In [34]:
import ast

In [35]:
code = '''
def addition(a,b):
    sum=a+b
    return sum
'''

#Use the `ast.parse` function to generate an AST from the code:

tree = ast.parse(code)

In [36]:
tree

<ast.Module at 0x27cd6b629d0>

In [37]:
# Displaying AST structure:
print(ast.dump(tree,indent=4))

Module(
    body=[
        FunctionDef(
            name='addition',
            args=arguments(
                posonlyargs=[],
                args=[
                    arg(arg='a'),
                    arg(arg='b')],
                kwonlyargs=[],
                kw_defaults=[],
                defaults=[]),
            body=[
                Assign(
                    targets=[
                        Name(id='sum', ctx=Store())],
                    value=BinOp(
                        left=Name(id='a', ctx=Load()),
                        op=Add(),
                        right=Name(id='b', ctx=Load()))),
                Return(
                    value=Name(id='sum', ctx=Load()))],
            decorator_list=[])],
    type_ignores=[])


In [38]:
class MyVisitor(ast.NodeVisitor):
    def visit_FunctionDef(self, node):
        print(f'Function: {node.name}')
        self.generic_visit(node)

    def visit_Print(self, node):
        print('Print statement')
        self.generic_visit(node)

visitor = MyVisitor()
visitor.visit(tree)


Function: addition


In [39]:
#Analyzing function arguments:
import ast

class FunctionAnalyzer(ast.NodeVisitor):
    def visit_FunctionDef(self, node):
        print(f"Analyzing function '{node.name}':")
        print(f"  Args: {[arg.arg for arg in node.args.args]}")
        print(f"  Vararg: {node.args.vararg}")
        print(f"  Kwarg: {node.args.kwarg}")
        self.generic_visit(node)

analyzer = FunctionAnalyzer()
analyzer.visit(tree)


Analyzing function 'addition':
  Args: ['a', 'b']
  Vararg: None
  Kwarg: None


In [40]:
#Collecting variables
def collect_variables(a):
    if type(a) is ast.Module:
        return [v for s in a.body for v in collect_variables(s)]

    elif type(a) is ast.FunctionDef:
        vs = [v for s in a.body for v in collect_variables(s)]
        return [a.name] + vs
    elif type(a) is ast.Assign:
        vs = [v for s in a.targets for v in collect_variables(s)]
        return vs + collect_variables(a.value)

    elif type(a) is ast.Return:
        return collect_variables(a.value)
    elif type(a) is ast.Name:
        return [a.id]
    elif type(a) is ast.BinOp:
        return collect_variables(a.left) + collect_variables(a.right)

    else:
      print(type(a)) # Display trees not captured by cases above.
    return []

In [41]:
code="""
def f(x, y):
  a = 5
  u = 1 + 2
  z = x + y
  return z
  """
tree=ast.parse(code)
collect_variables(tree)

<class 'ast.Constant'>
<class 'ast.Constant'>
<class 'ast.Constant'>


['f', 'a', 'u', 'z', 'x', 'y', 'z']

In [42]:
#Example with import statement
code="""
import random
if answer == 42:
    print('Correct answer!')
    """
tree=ast.parse(code)

In [43]:
print(ast.dump(tree,indent=4))

Module(
    body=[
        Import(
            names=[
                alias(name='random')]),
        If(
            test=Compare(
                left=Name(id='answer', ctx=Load()),
                ops=[
                    Eq()],
                comparators=[
                    Constant(value=42)]),
            body=[
                Expr(
                    value=Call(
                        func=Name(id='print', ctx=Load()),
                        args=[
                            Constant(value='Correct answer!')],
                        keywords=[]))],
            orelse=[])],
    type_ignores=[])


In [26]:
#understanding ctx in the AST
code="""
age=age+1
    """
tree=ast.parse(code)
print(ast.dump(tree,indent=4))

Module(
    body=[
        Assign(
            targets=[
                Name(id='age', ctx=Store())],
            value=BinOp(
                left=Name(id='age', ctx=Load()),
                op=Add(),
                right=Constant(value=1)))],
    type_ignores=[])


