# Syntax-based Analysis (Part 2)

## Generating Parsers

In the previous chapter we built a small parser for an example language, extracted parse trees, and used syntax trees to convert source code to a format suitable for machine learning applications. Writing the parser was hard work, even though we only looked at a very simplistic language -- doing the same for "real" programming languages would be very cumbersome. Luckily, we don't need to construct parsers by hand, but can resort to compiler construction tools. We will be using [Antlr](https://www.antlr.org/) to have some parsers generated for us.

The starting point for a parser generator is a grammar describing the language, as well as lexical information that helps tokenizing raw text. In Antlr, both are specified in the same file; by convention, terminals are named in all caps and specified using regular expressions, while terminals are written in lower case.

```
grammar Expr1;

expr : expr '+' term  |
       expr '-' term  |
       term;

term : DIGIT ;

DIGIT : ('0'..'9') ;
WS : [ \t\r\n]+ -> skip ;
```

This grammar tells Antlr to skip whitespacaes (`WS`), to match individual digits (`DIGIT`), and then describes a simple grammar of expressions consisting of addition and subtraction of terms (which are simply individual digits for now).

Antlr will automatically produce a lexer and a parser and some more helpful files for us given such a grammar. To avoid a dependency on Antlr the notebook is not going to call Antlr directly, but we include the files produced by Antlr in the repository directly.

To process the above grammar with Antlr, we would need to save the grammar in a file `Expr1.g4`, and then call Antlr like so:

```
 antlr -Dlanguage=Python3 -visitor Expr1.g4
```

The `language` option tells Antlr which programming language the parser should be generated in.

The generated files are included in the `antlr` subdirectory of this notebook's repository.

In [None]:
!ls antlr/Expr1*

`Expr1Lexer.py` is the tokenizer, `Expr1Parser.py` contains the parser, `Expr1Visitor.py` provides a visitor interface for the parse tree, and `Expr1Listener.py` provides an interface with which we can react to parse events while parsing.

Since the generated files are in the `antlr` subdirectory of this notebook's repository, we need to tell Python to include from there.

In [None]:
import sys  
sys.path.insert(0, 'antlr')

import antlr

We also need to include the Antlr runtime library.

In [None]:
from antlr4 import *

We can now include the generated lexer and parser.

In [None]:
from Expr1Lexer import Expr1Lexer
from Expr1Parser import Expr1Parser

The pipeline to parse textual input is to (1) generate an input stream based on the text, (2) create a token stream out of the input stream, and (3) invoke the parser to consume the tokens. The parsing is started by invoking the starting rule of the grammar (`expr` in our case).

In [None]:
input = InputStream('1+2')
lexer = Expr1Lexer(input)
stream = CommonTokenStream(lexer)
parser = Expr1Parser(stream)
tree = parser.expr() 

The result (`tree`) is the parse tree produced by `Expr1Parser`. Antlr provides a helper function to look at the parse tree.

In [None]:
from antlr4.tree.Trees import Trees
Trees.toStringTree(tree, None, parser)

## Translating code

We can add attributes to the terminals and nonterminals of our grammar in order to store semantic information, and we can interleave code that is executed by the parser during the parsing process. For example, if we want to convert our expressions from infix notation to postfix notation, we can simply add `print` statements at the appropriate locations.

```
grammar Expr2;

expr : expr '+' term {print("+")} |
       expr '-' term {print("-")} |
       term;

term : DIGIT {print($DIGIT.text) } ;

DIGIT : ('0'..'9') ;
WS : [ \t\r\n]+ -> skip ;
```

The resulting lexer and parser are generated by Antlr as usual, and already included in the repository, so we can immediately parse an expression and convert it to postfix notation.

In [None]:
from Expr2Lexer import Expr2Lexer
from Expr2Parser import Expr2Parser

input = InputStream('1+2+3+4')
lexer = Expr2Lexer(input)
stream = CommonTokenStream(lexer)
parser = Expr2Parser(stream)
tree = parser.expr() 

Changing the language is simply a matter of updating the grammar rules, and rerunning Antlr. For example, if we want to allow our expressions to contain numbers with more than one digit, we could include a new nonterminal `number` that consists of at least one `DIGIT`.

```
grammar Expr3;

expr : expr '+' term {print("+")} |
       expr '-' term {print("-")} |
       term;

term : number  {print($number.text) } ;

number: DIGIT+;

DIGIT : ('0'..'9') ;
WS : [ \t\r\n]+ -> skip ;
```

In [None]:
from Expr3Lexer import Expr3Lexer
from Expr3Parser import Expr3Parser

input = InputStream('12+2+443+4')
lexer = Expr3Lexer(input)
stream = CommonTokenStream(lexer)
parser = Expr3Parser(stream)
tree = parser.expr() 

Let's make things a bit more challenging and move from these simple expressions to program code. We'll try to parse a simple fictitious language again.

In [None]:
example = """
begin
  x := 4;
  if y > 42 then
    x := 10;
    while x > 0 do
      begin
        x := x - 1
      end
end
"""

