In [1]:
import pandas as pd
from tree_sitter import Language, Parser

import os
os.chdir('/home/dev/function_parser/function_parser/')

from language_data import LANGUAGE_METADATA
from process import DataProcessor

In [0]:
! rm -rf mkdir /home/dev/.vendor 
! mkdir /home/dev/.vendor
! rm -rf mkdir /home/dev/.build
! mkdir /home/dev/.build

! git clone https://github.com/tree-sitter/tree-sitter-python /home/dev/.vendor/tree-sitter-python

In [7]:
Language.build_library(
  # Store the library in the `build` directory
  '/home/dev/.build/py-tree-sitter-languages.so',

  # Include one or more languages
  [
    #'vendor/tree-sitter-go',
    #'vendor/tree-sitter-javascript',
    '/home/dev/.vendor/tree-sitter-python'
  ]
)

False

In [2]:
language_name = 'python'

language = Language('/home/dev/.build/py-tree-sitter-languages.so', language_name)
parser = Parser()
parser.set_language(language)

In [2]:
code = """
def foo():
    if bar:
        baz()
    if bar2:
        baz2()
    i = 0
    k = 1
    k = i + k
"""
blob = bytes(code, "utf8")
blob[0:15]


b'\ndef foo():\n   '

In [4]:
tree = parser.parse(blob)

In [5]:
def print_ast(cursor, blob, depth = 0):
    depth+=1
    # print(dir(cursor.node))
    if cursor.node.type == 'identifier':
        name = blob[cursor.node.start_byte:cursor.node.end_byte].decode("utf8")
        print(' '*depth + name)
    else:
        print(' '*depth + cursor.node.type) 
    
    if cursor.goto_first_child():
        while cursor:
            print_ast(cursor, blob, depth)
            if not cursor.goto_next_sibling():
                break
        cursor.goto_parent()



In [6]:
cursor = tree.walk()
print_ast(cursor, blob)


module
  function_definition
   def
   foo
   parameters
    (
    )
   :
   block
    if_statement
     if
     bar
     :
     block
      expression_statement
       call
        baz
        argument_list
         (
         )
    if_statement
     if
     bar2
     :
     block
      expression_statement
       call
        baz2
        argument_list
         (
         )
    expression_statement
     assignment
      expression_list
       i
      =
      expression_list
       integer
    expression_statement
     assignment
      expression_list
       k
      =
      expression_list
       integer
    expression_statement
     assignment
      expression_list
       k
      =
      expression_list
       binary_operator
        i
        +
        k


In [7]:
def ast2graph(cursor, blob, depth = 0):
    depth+=1
    # print(dir(cursor.node))
    if cursor.node.type == 'identifier':
        name = blob[cursor.node.start_byte:cursor.node.end_byte].decode("utf8")
        print(' '*depth + name)
    else:
        print(' '*depth + cursor.node.type) 
    
    if cursor.goto_first_child():
        while cursor:
            print_tree(cursor, blob, depth)
            if not cursor.goto_next_sibling():
                break
        cursor.goto_parent()

In [18]:
os.chdir('/home/dev/src/')
from ast_graph_generator import AstGraphGenerator, NODE_TYPE
from ast import parse

In [115]:
def next_terminal(node, non_terminal):
    if node in non_terminal:
        return next_terminal(node+1, non_terminal)
    return node

def fix_index(node, non_terminal):
    node = next_terminal(node, non_terminal)
    return node - len([_ for _ in non_terminal if _ < node])
    
visitor = AstGraphGenerator()
visitor.visit(parse(code))

# we'll keep the the first node and rename it to root
visitor.node_label[0] = 'root'
# non-terminal
n_t = [index for (index, _) in sorted(visitor.node_label.items(
)) if visitor.node_type[index] == NODE_TYPE['non_terminal'] and index > 0]
n_t.sort()

E = [(t, origin, destination)
        for (origin, destination), edges
        in visitor.graph.items() for t in edges]
E = [(e[0], fix_index(e[1], n_t), fix_index(e[2], n_t))
        for e in E]
# remove self references
E = [e for e in E if e[1] != e[2]]

V = [label.strip() for (index, label) in sorted(
    visitor.node_label.items()) if index not in n_t]
[(_[0], V[_[1]], V[_[2]]) for _ in E]



[('child', 'root', 'def'),
 ('child', 'root', 'foo'),
 ('NextToken', 'def', 'foo'),
 ('child', 'root', '('),
 ('NextToken', 'foo', '('),
 ('child', 'root', ')'),
 ('NextToken', '(', ')'),
 ('child', 'root', ':'),
 ('NextToken', ')', ':'),
 ('child', 'root', 'if'),
 ('NextToken', ':', 'if'),
 ('child', 'if', 'bar'),
 ('NextToken', 'if', 'bar'),
 ('child', 'if', ':'),
 ('NextToken', 'bar', ':'),
 ('child', 'if', 'baz'),
 ('NextToken', ':', 'baz'),
 ('child', 'baz', '('),
 ('NextToken', 'baz', '('),
 ('child', 'baz', ')'),
 ('NextToken', '(', ')'),
 ('child', 'root', 'if'),
 ('NextToken', ')', 'if'),
 ('child', 'if', 'bar2'),
 ('NextToken', 'if', 'bar2'),
 ('child', 'if', ':'),
 ('NextToken', 'bar2', ':'),
 ('child', 'if', 'baz2'),
 ('NextToken', ':', 'baz2'),
 ('child', 'baz2', '('),
 ('NextToken', 'baz2', '('),
 ('child', 'baz2', ')'),
 ('NextToken', '(', ')'),
 ('child', 'root', 'i'),
 ('NextToken', ')', 'i'),
 ('child', 'i', '='),
 ('NextToken', 'i', '='),
 ('child', 'i', '0'),
 ('Nex

In [70]:
from dpu_utils.utils import RichPath
from typing import List, Dict, Any, Iterable, Tuple, Optional, Union, Callable, Type, DefaultDict
import numpy as np

def get_data_files_from_directory(data_dirs: List[RichPath],
                                  max_files_per_dir: Optional[int] = None) -> List[RichPath]:
    files = []  # type: List[str]
    for data_dir in data_dirs:
        dir_files = data_dir.get_filtered_files_in_dir('*.jsonl.gz')
        if max_files_per_dir:
            dir_files = sorted(dir_files)[:int(max_files_per_dir)]
        files += dir_files

    np.random.shuffle(files)  # This avoids having large_file_0, large_file_1, ... subsequences
    return files


In [12]:
visitor = AstGraphGenerator()

graph_node_labels = [label.strip() for (_, label) in sorted(visitor.node_label.items())]
graph_node_labels
for (idx, file) in enumerate(get_data_files_from_directory(
    [RichPath.create("../resources/data/python/final/jsonl/train"),
                RichPath.create("../resources/data/python/final/jsonl/valid"),
                RichPath.create("../resources/data/python/final/jsonl/test")])):
    
    print(file)
    for raw_sample in file.read_by_file_suffix():
        #print(dict(raw_sample))
        code = raw_sample['original_string']
        print(code)
        visitor.visit(parse(code))
        edge_list = [(t, origin, destination)
                                for (origin, destination), edges
                                in visitor.graph.items() for t in edges]
        print(edge_list)
        graph_node_labels = [label.strip() for (_, label) in sorted(visitor.node_label.items())]
        print(graph_node_labels)

        break
    break