In [1]:
import torch
import torch.nn as nn

In [59]:
# emulate a forward pass given a complete binary tree of depth 2 (in total, 3 nodes)
# token of 3
# type vocabulary of 3
# hyper parameter dimension
token_size = 3
type_size = 3
dim = 2

# create the type embedding and vocabulary embedding
token_embed = nn.Embedding(token_size, dim)
type_embed = nn.Embedding(type_size, dim)

# create a linear layer to extract node embedding from the token and type embeddings
linear = nn.Linear(dim*2, dim)

zero_tensor = torch.zeros(dim)

In [253]:
# generate the input tree representation
# for each node, there needs to be (type_id, token_id)
tree_with_indices = torch.tensor([[[0,0],[1,1],[2,2]],[[1,1],[-1,-1],[-1,-1]],[[2,2],[-1,-1],[-1,-1]]])
# convert tree_with_indices to tree with vectors based on type and token embedding
def embed(tree):
    embed_tree = []
    for subtree_nodes in tree:
        subtree_embed = []
        for type_index, token_index in subtree_nodes:
            node_embed = []
            if type_index != -1:
                concat = torch.cat((type_embed(type_index), token_embed(token_index)), 0) 
                node_embed = [linear(concat).tolist()]
                #print(node_embed)
            else:
                node_embed = [zero_tensor.tolist()]
                #print(node_embed)
            subtree_embed.append(node_embed)
        embed_tree.append(subtree_embed)
    return torch.tensor(embed_tree).squeeze(2)

In [143]:
# derive the embedded tree with type and token information fused through a linear layer
tree_embed = embed(tree_with_indices)
print(tree_embed.size())
print(tree_embed)

torch.Size([3, 3, 2])
tensor([[[ 0.4627, -1.1416],
         [ 0.2623, -0.0223],
         [ 0.7116, -1.8547]],

        [[ 0.2623, -0.0223],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000]],

        [[ 0.7116, -1.8547],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000]]])


In [204]:
# the tbcnn method to convolve nodes to a single vector
# declare the embedding for top, left, and right node
t_emb = nn.Parameter(torch.rand([dim, dim])) 
r_emb = nn.Parameter(torch.rand([dim, dim]))
l_emb = nn.Parameter(torch.rand([dim, dim]))
bias = nn.Parameter(torch.rand(1))

In [288]:
def eta_t(di, d):
    # di is the height of the node i in the sliding window, starting from 1, not 0. d is the total depth of the tree
    return (di - 1.0) / (d - 1.0)

def eta_r(etat, pi, n):
    # pi is the position of the child node, from left to right, the position should be 1, 2, ..., n
    if n == 1:
        return 0
    return (1.0 - etat) * (pi - 1.0) / (n - 1.0)

def eta_l(etat, etar):
    return (1.0 - etat) * (1.0 - etar)

def convovle(tree_emb):
    #incomplete
    conv_tree = []
    for subtree in tree_emb:
        y = 0.0
        for node_id, node_vec in enumerate(subtree):
            node_vec = node_vec.unsqueeze(1)
            if node_id == 0:
                # this is the parent node
                etat = 1
                y = etat * torch.matmul(t_emb,node_vec)
            else :
                if torch.sum(node_vec) != 0.0:
                    etat = 0
                    etar = eta_r(etat, node_id, 2)
                    etal = eta_l(etat, etar)
                    y += etat * torch.matmul(t_emb, node_vec) + etar * torch.matmul(r_emb, node_vec) + etal * torch.matmul(l_emb, node_vec)
        y = torch.tanh(y + bias).tolist()
        # tolist() is fine for calculating the tbcnn embedding, but not ok when going for gradient decent because torch operation information is lost
        conv_tree.append(y)
    return torch.tensor(conv_tree).squeeze()
    


In [289]:
tbcnn_output = convovle(tree_embed)