We'll start by defining the grammar for this language.

```
grammar SimpleProgram;

start : statement
      ;

statement : Identifier ':=' expr        # assignmentStatement
          | 'begin' opt_stmts 'end'     # blockStatement
          | 'if' expr 'then' statement  # ifStatement
          | 'while' expr 'do' statement # whileStatement
          ;

expr : expr op=('+' | '-' | '>') term  # binaryExpr
     | term                      # unaryExpr
     ;

term : Number
     | Identifier
     ;

opt_stmts : statement ';' opt_stmts
          | statement
          ;

Number : Digit+
       ;

Identifier : [a-zA-Z_] [a-zA-Z_0-9]*
           ;

Digit : ('0'..'9') ;
WS : [ \t\r\n]+ -> skip ;
```

In [None]:
from SimpleProgramLexer import SimpleProgramLexer
from SimpleProgramParser import SimpleProgramParser

input = InputStream(example)
lexer = SimpleProgramLexer(input)
stream = CommonTokenStream(lexer)
parser = SimpleProgramParser(stream)
tree = parser.start() 

In [None]:
Trees.toStringTree(tree, None, parser)

The translation from infix expressions to postfix expressions we did earlier is actually quite similar to the translation from Java source code to Java byte code. Java uses a stack machine, where all operations are performed with regards to an operand stack; thus, similar to a postfix subtraction an operation would take as many operands as it needs from the stack, performs the operation, and pushes the result back on the stack.

To represent our simple program in a bytecode-like notation, we define the following bytecode instructions:
- `HALT`: End of execution
- `LVALUE`: Load variable onto the stack
- `RVALUE`: Store top of stack in a local variable
- `LABEL`: Denote a location as jump target
- `GOTO`: Unconditional jump to target label
- `GOFALSE`: If top of stack represents the value false, then jump to target label
- `IADD`: Pop the top two operands from the stack, push result of addition back to stack
- `ISUB`: Pop the top two operands from the stack, push result of subtraction back to stack
- `CMPGT`: Pop the top two operands from the stack, apply numerical comparison and push integer (0/1) with result back to stack.

The following annotated version of the grammar prints out a bytecode version of the program, in the same way that our annotated grammar converted infix to postfix notation expressions.

```
grammar Expr4;

start : {self.unique_id=10000} statement {
print("HALT") }
      ;

statement : Identifier ':=' expr  {print("LVALUE "+$Identifier.text) }
          | 'begin' opt_stmts 'end'
          | 'if' expr 'then' {
label = str(self.unique_id)
self.unique_id += 1
print("GOFALSE "+label)
          } statement {print("LABEL "+label)
          }
          | 'while' {
label1 = str(self.unique_id)
self.unique_id += 1
label2 = str(self.unique_id)
self.unique_id += 1
print("LABEL "+label1)
                       }
                       expr {
print("GOFALSE "+label2)
                       }
                      'do' statement {
print("GOTO "+label1)
print("LABEL "+label2)
                       }
          ;

expr : expr '+' term {print("IADD") }
     | expr '-' term {print("ISUB") }
     | expr '>' term  {print("CMPGT") }
     | term
     ;
     
term : Number  {print("PUSH "+$Number.text) }
     | Identifier  {print("RVALUE "+$Identifier.text) }
     ;

opt_stmts : statement ';' opt_stmts
          | statement
          ;

Number : Digit+
       ;

Identifier : [a-zA-Z_] [a-zA-Z_0-9]*
           ;

Digit : ('0'..'9') ;
WS : [ \t\r\n]+ -> skip ;
```

As in the other cases the result of running Antlr on this grammar are already in the repository, so we can immidately try to parse the `example` code.

In [None]:
from Expr4Lexer import Expr4Lexer
from Expr4Parser import Expr4Parser

input = InputStream(example)
lexer = Expr4Lexer(input)
stream = CommonTokenStream(lexer)
parser = Expr4Parser(stream)
tree = parser.start() 

Our goal actually isn't compilation, but we are considering all this to understand where the Abstract Syntax Tree comes from. The datastructure that Antlr gives us is the raw parse tree, which we could interpret as a _concrete_ parse tree. To create an abstract syntax tree, we need to decide on the abstraction, and create a class hierarchy.

In [None]:
node_id = 0

class ASTNode:
    def __init__(self, name, children = []):
        global node_id
        self.children = children
        self.name = name
        self.id = node_id
        node_id += 1
        
    def get_label(self):
        return self.name
    
    def get_id(self):
        return str(self.id)

We need a unique ID for each node in order to visualize the resulting tree with GraphViz; the graph should show a more readable label for each node (`get_label`). We also need the nodes to be aware of their children, such that we can traverse the tree structure.

In [None]:
class Number(ASTNode):
    def __init__(self, num):
        self.number = num
        super().__init__("Number")

