In [1]:

import ast 
from ast import NodeTransformer

# class DeleteNodeAtHeight(NodeTransformer):
    
#     def __init__(self, max_height):
#         self.max_height = max_height
    
#     def visit(self, node, height): 
#         if height == self.max_height:
#             print(f"removing node {ast.dump(node)}")
#             # import pdb; pdb.set_trace()
#             return None
#         for field, value in ast.iter_fields(node):
#             if isinstance(value, list):
#                 new_values = []
#                 for item in value:
#                     if isinstance(item, ast.AST):
#                         new_node = self.visit(item, height + 1)
#                         if new_node is not None:
#                             new_values.append(new_node)
#                     else:
#                         new_values.append(item)
#                 setattr(node, field, new_values)
#             elif isinstance(value, ast.AST):
#                 new_node = self.visit(value, height + 1)
#                 if new_node is not None:
#                     setattr(node, field, new_node)
#         return node

# class DeleteNodeAtHeight(NodeTransformer):
#     def __init__(self, max_height):
#         self.max_height = max_height

#     def visit(self, node):
#         return self.visit_with_height(node, 0)  # Start with height 0

#     def visit_with_height(self, node, height):
#         """A custom method to handle the height tracking."""
#         if height == self.max_height:
#             print(f"Removing node at height {height}: {ast.dump(node)}")
#             return None  # Remove the node

#         # Increment the height for children
#         for field, value in ast.iter_fields(node):
#             if isinstance(value, list):
#                 new_values = []
#                 for item in value:
#                     if isinstance(item, ast.AST):
#                         new_node = self.visit_with_height(item, height + 1)
#                         if new_node is not None:
#                             new_values.append(new_node)
#                     else:
#                         new_values.append(item)
#                 setattr(node, field, new_values)
#             elif isinstance(value, ast.AST):
#                 new_node = self.visit_with_height(value, height + 1)
#                 if new_node is not None:
#                     setattr(node, field, new_node)
#         return node

class DeleteNodeAtHeight(NodeTransformer): 
    def __init__(self, max_height, verbose=False):
        self.max_height = max_height
        self.verbose = verbose
        self.current_height = 0  # Adding an instance variable to track current height
    
    def generic_visit(self, node):
        """Override generic_visit to manage the height increment."""
        if self.current_height == self.max_height:
            if self.verbose:
                print(f"Removing node: {ast.dump(node)}")
            return None  # Remove the node by returning None
        self.current_height += 1
        # Continue with the original generic_visit, which will recursively visit children
        result = super().generic_visit(node)
        self.current_height -= 1
        return result


ex_1 = """

def foo():
    for i in range(10):
        x = (i + 1) * 2
        print(x)
    return x

"""

orig_ast = ast.parse(ex_1)
print(ast.dump(orig_ast, indent=4))

    
        

