# Python Code to AST Using LIBCST
* You can find the full documentation at: https://libcst.readthedocs.io/en/latest/

## INSTALLATION
LIBCST is for Python3 only

In [None]:
!pip3 install libcst

## PARSE CODE

In [1]:
import libcst as cst

### PARSE EXPRESSIONS

In [2]:
exp = cst.parse_expression("x >=5 ")

In [3]:
print(exp)

Comparison(
    left=Name(
        value='x',
        lpar=[],
        rpar=[],
    ),
    comparisons=[
        ComparisonTarget(
            operator=GreaterThanEqual(
                whitespace_before=SimpleWhitespace(
                    value=' ',
                ),
                whitespace_after=SimpleWhitespace(
                    value='',
                ),
            ),
            comparator=Integer(
                value='5',
                lpar=[],
                rpar=[],
            ),
        ),
    ],
    lpar=[],
    rpar=[],
)


### PARSE STATEMENTS

In [4]:
stmt = cst.parse_statement('a = list("abcd") + ["e"]') 

In [5]:
print(stmt)

SimpleStatementLine(
    body=[
        Assign(
            targets=[
                AssignTarget(
                    target=Name(
                        value='a',
                        lpar=[],
                        rpar=[],
                    ),
                    whitespace_before_equal=SimpleWhitespace(
                        value=' ',
                    ),
                    whitespace_after_equal=SimpleWhitespace(
                        value=' ',
                    ),
                ),
            ],
            value=BinaryOperation(
                left=Call(
                    func=Name(
                        value='list',
                        lpar=[],
                        rpar=[],
                    ),
                    args=[
                        Arg(
                            value=SimpleString(
                                value='"abcd"',
                                lpar=[],
                                rpar=[],
                

### PARSE MODULE

In [6]:
code = """
def random_func(a, b, s):
    a = b * s
    b = a * s
    s = a * b
    return a * s + b * s
"""
mdl = cst.parse_module(code)

In [7]:
print(mdl)

Module(
    body=[
        FunctionDef(
            name=Name(
                value='random_func',
                lpar=[],
                rpar=[],
            ),
            params=Parameters(
                params=[
                    Param(
                        name=Name(
                            value='a',
                            lpar=[],
                            rpar=[],
                        ),
                        annotation=None,
                        equal=MaybeSentinel.DEFAULT,
                        default=None,
                        comma=Comma(
                            whitespace_before=SimpleWhitespace(
                                value='',
                            ),
                            whitespace_after=SimpleWhitespace(
                                value=' ',
                            ),
                        ),
                        star='',
                        whitespace_after_star=SimpleWhitespace(
          

## AST TO CODE

In [8]:
# code to ast
code = """
def random_func(a, b, s):
    a = b * s
    b = a * s
    s = a * b
    return a * s + b * s
"""
mdl = cst.parse_module(code)

# ast to code

code_prime = cst.Module([mdl]).code

In [9]:
print(code_prime)


def random_func(a, b, s):
    a = b * s
    b = a * s
    s = a * b
    return a * s + b * s



## SEARCH IN AST USING VISITORS

### Example : search for calls to the function "print" in an arbitrary code
* We need to implement a class that extends the Vistor class
* The class should have a method with the name visit_NODE where NODE is the type of the node we want to find in the AST
* In this first example, the name of the node is "Call" which represents any call to any function in a python code

#### Visitor definition

In [10]:
class FindPrint(cst.CSTVisitor):
    
    # data structures to keep track of the search results
    
    prints = [] # the list of print calls found in the code
    prints_detailed = [] # the details of the print call (arguments passed to the function print)
    other_calls = [] # other calls that are not print (this is just for the sake of the example, you can drop this attribute if you don't want that list)
    
    # INIT
    def __init__(self):
        self.prints = []
        self.other_calls = []
        self.prints_detailed = []
        
    # name of the function should mention the type of node we are looking for, in this case a Call node
    def visit_Call(self, node: cst.Call):
        try:
            if node.func.value == 'print':
                self.prints.append(node)
                self.prints_detailed.append((node.args))
            else:
                self.other_calls.append(node)
        except Exception as e:
            print(e)
            #print("CALL DOES NOT HAVE ATTR VALUE", cst.Module([node]).code)

#### Usage example

In [11]:
code = """
x = int(0.5)
if x >= 0:
    print('positive')
else:
    print('negative')
"""

code_tree = cst.parse_module(code)

print_finder = FindPrint()

_ = code_tree.visit(print_finder)

assert(len(print_finder.prints) == 2)

In [12]:
print_finder.prints

[Call(
     func=Name(
         value='print',
         lpar=[],
         rpar=[],
     ),
     args=[
         Arg(
             value=SimpleString(
                 value="'positive'",
                 lpar=[],
                 rpar=[],
             ),
             keyword=None,
             equal=MaybeSentinel.DEFAULT,
             comma=MaybeSentinel.DEFAULT,
             star='',
             whitespace_after_star=SimpleWhitespace(
                 value='',
             ),
             whitespace_after_arg=SimpleWhitespace(
                 value='',
             ),
         ),
     ],
     lpar=[],
     rpar=[],
     whitespace_after_func=SimpleWhitespace(
         value='',
     ),
     whitespace_before_args=SimpleWhitespace(
         value='',
     ),
 ),
 Call(
     func=Name(
         value='print',
         lpar=[],
         rpar=[],
     ),
     args=[
         Arg(
             value=SimpleString(
                 value="'negative'",
                 lpar=[],
           

In [13]:
for print_call in print_finder.prints:
    print(cst.Module([print_call]).code)

print('positive')
print('negative')


In [14]:
for other_call in print_finder.other_calls:
    print(cst.Module([other_call]).code)

int(0.5)


In [16]:
for print_detail in print_finder.prints_detailed:
    print('The function print was called with the following argument(s):', cst.Module(list(print_detail)).code)

The function print was called with the following argument(s): 'positive'
The function print was called with the following argument(s): 'negative'


## AST BASED CODE MODIFICATION

### Example: replace the argument value for all function calls with a specific value
Of course you can develop this example more, using condition over the name of the called function and the value of the argument.
This is it just an example on how to do a basic transformation on an AST node.

In [17]:
class CallTransformer(cst.CSTTransformer):
    
    def __init__(self, name, args):
        self.stack = []
        self.name = name
        self.args = args

    def visit_Call(self, node: cst.FunctionDef):
        self.stack.append(node)
        
    def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
        self.stack.pop()
        return updated_node.with_changes(
            func = original_node.func, args=(cst.Arg(value = cst.SimpleString(value = "'custom value here'", lpar = [], rpar = [])),)
        )
        return updated_node

In [18]:
code = """
x = int(0.5)
if x >= 0:
    print('positive')
else:
    print('negative')
"""

code_tree = cst.parse_module(code)

transformer = CallTransformer('log', print_detail)
tree = code_tree.visit(transformer)


In [19]:
print(tree.code)


x = int('custom value here')
if x >= 0:
    print('custom value here')
else:
    print('custom value here')

