In [86]:
import multicoder

from tokenizecode import CodeParser, CodeTokenizer

import json
import torch
import torch.nn.functional as F

In [87]:
py_code = """def foo():
    print("hello world!")
    for i in [1,2,3]:
        if (i % 2 == 0):
            print(i)
        else:
            print(i+1)"""

In [88]:
java_code = """class HelloWorld {
    public static void main(String[] args) {
        System.out.println("Hello, World!"); 
    }
}"""

### Example tree

In [89]:
parser = CodeParser()
out = parser.parse(py_code, "python")
out.tree.pprint()

TensorTree():
 0. [module]
╰── 1. [function_definition]
    ├── 2. def
    ├── 3. ·
    ├── 4. [identifier]
    │   ╰── 5. foo
    ├── 6. [parameters]
    │   ├── 7. (
    │   ╰── 8. )
    ├── 9. :
    ├── 10. ⏎····
    ╰── 11. [block]
        ├── 12. [expression_statement]
        │   ╰── 13. [call]
        │       ├── 14. [identifier]
        │       │   ╰── 15. print
        │       ╰── 16. [argument_list]
        │           ├── 17. (
        │           ├── 18. [string]
        │           │   ├── 19. "
        │           │   ├── 20. hello·world!
        │           │   ╰── 21. "
        │           ╰── 22. )
        ├── 23. ⏎····
        ╰── 24. [for_statement]
            ├── 25. for
            ├── 26. ·
            ├── 27. [identifier]
            │   ╰── 28. i
            ├── 29. ·
            ├── 30. in
            ├── 31. ·
            ├── 32. [list]
            │   ├── 33. [
            │   ├── 34. [integer]
            │   │   ╰── 35. 1
            │   ├── 36. ,
        

In [90]:
tree = out.tree

## Root candidates
A root candidate is a non-terminal node which subtree can be removed from the tree and can be considered as a self-contained unit. Leaves in this subtree should attend on each other leaves as they are considered to serve the same semantic purpose in the tree.

### 1. Considering each non-terminal as a root candidate
Disadvantage: Terminal children (all leaves) of each nonterminal attend only to siblings and underlying leaves, which create a pretty sparse matrix for deeply nested trees.

In [172]:
mask = torch.zeros(len(out.tree), len(out.tree))

def create_mask(node_idx):
    if not tree.is_leaf(node_idx):

        # attend to all leaf children
        leaves_mask = tree[node_idx].leaves_mask()
        # mask[node_idx] = F.pad(leaves_mask, (len(tree)-len(leaves_mask), 0))
        # [0 0 0.......]  [..........]  [0 0 0 .......]
        # len(node_idx)    leaves_mask       rest
        mask[node_idx] = torch.cat((torch.zeros(node_idx), leaves_mask, torch.zeros(len(tree) - len(leaves_mask) - node_idx)))        
        
        # attend to theirself
        mask[node_idx][node_idx] = 1

        # fill mask for chilren
        for child_idx in tree.iter_children(node_idx):
            create_mask(child_idx)
    
    elif tree.is_leaf(node_idx):
        
        # copy parent node mask
        mask[node_idx] = mask[tree.get_parent(node_idx)]
        
        # attend to theirself
        mask[node_idx][node_idx] = 1

In [173]:
create_mask(0)

In [174]:
mask

tensor([[1., 0., 1.,  ..., 0., 1., 1.],
        [0., 1., 1.,  ..., 0., 1., 1.],
        [0., 1., 1.,  ..., 0., 1., 1.],
        ...,
        [0., 0., 0.,  ..., 1., 1., 0.],
        [0., 0., 0.,  ..., 1., 1., 0.],
        [0., 0., 0.,  ..., 0., 1., 1.]])

In [154]:
mask.shape

torch.Size([97, 97])

### 2. Consider only special non-terminals as root candidates
For each terminal, the mask of the next root candidate ancestor node is copied.
Disadvantage: For deeply nested root candidate nodes (e.g. if-for-if-if...), the mask will be sparse like in (1) as the children only attend to the nodes that their parents attend to. This could maybe be resolved by setting a fixed MAX_DEPTH (e.g. if the child is nested more than depth=MAX_DEPTH, take root candidate ancestor on level MAX_DEPTH as nearest parent).

In [95]:
with open("/Users/savinadiez/masterthesis/git/code-buddy/multicoder/tokenizer/nonterminals.json", "rt") as f:
    nonterminals = json.load(f)

In [96]:
nonterminal_names = list(nonterminals.keys())

In [98]:
# Kinder dieser Knoten sollen nur auf Geschwister + ihre Blätter attenden, da sie abgeschlossene Einheiten bilden:
# *block*, if_statement, *method*, *function*, call, for_statement, do_block, switch_statement, while_statement

# Kinder dieser Knoten sollen auf das attenden, auf das der nächsthöhere root candidate attendet:
# identifier, tag_name, string, ...
# diese Nonterminals selbst attenden aber trotzdem nur auf sich und unterliegende?

In [180]:
mask = torch.zeros(len(tree), len(tree))
root_candidates = ['[if_statement]', '[for_statement]']

def create_mask_considering_root_candidates(node_idx):

    # Nonterminals
    if not tree.is_leaf(node_idx):

        # attend to all leaf children
        leaves_mask = tree[node_idx].leaves_mask()
        """Nonterminal masks look like this:
                [0 0 0.......] [..........] [0 0 0 .......]
                len(node_idx)   leaves_mask      rest
        """
        mask[node_idx] = torch.cat((torch.zeros(node_idx), leaves_mask, torch.zeros(len(tree) - len(leaves_mask) - node_idx)))
        
        # attend to theirself
        mask[node_idx][node_idx] = 1

        # fill mask for chilren
        for child_idx in tree.iter_children(node_idx):
            create_mask_considering_root_candidates(child_idx)
    
    # Terminals (leaves)
    elif tree.is_leaf(node_idx):
       
        # iterate ancestors until root candidate is found
        for ancestor_idx in tree.iter_ancestors(node_idx):

            if tree.get_node_data(ancestor_idx) in root_candidates:
               
                # copy next highest root candidate's node mask
                mask[node_idx] = mask[ancestor_idx]
                
                # attend to theirself
                mask[node_idx][node_idx] = 1
                return
            
        # if no relevant ancestor found (should not be possible), copy parent node mask (maybe better: root mask?)
        mask[node_idx] = mask[tree.get_parent(node_idx)]

In [181]:
create_mask_considering_root_candidates(0)

In [184]:
mask[87]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 1., 0., 0., 0.,
        1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 0., 0., 0.,
        1., 0., 1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0., 1.,
        0., 0., 1., 1., 0., 1., 1.])