Module(
    body=[
        FunctionDef(
            name='foo',
            args=arguments(
                posonlyargs=[],
                args=[],
                kwonlyargs=[],
                kw_defaults=[],
                defaults=[]),
            body=[
                For(
                    target=Name(id='i', ctx=Store()),
                    iter=Call(
                        func=Name(id='range', ctx=Load()),
                        args=[
                            Constant(value=10)],
                        keywords=[]),
                    body=[
                        Assign(
                            targets=[
                                Name(id='x', ctx=Store())],
                            value=BinOp(
                                left=BinOp(
                                    left=Name(id='i', ctx=Load()),
                                    op=Add(),
                                    right=Constant(value=1)),
                                op=Mult

In [2]:
def find_ast_height(node):
    if not isinstance(node, ast.AST):
        return 0
    max_height = 0
    for field, value in ast.iter_fields(node):
        if isinstance(value, list):
            for item in value:
                max_height = max(max_height, find_ast_height(item))
        elif isinstance(value, ast.AST):
            max_height = max(max_height, find_ast_height(value))
    return 1 + max_height

max_height = find_ast_height(orig_ast)
print(max_height)

8


In [3]:
from copy import deepcopy
copy_ast = deepcopy(orig_ast)
truncate = DeleteNodeAtHeight(7, verbose=True)
# new_ast = 
new_ast = truncate.visit(copy_ast)
print(ast.dump(new_ast, indent=4, annotate_fields=True))


Removing node: Load()
Module(
    body=[
        FunctionDef(
            name='foo',
            args=arguments(
                posonlyargs=[],
                args=[],
                kwonlyargs=[],
                kw_defaults=[],
                defaults=[]),
            body=[
                For(
                    target=Name(id='i', ctx=Store()),
                    iter=Call(
                        func=Name(id='range', ctx=Load()),
                        args=[
                            Constant(value=10)],
                        keywords=[]),
                    body=[
                        Assign(
                            targets=[
                                Name(id='x', ctx=Store())],
                            value=BinOp(
                                left=BinOp(
                                    left=Name(id='i'),
                                    op=Add(),
                                    right=Constant(value=1)),
                             

In [11]:
class AstTruncator: 
    """wrapper class to hold all H - 2 ast's for a given ast.
    We have the original AST, and then we have the H - 1, H - 2, ... 2 asts. 
    If we used H = 1, it would only be a module which is not very interesting.
    """
    def __init__(self, ast, verbose=False):
        self.verbose = verbose
        self.ast = ast
        self.max_height = find_ast_height(self.ast)
        self.all_sub_asts = []
        if self.verbose:
            print(f"Max height of AST: {self.max_height}")
            print("making all sub asts")
        for i in range(2, self.max_height+1):
            if self.verbose:
                print("-"*40)
                print(f"Generating sub AST at height {i}")
            copy_ast = deepcopy(self.ast)
            self.all_sub_asts.append(DeleteNodeAtHeight(i, verbose=self.verbose).visit(copy_ast))
        assert set([find_ast_height(ast) for ast in self.all_sub_asts]) == set(range(2, self.max_height+1))
        # reverse order 
        self.all_sub_asts = self.all_sub_asts[::-1]
        if self.verbose:
            print("done making all sub asts")
                                     
    
    def get_sub_ast(self, height):
        if height < 2 or height >= self.max_height:
            raise ValueError(f"Height must be between 2 and {self.max_height - 1}")
        return self.all_sub_asts[height - 2]
    
    
    def __iter__(self):
        return iter(self.all_sub_asts)
    
    

In [5]:
truncator = AstTruncator(orig_ast, verbose=True)

print(f"Number of sub ASTs: {len(truncator.all_sub_asts)}")
height = 2
for sub_ast in truncator:
    print(f"Height {height} AST:")
    print(ast.dump(sub_ast, indent=4, annotate_fields=True))
    print('---------------------------------')
    height += 1

Max height of AST: 8
making all sub asts
----------------------------------------
Generating sub AST at height 2
Removing node: arguments(posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[])
Removing node: For(target=Name(id='i', ctx=Store()), iter=Call(func=Name(id='range', ctx=Load()), args=[Constant(value=10)], keywords=[]), body=[Assign(targets=[Name(id='x', ctx=Store())], value=BinOp(left=BinOp(left=Name(id='i', ctx=Load()), op=Add(), right=Constant(value=1)), op=Mult(), right=Constant(value=2))), Expr(value=Call(func=Name(id='print', ctx=Load()), args=[Name(id='x', ctx=Load())], keywords=[]))], orelse=[])
Removing node: Return(value=Name(id='x', ctx=Load()))
----------------------------------------
Generating sub AST at height 3
Removing node: Name(id='i', ctx=Store())
Removing node: Call(func=Name(id='range', ctx=Load()), args=[Constant(value=10)], keywords=[])
Removing node: Assign(targets=[Name(id='x', ctx=Store())], value=BinOp(left=BinOp(left=Name(id='i', ctx

In [21]:


class AstSubTree:
    def __init__(self, root, height, derivation: str =None):
        self.root = root
        self.height = height
        self.str = ast.dump(root) 
        self.derivation = derivation
        
    def __str__(self):
        if self.derivation is not None:
            return f"Height: {self.height}, Derivation: {self.derivation}\nAST: {ast.dump(self.root, indent=4)}"
        else: 
            return f"Height: {self.height}\nAST: {ast.dump(self.root, indent=4)}"
    
    def as_string(self):
        return self.str

# def all_subtrees_for_ast(node, path=[], visited=set()):
#     """Analyze the subtree to compute the height, collect all subtrees, and track visited nodes by derivations,
#        returning the list of subtrees including the current node with its height."""
#     subtrees = []
#     max_height = 0
#     current_derivation = path + [type(node).__name__]

#     # Create a unique representation of the current path as a tuple
#     derivation_tuple = tuple(current_derivation)

#     # Check if this derivation has been visited
#     if derivation_tuple in visited:
#         return subtrees, max_height

#     # Add the current derivation to the visited set
#     visited.add(derivation_tuple)

#     # Recursively compute the height of each child and collect subtrees
#     for child in ast.iter_child_nodes(node):
#         children_subtrees, child_height, child_visited = all_subtrees_for_ast(child, current_derivation, visited)
#         subtrees.extend(children_subtrees)
#         max_height = max(max_height, child_height)
#         visited.update(child_visited)
    
#     # Increment to include the current node in the height calculation
#     max_height += 1

#     # Append the current node's subtree after processing all children
#     subtrees.append(AstSubTree(node, max_height))
#     return subtrees, max_height


# def bottom_up_subtrees_for_ast(node, derivation_path=[], visited=set()): 
#     # base case we're at a terminal, we have the derivation, if the derivation is unique, we propagate the node info up + a True, if it has already existed we do False
#         #input: derivation, node, visited
#         #output: do_add_subtree, visited -> add the derivation to this when we get to the terminal 
#     subtrees = []
#     max_height = 0 
#     current_derivation_list = derivation_path + [type(node).__name__]
    
#     children = list(ast.iter_child_nodes) 
    
    
#     if len(children) == 0: 
        
#         terminal_subtree = [AstSubTree(node, 1)]
#         current_derivation = "/".join(current_derivation_list)
        
#         if current_derivation in visited: 
#         # we've been down this path before
#             return terminal_subtree, 1, False, visited 
#         # this is a terminal we haven't seen before 
#         # add new derivationpath to the visited set
#         visited.add(current_derivation)
#         return terminal_subtree, 1, True, visited
    
#     # otherwise we're at a non-terminal 
#     any_children_new = False
#     children_subtrees = []
#     for child in children: 
#         child_subtrees, child_max_height, child_is_new, child_visited_set = bottom_up_subtrees_for_ast(child, current_derivation_list, visited)
#         # sanity check here 
#         if not child_is_new: 
#             # if it was visited, then it should only propagate that node up to the next level 
#             assert len(children_subtrees) == 1 
#             assert visited != child_visited_set
#         else: 
#             assert visited == child_visited_set
        
#         visited.update(child_visited_set)
#         max_height = max(max_height, child_max_height)
#         children_subtrees.extend(child_subtrees)
#         any_children_new = any_children_new or child_is_new # if any children (recursively) is novel / changed, we'll add all subtrees received 

#     # we propagate this node up as well as all children subtrees returned 
#     max_height += 1
#     if any_children_new: 
#         # add node to the children subtrees
#         subtrees = children_subtrees + [AstSubTree(node, max_height)] 
#         return subtrees, max_height, True, visited 
    
#     # all the children were actually previously visited 
#     return [AstSubTree(node, 1)], max_height, False, visited
    
        
def bottom_up_subtrees_for_ast(node, derivation_path=[], visited=set(), verbose=False):
    # base case we're at a terminal, we have the derivation, if the derivation is unique, we propagate the node info up + a True, if it has already existed we do False
        #input: derivation, node, visited
        #output: do_add_subtree, visited -> add the derivation to this when we get to the terminal 
    subtrees = []
    max_height = 0 
    current_derivation_list = derivation_path + [type(node).__name__]
    

    any_children_new = False
    children_subtrees = []
    for i, child in enumerate(ast.iter_child_nodes(node)):
        child_derivation = current_derivation_list + [f"child_{i}"]
        child_subtrees, child_max_height, child_is_new, child_visited_set = bottom_up_subtrees_for_ast(child, child_derivation, visited)
        # sanity check here 
        if not child_is_new: 
            # if it was visited, then it should only propagate that node up to the next level 
            if len(child_subtrees) != 1:
                import pdb; pdb.set_trace()
            assert len(child_subtrees) == 1
            if visited != child_visited_set:
                import pdb; pdb.set_trace()
            assert visited == child_visited_set
        else: 
            pass
            # if visited == child_visited_set:
            #     import pdb; pdb.set_trace()
            # assert visited != child_visited_set
        
        visited.update(child_visited_set)
        max_height = max(max_height, child_max_height)
        if child_is_new:
            children_subtrees.extend(child_subtrees)
        any_children_new = any_children_new or child_is_new # if any children (recursively) is novel / changed, we'll add all subtrees received 
    
    max_height += 1
    
    this_node = [AstSubTree(node, max_height, "/".join(current_derivation_list))]
    
    # check if this is a terminal 
    if max_height == 1: 
        current_derivation = "/".join(current_derivation_list)
        # if current_derivation in visited: 
        is_new = current_derivation not in visited # we've been down this path before
        if verbose:
            print(f"Processing {current_derivation}, is_new: {is_new}")
            print("Visited so far:", visited)
        visited.add(current_derivation)
        return this_node, max_height, is_new, visited

    # we propagate this node up as well as all children subtrees returned 
    # if the node's subtree is new (any of its children is new), we add the node to the children subtrees (if any) and return the combined list
    # else, we only propagate this node up, and do not add any children subtrees
    
    is_new = any_children_new
    subtrees = children_subtrees + this_node if is_new else this_node 
    # if we are the root node and not is_new, we should return [] 
    # this is not necessary, because at each iteration in all_subtrees, the root node will always be modified as we will always truncate 
    # subtrees = [] if is_root and not is_new else subtrees
    
    
    return subtrees, max_height, is_new, visited
            
    
    
    
    # at any non-terminal, we scan over the children, if all of the children are False, we propagate False up (ie this whole subtree is unchanged), do not add any subtrees
        #input: derivation, node, visited
        # output: subtrees, do_add_subtree, visited
        
    
    
    



def all_subtrees(node, verbose=False):
    """
    Given an AST node, return all subtrees of the node. We first make all subtrees of the truncated tree (remove children at each level)
    """
    all_subtrees = []
    visited = set()
    for truncated_tree in AstTruncator(node, verbose=verbose):
        if verbose:
            height = find_ast_height(truncated_tree)
            print("-"*40)
            print(f"Height of truncated tree: {height}")
        new_subtrees, _, _, visited = bottom_up_subtrees_for_ast(truncated_tree, [], visited, verbose=verbose)
        all_subtrees.extend(new_subtrees)
    return all_subtrees
    

# def subtrees_from_code(source_code, obfuscate=False, strip_all=False):
#     if obfuscate:
#         source_code, _ = obfuscateString(source_code)
#     tree = ast.parse(source_code)
#     if strip_all:
#         tree = strip_id_value(tree)
#     subtrees, _ = all_subtrees(tree)
#     return subtrees  # Return the list of subtrees for the entire AST

def all_subtrees_from_code(source_code, verbose=False):
    tree = ast.parse(source_code)
    subtrees = all_subtrees(tree, verbose=verbose)
    return subtrees


def all_subtrees_of_height(subtrees, height): 
    return [subtree for subtree in subtrees if subtree.height == height]

In [19]:
type(copy_ast).__name__

'Module'

In [8]:
# assert len(set([subtree.__str__() for subtree in height_2])) == len(height_2)
# assert len(set([subtree.__str__() for subtree in height_3])) == len(height_3)

# print(f"unique height 2 subtrees: {len(set([subtree.str for subtree in height_2]))}, vs {len(height_2)}")
# print(f"unique height 3 subtrees: {len(set([subtree.str for subtree in height_3]))}, vs {len(height_3)}")

In [9]:
# height_2_strs = [subtree.str for subtree in height_2]
# from collections import Counter
# duplication = [item for item, count in Counter(height_2_strs).items() if count > 1]
# duplication

In [16]:
height_2[1].str

"Name(id='range', ctx=Load())"

In [20]:
all_subtrees = all_subtrees_from_code(ex_1)
height_2 = all_subtrees_of_height(all_subtrees, 2)
height_3 = all_subtrees_of_height(all_subtrees, 3)

print(f"Subtrees of Height 2\n{'-' * 20}")
for subtree in height_2: 
    print(subtree)
    print("{:-^20}".format(""))
print(f"Subtrees of Height 3\n{'-' * 20}")
for subtree in height_3: 
    print(subtree)
    print("{:-^20}".format(""))


    
# ex_1 = """

# def foo():
#     for i in range(10):
#         x = (i + 1) * 2
#         print(x)
#     return x

# """    
    

Subtrees of Height 2
--------------------
Height: 2, Derivation: Module/child_0/FunctionDef/child_1/For/child_0/Name
AST: Name(id='i', ctx=Store())
--------------------
Height: 2, Derivation: Module/child_0/FunctionDef/child_1/For/child_1/Call/child_0/Name
AST: Name(id='range', ctx=Load())
--------------------
Height: 2, Derivation: Module/child_0/FunctionDef/child_1/For/child_2/Assign/child_0/Name
AST: Name(id='x', ctx=Store())
--------------------
Height: 2, Derivation: Module/child_0/FunctionDef/child_1/For/child_2/Assign/child_1/BinOp/child_0/BinOp/child_0/Name
AST: Name(id='i', ctx=Load())
--------------------
Height: 2, Derivation: Module/child_0/FunctionDef/child_1/For/child_3/Expr/child_0/Call/child_0/Name
AST: Name(id='print', ctx=Load())
--------------------
Height: 2, Derivation: Module/child_0/FunctionDef/child_1/For/child_3/Expr/child_0/Call/child_1/Name
AST: Name(id='x', ctx=Load())
--------------------
Height: 2, Derivation: Module/child_0/FunctionDef/child_2/Return/chil

In [None]:
x = set()
x.add(1)
x

In [None]:
truncate = DeleteNodeAtHeight(max_height - 2)
new_ast = truncate.visit(orig_ast, 0)
print(ast.dump(new_ast, indent=4))

In [32]:
import tokenize
from io import BytesIO

def get_relevant_tokens_lexer(code_str):
    # Convert the string to a bytes-like object
    bytes_io = BytesIO(code_str.encode('utf-8'))
    
    # Use the tokenize module to tokenize the code
    tokens = tokenize.tokenize(bytes_io.readline)
    
    # Define the irrelevant token types
    irrelevant_types = {
        tokenize.ENCODING,
        tokenize.ENDMARKER,
        # tokenize.NEWLINE,
        tokenize.INDENT, # indents are usually 'obvious' ie an if statement will have an indent, a for statement will have an indent, etc; the dedent is more interesting
        tokenize.NL,
        # tokenize.COMMENT,
    }
    
    # Extract the relevant tokens, ignore irrelevant ones, and exclude the token type for string representation
    # relevant_tokens = [token.string for token in tokens if token.type not in irrelevant_types or token.type == tokenize.DEDENT]
    relevant_tokens = [] 
    for token in tokens:
        if token.type not in irrelevant_types:
            if token.type == tokenize.STRING or token.type == tokenize.COMMENT:
                relevant_tokens.extend(token.string.split(" "))
            elif token.type == tokenize.DEDENT:
                relevant_tokens.append("DEDENT")
            else: 
                relevant_tokens.append(token.string)
    
    return relevant_tokens

tokens = get_relevant_tokens_lexer(ex_1)
print(tokens)

['def', 'foo', '(', ')', ':', '\n', 'for', 'i', 'in', 'range', '(', '10', ')', ':', '\n', 'x', '=', '(', 'i', '+', '1', ')', '*', '2', '\n', 'print', '(', 'x', ')', '\n', 'DEDENT', 'return', 'x', '\n', 'DEDENT']


In [28]:
# bytes_io = BytesIO(ex_1.encode('utf-8'))
    
#     # Use the tokenize module to tokenize the code
# tokens = tokenize.tokenize(bytes_io.readline)
# for token in tokens:
#     if token.type == tokenize.DEDENT: 
#         import pdb; pdb.set_trace()
#     print(token)
    #

In [None]:

class AllSubtreeAnalysis: 
    def __init__(self, source_code):
        self.source_code = source_code
        try: 
            self.orig_ast = ast.parse(source_code)
            self.plain_subtrees = subtrees_from_code(source_code)
            self.strip_subtrees = subtrees_from_code(source_code, strip_all=True)
            self.obf_subtrees = subtrees_from_code(source_code, obfuscate=True)
        except Exception as e:
            self.orig_ast = []
            self.plain_subtrees = []
            self.strip_subtrees = []
            self.obf_subtrees = []
            print(f"Error processing source code: {source_code}")
            traceback.print_exc()
            raise e
            
     
    @staticmethod
    def filter_below_height(subtrees, height = None):
        if height is None:
            return subtrees
        return [subtree for subtree in subtrees if subtree.height == height]
    
    @staticmethod
    def subtrees_as_string(subtrees): 
        return [subtree.as_string() for subtree in subtrees]
    
    def get_plain_subtrees(self, max_height = None):
        subtrees = self.filter_below_height(self.plain_subtrees, max_height)
        return self.subtrees_as_string(subtrees)

    
    def get_stripped_subtrees(self, max_height = None):
        subtrees = self.filter_below_height(self.strip_subtrees, max_height)
        return self.subtrees_as_string(subtrees)
    
    def get_obfuscated_subtrees(self, max_height = None):
        subtrees = self.filter_below_height(self.obf_subtrees, max_height)
        return self.subtrees_as_string(subtrees)
    
    def get_subtrees(self, typ: str, max_height = None):
        if typ == "plain":
            return self.get_plain_subtrees(max_height)
        elif typ == "stripped":
            return self.get_stripped_subtrees(max_height)
        elif typ == "obfuscated":
            return self.get_obfuscated_subtrees(max_height)
        else:
            raise ValueError("Invalid type, must be one of 'plain', 'stripped', or 'obfuscated'")

In [1]:
! pip install parso



In [3]:
from parso.python import tokenize

In [11]:
s="""from typing import List, Tuple
def f(inputs: List[Tuple[float, float, float, float]]):
    sum_values = 0
    for i, (a, b, c, d) in enumerate(inputs, 1):
        sum_value = a * b - c * d
        sum_values += sum_value
        print(f"Dataset {i}: {a} * {b} - {c} * {d} = {sum_value}")
    print(f"Total: {sum_values}")

if __name__ == "__main__":
    inputs = [eval(i) for i in input().split(",")]
    f(inputs)

if __name__ == "__main__":
    inputs = [eval(i) for i in input().split(",")]
    f(inputs)

from typing import List, Tuple
def f(inputs: List[Tuple[float, float, float, float]]):
    sum_values = 0
    for i, (a, b, c, d) in enumerate(inputs, 1):
        sum_value = a * b - c * d
        sum_values += sum_value
        print(f"Dataset {i}: {a} * {b} - {c} * {d} = {sum_value}")
    print(f"Total: {sum_values}")

if __name__ == "__main__":
    inputs = [eval(i) for i in input().split(",")]
    f(tuple(map(eval, eval(input().replace('(', '[').replace(')', ']')).split(',')))
"""
tokens = tokenize.tokenize(s, version_info=(3, 12))

In [12]:
list(tokens)


[TokenInfo(type=NAME, string='from', start_pos=(1, 0), prefix=''),
 TokenInfo(type=NAME, string='typing', start_pos=(1, 5), prefix=' '),
 TokenInfo(type=NAME, string='import', start_pos=(1, 12), prefix=' '),
 TokenInfo(type=NAME, string='List', start_pos=(1, 19), prefix=' '),
 TokenInfo(type=OP, string=',', start_pos=(1, 23), prefix=''),
 TokenInfo(type=NAME, string='Tuple', start_pos=(1, 25), prefix=' '),
 TokenInfo(type=NEWLINE, string='\n', start_pos=(1, 30), prefix=''),
 TokenInfo(type=NAME, string='def', start_pos=(2, 0), prefix=''),
 TokenInfo(type=NAME, string='f', start_pos=(2, 4), prefix=' '),
 TokenInfo(type=OP, string='(', start_pos=(2, 5), prefix=''),
 TokenInfo(type=NAME, string='inputs', start_pos=(2, 6), prefix=''),
 TokenInfo(type=OP, string=':', start_pos=(2, 12), prefix=''),
 TokenInfo(type=NAME, string='List', start_pos=(2, 14), prefix=' '),
 TokenInfo(type=OP, string='[', start_pos=(2, 18), prefix=''),
 TokenInfo(type=NAME, string='Tuple', start_pos=(2, 19), prefix=