In [None]:
class Identifier(ASTNode):
    def __init__(self, name):
        self.identifier = name
        super().__init__("Identifier")  
        
    def get_label(self):
        return "Id: "+str(self.identifier)

In [None]:
class AssignmentStatement(ASTNode):
    def __init__(self, identifier, expression):
        self.identifier = identifier
        self.expression = expression
        super().__init__("Assignment", [identifier, expression])        

In [None]:
class BlockStatement(ASTNode):
    def __init__(self, statements):
        self.statements = statements
        super().__init__("Block", statements )

The `BlockStatement` is an example where we are abstracting: The corresponsing node in the concrete syntax tree will be a `Statement` node with three children, the terminals `begin` and `end`, which are irrelevant in our abstraction, and the `opt_stmts`, which is an unnecessary indirection we can avoid by directly adding the statements as children of `BlockStatement`.

In [None]:
class Expression(ASTNode):
    def __init__(self, lhs, rhs, op):
        self.lhs = lhs
        self.rhs = rhs
        self.op  = op
        super().__init__("Expression", [lhs, rhs])
        
    def get_label(self):
        return "Expression: "+str(self.op)

In [None]:
class IfStatement(ASTNode):
    def __init__(self, expr, then):
        self.expr = expr
        self.then = then
        super().__init__("If", [expr, then])

In [None]:
class WhileStatement(ASTNode):
    def __init__(self, expr, body):
        self.expr = expr
        self.body = body
        super().__init__("While", [expr, body])

One way of creating the AST is by visiting the concrete syntax tree and instantiating appropriate nodes. Antlr has already produced a visitor interface for our `SimpleProgram` grammar.

In [None]:
from SimpleProgramVisitor import SimpleProgramVisitor

In [None]:
class ASTBuilder(SimpleProgramVisitor):
    def visitStart(self, ctx:SimpleProgramParser.StartContext):
        return self.visit(ctx.statement())
    
    def visitAssignmentStatement(self, ctx):        
        return AssignmentStatement(Identifier(ctx.Identifier()), self.visit(ctx.expr()))
    
    def visitBlockStatement(self, ctx):
        return BlockStatement(self.visit(ctx.opt_stmts()))
    
    def visitIfStatement(self, ctx):
        return IfStatement(self.visit(ctx.expr()), self.visit(ctx.statement()))

    def visitWhileStatement(self, ctx):
        return WhileStatement(self.visit(ctx.expr()), self.visit(ctx.statement()))
    
    def visitUnaryExpr(self, ctx):
        return self.visitTerm(ctx.term())

    def visitBinaryExpr(self, ctx):
        return Expression(self.visit(ctx.expr()), self.visit(ctx.term()), ctx.op.text)

    def visitTerm(self, ctx):
        if ctx.getAltNumber() == 0:
            return Identifier(ctx.getChild(0).getText())
        else:
            return Number(ctx.getChild(0).getText())

    def visitOpt_stmts(self, ctx):
        statements = []
        statements.append(self.visit(ctx.statement()))
        if ctx.getChildCount() > 1:
            remaining_stmts = self.visitOpt_stmts(ctx.opt_stmts())
            statements.extend(remaining_stmts)
        return statements

Let's use our non-translating parser for the `SimpleProgram` grammar again.

In [None]:
input = InputStream(example)
lexer = SimpleProgramLexer(input)
stream = CommonTokenStream(lexer)
parser = SimpleProgramParser(stream)
tree = parser.start() 

To create our AST, we just need to apply the visitor.

In [None]:
builder = ASTBuilder()

In [None]:
tree.accept(builder)

...which of course doesn't tell us anything useful since we have not defined a string representation. Let's rather visualise the tree directly.

In [None]:
from graphviz import Digraph
def print_tree(tree, dot = Digraph()):

    dot.node(tree.get_id(), tree.get_label())
        
    for child in tree.children:
        dot.edge(tree.get_id(), child.get_id())
        print_tree(child, dot)
            
                 
    return dot

In [None]:
print_tree(tree.accept(builder))

Of course we could also integrate the AST Node creation directly in the attributed grammar.

```
grammar SimpleProgramAttributed;

start returns [node]
      : statement {$node = $statement.node }
      ;

statement returns [node]
          : Identifier ':=' expr        {$node = AssignmentStatement(Identifier($Identifier.text), $expr.node) }
          | 'begin' opt_stmts 'end'     {$node = BlockStatement($opt_stmts.nodes) }
          | 'if' a=expr 'then' statement  {$node = IfStatement($a.node, $statement.node) }
          | 'while' a=expr 'do' statement {$node = WhileStatement($a.node, $statement.node) }
          ;

expr returns [node]
     : a=expr op=('+' | '-' | '>') term  {$node = Expression($a.node, $term.node, $op.text) }
     | term                            {$node = $term.node }
     ;

term returns [node]
     : Number      {$node = Number($Number.text) }
     | Identifier  {$node = Identifier($Identifier.text) }
     ;

opt_stmts returns [nodes]
          : statement ';' opt_stmts  {$nodes = [ $statement.node] + $opt_stmts.nodes }
          | statement                {$nodes = [ $statement.node] }
          ;

Number : Digit+
       ;

Identifier : [a-zA-Z_] [a-zA-Z_0-9]*
           ;

Digit : ('0'..'9') ;
WS : [ \t\r\n]+ -> skip ;
```