In [44]:
#Visit the tree step by step
import ast

class MyVisitor(ast.NodeVisitor):
    def generic_visit(self, node):
        print(f'entering {ast.dump(node)}')
        super().generic_visit(node)

visitor = MyVisitor()

tree = ast.parse('''
x = 5
print(x)
''')
visitor.visit(tree)

entering Module(body=[Assign(targets=[Name(id='x', ctx=Store())], value=Constant(value=5)), Expr(value=Call(func=Name(id='print', ctx=Load()), args=[Name(id='x', ctx=Load())], keywords=[]))], type_ignores=[])
entering Assign(targets=[Name(id='x', ctx=Store())], value=Constant(value=5))
entering Name(id='x', ctx=Store())
entering Store()
entering Constant(value=5)
entering Expr(value=Call(func=Name(id='print', ctx=Load()), args=[Name(id='x', ctx=Load())], keywords=[]))
entering Call(func=Name(id='print', ctx=Load()), args=[Name(id='x', ctx=Load())], keywords=[])
entering Name(id='print', ctx=Load())
entering Load()
entering Name(id='x', ctx=Load())
entering Load()


In [45]:
#Traversal
class MyVisitor(ast.NodeVisitor):
    def generic_visit(self, node):
        print(f'entering {node.__class__.__name__}')
        super().generic_visit(node)

visitor = MyVisitor()

tree = ast.parse('x = 5')
visitor.visit(tree)

entering Module
entering Assign
entering Name
entering Store
entering Constant


In [46]:
class MyVisitor(ast.NodeVisitor):
    def generic_visit(self, node):
        super().generic_visit(node)
        print(f'leaving {node.__class__.__name__}')

visitor = MyVisitor()

tree = ast.parse('x = 5')
visitor.visit(tree)

leaving Store
leaving Name
leaving Constant
leaving Assign
leaving Module


In [47]:
class MyVisitor(ast.NodeVisitor):
    def generic_visit(self, node):
        print(f'entering {node.__class__.__name__}')
        super().generic_visit(node)
        print(f'leaving {node.__class__.__name__}')

visitor = MyVisitor()

tree = ast.parse('x = 5')
visitor.visit(tree)

entering Module
entering Assign
entering Name
entering Store
leaving Store
leaving Name
entering Constant
leaving Constant
leaving Assign
leaving Module


In [48]:
#Counting Number of for nodes in a statement
import ast

class ForStmtCounter(ast.NodeVisitor):
    current_for_node = None
    stmt_count = 0

    def generic_visit(self, node):
        # If we are inside a for node, count statements
        if self.current_for_node is not None:
            if isinstance(node, ast.stmt):
                self.stmt_count += 1

        # If we just found a new for node, start counting
        elif isinstance(node, ast.For):
            self.current_for_node = node
            self.stmt_count = 0

        super().generic_visit(node)

        # This runs when coming back up from the children
        if node is self.current_for_node:
            # We're done counting this node. Print it out
            print(f'For node contains {self.stmt_count} statements')
            self.current_for_node = None

for_statement_counter = ForStmtCounter()

tree = ast.parse('''
for i in range(10):
    print(i)

for item in items:
    if item == 42:
        print('Magic item found!')
        break
''')
for_statement_counter.visit(tree)

For node contains 1 statements
For node contains 3 statements


In [49]:
#Modifying ast node
code="""
print(13)  # definitely not 42

num = 28
print(num)

def add(x, y):
    return x + y

for i in [1, 2, 3]:
    print(i)

output = add(num, 100)
print(output)
"""
import ast

class NumberChanger(ast.NodeTransformer):
    """Changes all number literals to 42."""
    def generic_visit(self, node):
        super().generic_visit(node)  # Added this line

        # if it isn't an int constant, do nothing with the node
        if not isinstance(node, ast.Constant) or not isinstance(node.value, int):
            return node

        return ast.Constant(value=42)
tree = ast.parse(code)
modified_tree = NumberChanger().visit(tree)
modified_tree = ast.fix_missing_locations(modified_tree)
exec(compile(modified_tree, '<my ast>', 'exec'))

42
42
42
42
42
84