## Linting

The first application of ASTs we looked at last week was to create embeddings from source code. However, the AST is immediately useful also without any machine learning. A common application of ASTs is linting, i.e., checking the AST whether it satisfies certain syntactic rules and whether it matches known patterns of problems. For example, many of the checks that [SpotBugs performs](https://spotbugs.readthedocs.io/en/latest/bugDescriptions.html) are based on the AST.

Let's use some Java code snippets for our analysis.

In [None]:
code1 = """
public class Foo {
  public void foo(int x) {
    System.out.println("Hello Clone!");
    int j = 10;
    for(int i = 0; i < x; i++) {
      System.out.println("Another iteration");
    }
  }
}
"""

In [None]:
code2 = """
public class Foo {
  public void foo(int x) { System.out.println("This is a very long line for the sake of the check")}
}
"""

We'll start by implementing some checks that we can apply directly at the character level. For example, [Checkstyle](https://checkstyle.sourceforge.io/config_sizes.html#FileLength) contains rules to check whether a maximum allowed number of lines is exceeded by a source code file, or if a maximum line length is exceeded.

In [None]:
class FileChecker:
    def check(self, code):
        lines = code.split('\n')
        return self.checkLines(lines)

In [None]:
class FileLengthChecker(FileChecker):
    def __init__(self):
        self.max_length = 6 # Extra small for example
        
    def checkLines(self, lines):
        return len(lines) > self.max_length

In [None]:
class LineLengthChecker(FileChecker):
    def __init__(self):
        self.max_length = 50 # Extra small for example
        
    def checkLines(self, lines):
        long_lines = [line for line in lines if len(line) > self.max_length]
        return len(long_lines) > 0

The first code example is longer than allowed.

In [None]:
FileLengthChecker().check(code1)

The second one isn't.

In [None]:
FileLengthChecker().check(code2)

The first contains only short lines.

In [None]:
LineLengthChecker().check(code1)

The second one contains a very long line.

In [None]:
LineLengthChecker().check(code2)

To extend these basic checks to more complicated syntactical checks, we will use the javalang parser again.

In [None]:
import javalang

In [None]:
class ASTChecker:
    def check(self, code):
        self.tree = javalang.parse.parse(code)
        return self.check_ast(self.tree)

For example, let's consider the SpotBugs check for [Covariant equals methods](https://spotbugs.readthedocs.io/en/latest/bugDescriptions.html#eq-covariant-equals-method-defined-eq-self-no-object). That is, if there is a method named equals that has a different signature than the one inherited from `java.lang.Object` then this is suspicious code.

In [None]:
class CovariantEqualsChecker(ASTChecker):
    def __init__(self):
        self.max_length = 50
        
    def check_ast(self, ast):
        for _, node in ast.filter(javalang.tree.MethodDeclaration):
            if node.name == "equals":
                if len(node.parameters) != 1:
                    return True
                if node.parameters[0].type.name != "Object":
                    return True
        return False

In [None]:
code3 = """
public class Foo {
  public boolean equals(String str) {
    return true;
  }
}
"""

In [None]:
CovariantEqualsChecker().check(code1)

In [None]:
CovariantEqualsChecker().check(code3)

As another AST example, let's consider the [Format String Newline](https://spotbugs.readthedocs.io/en/latest/bugDescriptions.html#fs-format-string-should-use-n-rather-than-n-va-format-string-uses-newline) check in SpotBugs. The problem matched by this check is whether a formatting string, used in the static method `String.format`, contains an explicit newline character (`\n`) rather than using the correct newline formatting string (`%n`).

In [None]:
code4 = """
public class Foo {
  public void foo(String str) {
    String foo = String.format("Foo\n");
    System.out.println(foo);
  }
}
"""

In [None]:
class FormatStringNewlineChecker(ASTChecker):
    def __init__(self):
        self.max_length = 50
        
    def check_ast(self, ast):
        for _, node in ast.filter(javalang.tree.MethodInvocation):            
            if node.member == "format" and \
                len(node.arguments) >= 1 and \
                node.qualifier == "String":
                if "\n" in node.arguments[0].value:
                    return True

        return False

In [None]:
FormatStringNewlineChecker().check(code1)

In [None]:
FormatStringNewlineChecker().check(code4)

As last example, consider the [Useless control flow](https://spotbugs.readthedocs.io/en/latest/bugDescriptions.html#ucf-useless-control-flow-ucf-useless-control-flow) checker: This describes an if-statement that has no effects since the then-block is empty.

In [None]:
code5 = """
public class Foo {
  public boolean foo(int x) {
    if (x > 0) {
    
    }
    System.out.println("Foo");
  }
}
"""

In [None]:
class UselessControlFlowChecker(ASTChecker):
    def __init__(self):
        self.max_length = 50
        
    def check_ast(self, ast):
        for _, node in ast.filter(javalang.tree.IfStatement):
            if isinstance(node.then_statement, javalang.tree.BlockStatement):
                if not node.then_statement.statements:
                    return True

        return False

In [None]:
UselessControlFlowChecker().check(code1)

In [None]:
UselessControlFlowChecker().check(code5)

In [None]:
code6 = """
public class Foo {
  public Boolean foo(int x) {
    return null;
  }
}
"""

In [None]:
class BooleanReturnNullChecker(ASTChecker):
    def __init__(self):
        self.max_length = 50
        
    def check_ast(self, ast):
        for _, node in ast.filter(javalang.tree.MethodDeclaration):
            if node.return_type and node.return_type.name == "Boolean":
                for _, return_stmt in ast.filter(javalang.tree.ReturnStatement):
                    expr = return_stmt.expression
                    if type(expr) == javalang.tree.Literal and expr.value == "null":
                            return True

        return False

In [None]:
BooleanReturnNullChecker().check(code1)

In [None]:
BooleanReturnNullChecker().check(code6)

## Code2vec

Although we've already covered the ASTNN approach for creating code embeddings, we will now also consider an alternative approach, which has contributed much to the general idea of code embeddings in the first place:

Alon, U., Zilberstein, M., Levy, O., & Yahav, E. (2019). code2vec: Learning distributed representations of code. Proceedings of the ACM on Programming Languages, 3(POPL), 1-29.

We will look at how to embed an individual method using code2vec. For this, let's define a simple helper function that gives us an AST rooted at a method declaration.

In [None]:
code = """
public int sum(int a, int b) {
   return a + b + 2;
}
"""

In [None]:
import javalang

In [None]:
def parse_method(code):
    class_code = "class Dummy {\n" + code + "\n}";
    tokens = javalang.tokenizer.tokenize(class_code)
    parser = javalang.parser.Parser(tokens)
    ast = parser.parse()
    _, node = list(ast.filter(javalang.tree.MethodDeclaration))[0]
    return node

In [None]:
tree = parse_method(code)

In [None]:
tree

In [None]:
from graphviz import Digraph
def print_tree(tree):
    unique_id = 1
    dot = Digraph()
    for path, node in tree:
        dot.node(str(id(node)), str(type(node)))
        
        for child in node.children:
            if isinstance(child, javalang.ast.Node):
                dot.edge(str(id(node)), str(id(child)))
            elif type(child) == str:
                strid = str(unique_id)
                unique_id = unique_id + 1
                dot.node(strid, child)
                dot.edge(str(id(node)), strid)
            elif type(child) == list:
                for lc in child:
                    dot.edge(str(id(node)), str(id(lc)))
                 
    return dot

In [None]:
print_tree(tree)

In contrast to ASTNN with its statement trees, code2vec looks at the concept of a path context, which is a path between two tokens in the AST.

We can easily retrieve a list of all terminals in the AST; for example we could traverse the tree and look for strings or sets.

In [None]:
for path, node in tree:
    for child in node.children:
        if child:
            if type(child) is str:
                print("Terminal: ", child)
            elif type(child) is set:
                for x in child:
                    print("Terminal ", x)
        
        

Let's put this into a function that gives us the terminals as well as the corresponding AST nodes.

In [None]:
def get_terminal_nodes(tree):
    for path, node in tree:
        for child in node.children:
            if child:
                if type(child) is str and child != "Dummy":
                    yield(node, child)
                elif type(child) is set:
                    for x in child:
                        yield(node, x)      

In [None]:
[ terminal for _, terminal in list(get_terminal_nodes(tree))]

A path context is defined as the path between two terminals, so let's pick to two terminals.

In [None]:
node1, terminal1 = list(get_terminal_nodes(tree))[-1]
node2, terminal2 = list(get_terminal_nodes(tree))[-2]

In [None]:
terminal1

In [None]:
terminal2

Let's first construct the path from a root node to a chosen terminal node.

In [None]:
def get_path(tree, node):
    if tree == node:
        return [tree]
    
    if type(tree) == list:
        for child in tree:
            path = get_path(child, node)
            if path:
                return path  
    
    if not isinstance(tree, javalang.tree.Node):
        return None
    
    for child in tree.children:
        path = get_path(child, node)
        if path:
            return [tree] + path  
    
    return None

In [None]:
def print_path(path):
    result = ""
    for node in path:
        if type(node) == str:
            result += node
        elif type(node) == list:
            result += print_path(node)
        else:
            result += str(type(node))

    return result

In [None]:
print_path(get_path(tree, node1))

In [None]:
print_path(get_path(tree, node2))

A path context consists of the path up the AST from the first terminal node to the least common ancestor of both terminal nodes, and then down the AST again to the second terminal node.

In [None]:
def path_context(tree, node1, node2):
    path1 = get_path(tree, node2)
    path1.reverse()
    for i in range(len(path1)):
        node = path1[i]
        path2 = get_path(node, node1)
        if path2:
            return (path1[:i], path2)

In [None]:
def print_path_context(path_context):
    down_path = []
    up_path = []
    for node in path_context[0]:
        if type(node) == str:
            up_path.append(node)
        else:
            up_path.append(node.__class__.__name__)
    for node in path_context[1]:
        if type(node) == str:
            down_path.append(node)
        else:
            down_path.append(node.__class__.__name__)
            
    return "↑".join(up_path) + "↑" + "↓".join(down_path)

In [None]:
print_path_context(path_context(tree, node1, node2))

In [None]:
print_path_context(path_context(tree, node2, node1))

In [None]:
terminal1

In [None]:
terminal2

To build the embeddings for a method, we next require the path context for every pair of terminal nodes in the AST.

In [None]:
terminals = list(get_terminal_nodes(tree))
paths = []
for index1 in range(len(terminals)-1):
    node1, terminal1 = terminals[index1]
    for index2 in range(index1 + 1, len(terminals)):
        node2, terminal2 = terminals[index2]
        path = path_context(tree, node1, node2)
        print(terminal1, ",", print_path_context(path), ",", terminal2)
        paths.append(path)
len(paths)

Converting a function to a path context

In [None]:
method1 = """
public int sum(int a, int b) {
   return a + b + 2;
}
"""

In [None]:
method2 = """
public void printHello(String name) {
   System.out.println("Hello " + name +"! ");
}
"""

In [None]:
method3 = """
public boolean isTheAnswer(int x) {
   if (x == 42) {
     return true;
   } else {
     return false;
   }
}
"""

In [None]:
def get_method_name(tree):
    for _, node in tree.filter(javalang.tree.MethodDeclaration):
        return node.name
    return None

In [None]:
def get_id(value, dictionary, vocab):
    if value in dictionary:
        return dictionary[value]
    else:
        new_id = len(dictionary.keys())
        dictionary[value] = new_id
        vocab[new_id] = value
    return new_id

In [None]:
from dataclasses import dataclass

# collection of all known paths, terminals, and method names in the dataset
@dataclass
class Vocabulary:
    # actual type of the value is not important, put in whatever is best
    paths: dict[int, str]
    terminals: dict[int, str]
    method_names: dict[int, str]

vocabulary = Vocabulary({}, {}, {})

In [None]:
train_x = []
train_y = []

paths = {}
method_names = {}
terminal_names = {}

for method in [method1, method2, method3]:
    method_ast = parse_method(method)
    name = get_method_name(method_ast)    
    method_id = get_id(name, method_names, vocabulary.method_names)
    path_contexts = []
    
    terminals = list(get_terminal_nodes(method_ast))
    for index1 in range(len(terminals)-1):
        node1, terminal1 = terminals[index1]
        terminal1_id = get_id(terminal1, terminal_names, vocabulary.terminals)
        for index2 in range(index1 + 1, len(terminals)):
            node2, terminal2 = terminals[index2]
            terminal2_id = get_id(terminal2, terminal_names, vocabulary.terminals)
            path = path_context(method_ast, node1, node2)
            path_str = print_path_context(path)
            path_id = get_id(path_str, paths, vocabulary.paths)
            print(terminal1, ",", path_str, ",", terminal2)
            print(terminal1_id, ",", path_id, ",", terminal2_id)
            path_contexts.append((terminal1_id, path_id, terminal2_id))
            
    train_x.append(path_contexts)
    train_y.append(method_id)

In [None]:
train_x

In [None]:
train_y

### Representation
Each path is represented as a vector $p$ with values that are not known initially. The terminals are each represented by a vector $t$ with unknown elements as well. By concatenating the three parts of a context its representation $c_i = [t_\mathrm{start}, p, t_\mathrm{end}]$ is created.
To learn how the different parts of $c_i$ relate to each other, a weight matrix $W$ with learnable weights is introduced. The product $\tilde{c}_i := \mathrm{tanh}(W \cdot c_i)$ is then called a *combined context vector* as it now no longer contains just the concatenation of the three separate parts.

The whole code snippet/method body is again represented as a single vector $v$. As different contexts of this code snippet are not equally important, the network has to learn which ones actually are. To achieve this, an attention vector that contains a weight $\alpha_i$ for each context is learned. The code vector $v$ can then be calculated as the weighted sum
$$
    v := \sum_{i=1}^{n} \alpha_i \cdot \tilde{c}_i
$$


Each method name is again represented as a vector $y$ with unknown values. The probability $q(y)$ that a code vector should be associated with this tag is calculated as $q(y) := \mathrm{softmax}(v^T \cdot y)$. By performing this calculation for all known tags the one with the highest probability to fit the code can be chosen.

### Learned Elements
- A vector $c$ as representation for each context as combination of representations $p$ for paths and $t$ for terminals.
- A weight matrix $W$ that contains information how the three parts of a context are combined.
- An attention weight $\alpha$ which contains information which contexts in a method are important.
- A vector $t$ as representation for each method name.

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.layers.base import Layer
from keras import Input, activations, optimizers, losses
import keras.backend as kb
from keras.layers import Embedding, Concatenate, Dropout, TimeDistributed, Dense

# how many paths does the biggest analysed function have
MAX_PATHS = 50
# length of the vectors that should represent paths and labels (same size for simplicity)
EMBEDDING_SIZE = 100
# embedding sizes of start, path, end added together
CONTEXT_EMBEDDING_SIZE = 3 * EMBEDDING_SIZE

In [None]:
# Adapted from: https://github.com/tech-srl/code2vec/blob/master/keras_model.py
def build_code2vec_model(vocab: Vocabulary):
    path_start_token = Input((MAX_PATHS,), dtype=tf.int32)
    path_input = Input((MAX_PATHS,), dtype=tf.int32)
    path_end_token = Input((MAX_PATHS,), dtype=tf.int32)
    # the sets of contexts for each function are padded to contain MAX_PATHS number of paths
    context_mask = Input((MAX_PATHS,))

    # The elements of the matrix are chosen randomly, as the actual values have to be learned.
    paths_embedded = Embedding(len(vocab.paths), EMBEDDING_SIZE,
                               name='path_embedding')(path_input)

    # Embed terminals the same way as paths.
    token_embedding = Embedding(len(vocab.terminals), EMBEDDING_SIZE,
                                name='token_embedding')
    path_start_token_embedded = token_embedding(path_start_token)
    path_end_token_embedded = token_embedding(path_end_token)

    # Representation of contexts $c_i$: concatenation of start, path, end
    context_embedded = Concatenate()([path_start_token_embedded, paths_embedded, path_end_token_embedded])
    # Dropout to prevent overfitting.
    context_embedded = Dropout(0.25)(context_embedded)

    # $\tilde{c}_i = tanh(Wc_i)$
    # Fully connected layer that learns to combine the three parts of a context.
    context_after_dense = TimeDistributed(
        Dense(CONTEXT_EMBEDDING_SIZE, use_bias=False,
              activation=activations.tanh))(context_embedded)

    # AttentionLayer learns which path contexts are the most important.
    # A code_vector $v$ now is the final representation for a piece of code.
    code_vectors, attention_weights = AttentionLayer(name='attention')([context_after_dense, context_mask])

    # $q(y) := softmax(v^T y)$
    # Final dense layer: Learn how the method names should be represented.
    # For each method name: The probability that a given code vector represents a method name is
    # the dot product of those two values after softmax normalisation.
    # The target_index is the key of the method name in the vocabulary with the highest probability.
    target_index = Dense(len(vocab.method_names), use_bias=False,
                         activation=activations.softmax, name='target_index')(
        code_vectors)

    inputs = [path_start_token, path_input, path_end_token, context_mask]
    outputs = [target_index]
    return keras.Model(name='code2vec', inputs=inputs, outputs=outputs)

In [None]:
# Learns which of the contexts in the method are the most important.
#
# Adapted from: https://github.com/tech-srl/code2vec/blob/master/keras_attention_layer.py
class AttentionLayer(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, inputs_shape):
        shape_actual_input = inputs_shape[0]
        self.input_length = int(shape_actual_input[1])
        self.input_dim = int(shape_actual_input[2])

        # The vector that defines how much each context should be weighted.
        # Initialized with random values, model learns the actual ones.
        attention_param_shape = (self.input_dim, 1)
        self.attention_param = self.add_weight(name='attention_param',
                                               shape=attention_param_shape,
                                               initializer='uniform',
                                               trainable=True, dtype=tf.float32)

        super(AttentionLayer, self).build(shape_actual_input)

    def call(self, inputs, **kwargs):
        context = inputs[0]
        mask = inputs[1]

        # multiply each context with the attention to get the weight it should have in the final code_vector
        attention_weights = kb.dot(context, self.attention_param)

        if len(mask.shape) == 2:
            mask = kb.expand_dims(mask, axis=2)
        mask = kb.log(mask)
        attention_weights += mask

        # normalise weights
        attention_weights = kb.softmax(attention_weights, axis=1)
        # the code vector is just a weighted sum of contexts
        code_vector = kb.sum(context * attention_weights, axis=1)

        return code_vector, attention_weights

    def compute_output_shape(self, input_shape):
        return input_shape[0], input_shape[2]

In [None]:
model = build_code2vec_model(vocabulary)
model.summary()

In [None]:
model.compile(optimizer=optimizers.adam_v2.Adam(), loss=losses.CategoricalCrossentropy())

# TODO: model.fit

# Creating Big Code Datasets

A prerequisite for all these deep learning applications are large datasets of source code. Let's briefly have a look how these are commonly created.

The most common approach is to collect code from GitHub, for which GitHub provides a convenient [API](https://docs.github.com/en/rest).

In [None]:
import requests

The question of which repositories to mine (a random set? the most popular ones?) is a tricky one. Let's assume we are interested in the top repositories with the most stars. 

In [None]:
url = 'https://api.github.com/search/repositories?q=language:java&sort=stars'
response = requests.get(url)

response_dict = response.json()

print("Total repos:", response_dict['total_count'])
repos_dicts = response_dict['items']
print("Repos found:", len(repos_dicts))    

Let's have a look what these repositories are, and collect their Git URLs.

In [None]:
urls = []

for repos_dict in repos_dicts:
    print('\nName:', repos_dict['name'])
    print('Owner:', repos_dict['owner']['login'])
    print('Stars:', repos_dict['stargazers_count'])
    print('Repository:', repos_dict['html_url'])
    urls.append((repos_dict['name'], repos_dict['html_url'] +".git"))

Given our list of Git URLs, the next question is how to extract source code.

In [None]:
from git import Repo

In [None]:
import tempfile

In [None]:
tmp_dir = tempfile.mkdtemp()

In [None]:
import itertools # To limit loop iterations

for repo_name, repo_url in itertools.islice(urls, 3):
    print("Current url", repo_url)
    repo = Repo.clone_from(repo_url, tmp_dir +"/" + repo_name)

In [None]:
import os

java_files = []
for root,d_names,f_names in os.walk(tmp_dir):
    for f in f_names:
        if f.endswith(".java"):
            java_files.append(os.path.join(root, f))

print(len(java_files))

## Mining bugs

Often we are interested in more than just raw source code. For example, many analysis approaches require information about specific bugs -- either because we want to evaluate analysis techniques (how effective is the analysis at finding bugs?), or because we want to train a model to detect bugs.

A convenient way to investigate Git repositories is offered by [PyDriller](https://github.com/ishepard/pydriller). For example, we can conveniently traverse the commits of a repository.

In [None]:
from pydriller import Repository


example_repo = 'https://github.com/se2p/LitterBox.git'

for commit in itertools.islice(Repository(example_repo).traverse_commits(), 1200, 1210):
    print('Hash {}, author {}'.format(commit.hash, commit.author.name))

In [None]:
Repo.clone_from(example_repo, tmp_dir +"/litterbox")

In [None]:
example_git = tmp_dir + "/litterbox"

Why files were modified in each of the commits?

In [None]:
for commit in itertools.islice(Repository(example_git).traverse_commits(), 1200, 1210):
    for file in commit.modified_files:
        print('Author {} modified {} in commit {}'.format(commit.author.name, file.filename, commit.hash))

In order to produce data on real bugs, there are two common strategies: The first strategy is to find commits that indicate they are fixing bugs. We then know that the version before that commit contains a bug, whereas the version afterwards does not. 

A more challenging question is when a bug was introduced. Identifying this is commonly done using the SZZ algorithm, named after the authors:

Śliwerski J, Zimmermann T, Zeller A (2005) When do changes induce fixes? In: Proceedings of the
2005 International Workshop on Mining Software Repositories, ACM, New York, NY, USA, MSR
’05, pp 1–5


The approach consists of two phases: 

- In the first phase, bug-fixing commits are identified, often by investigating bug tracker data, or by looking for commit messages that contain the word `fix`, or refer to an issue ID. 

- In the second phase, for each bug-fixing commit we identify all commits that previously made changes to the same lines of code that were changed in the bug-fixing commit. 

The latter can be simply done by using `git blame` to identify bug-introducing commit candidates.

In [None]:
fix_commits = []

for commit in Repository(example_git).traverse_commits():
    # Exclude merge commits 
    if commit.merge: 
        continue
    msg = commit.msg.lower()
    if "fix" in msg:
        fix_commits.append(commit)

In [None]:
len(fix_commits)

In [None]:
from pydriller import Git

git = Git(example_git)

In [None]:
commit = git.get_commit("cd665cfc27315130102e2cf816a3d58fe82ce77e")

In [None]:
print(f"Fix Commit {commit.hash}: {commit.msg}")
for file in commit.modified_files:
    lines = [line for (line, text) in file.diff_parsed["deleted"]]
    print(f" -> {file.filename}: {lines}")

In [None]:
commit.modified_files[0].diff_parsed["deleted"]

In [None]:
commit.modified_files[0].diff_parsed["added"]

In [None]:
commit.modified_files[1].diff_parsed["added"]

In [None]:
bug_changes = git.get_commits_last_modified_lines(commit)
bug_changes

In [None]:
bug_commit_hash = list(bug_changes[list(bug_changes)[0]])[0] # WTF

In [None]:
bug_commit = git.get_commit(bug_commit_hash)

Given the bug commit, we can consider what was the buggy code that was added in this particular version.

In [None]:
for file in bug_commit.modified_files:
    if file.filename == "WeightedMethodCount.java":
        print([line for line, text in file.diff_parsed["added"]])