# Experiment for RQ2: Intermediate node prediction

### We first build the dataset for intermediate node prediction

In [1]:
import torch
import numpy as np
from itertools import combinations
from tqdm.notebook import tqdm
from utils import load_vocab
from torch.utils.data import DataLoader
import argparse
import os
from pathlib import Path
from py_config_runner import ConfigObject
from module import FastASTTrans
from ignite.utils import convert_tensor
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import torch.nn as nn
from torch.optim import AdamW
import torch.nn.functional as F

In [2]:
matrices_path = './csa-trans/processed/tree_sitter_python/test/split_matrices.npz'
data = np.load(matrices_path, allow_pickle=True)
test_rfs = data["root_first_seq"]

for split_idx, rfs in enumerate([test_rfs]):
    original_test_dataset = {}
    for data_idx, sample_ast in enumerate(tqdm(rfs)):
        # generate parent_path_list and brother_path_list
        distance_map = {}
        brother_map = {}

        parent_path_list = []
        brother_path_list = []

        for node in sample_ast:
            if len(node.children) == 0:
                path = [node.label]
                n = node
                while n.parent is not None:
                    path.append(n.parent.label)
                    n = n.parent
                parent_path_list.append(list(reversed(path)))
            else:
                brother_path_list.append([child.label for child in node.children])

        # remove identifier dangling nodes from parent path
        refined_parent_path_list = []
        for path in parent_path_list:
            idt_remove_path = []
            for e in path:
                if e.split(':')[0] != 'idt':
                    idt_remove_path.append(e)
                else:
                    break
            if len(idt_remove_path) < 2:
                continue
            refined_parent_path_list.append(idt_remove_path)

        # remove identifier nodes from brother list
        refined_brother_path_list = []
        for path in brother_path_list:
            idt_remove_path = []
            for e in path:
                if e.split(':')[0] != 'idt':
                    idt_remove_path.append(e)
            refined_brother_path_list.append(idt_remove_path)

        # get dataset
        parent_intermediate = set()
        brother_intermediate = set()
        for par_path in refined_parent_path_list:
            for idx in range(len(par_path) - 2):
                prev = par_path[idx]
                inter = par_path[idx + 1]
                after = par_path[idx + 2]
                parent_intermediate.add((prev, inter, after))

        for bro_path in refined_brother_path_list:
            for e in combinations(bro_path, 2):
                parent = sample_ast[int(e[0].split(':')[-1]) -1].parent
                brother_intermediate.add((e[0], parent.label, e[1]))

        parent_intermediate = list(parent_intermediate)
        brother_intermediate = list(brother_intermediate)
        # sample
        SAMPLE_NUM = min([10, len(parent_intermediate), len(brother_intermediate)])
        idx1 = np.random.choice(range(len(parent_intermediate)), size=SAMPLE_NUM, replace=False)
        idx2 = np.random.choice(range(len(brother_intermediate)), size=SAMPLE_NUM, replace=False)
        parent_inter = [parent_intermediate[i] for i in idx1]
        brother_inter = [brother_intermediate[i] for i in idx2]
        original_test_dataset[data_idx] = parent_inter + brother_inter

  0%|          | 0/18502 [00:00<?, ?it/s]

In [3]:
def _graph_prepare_batch(batch, device=None, non_blocking: bool = False):
    x, y = batch
    return (
        x.to(device),
        convert_tensor(y, device=device, non_blocking=non_blocking),
    )

class SyntheticDataset(Dataset):
    def __init__(self, embeddings, targets):
        self.x = embeddings
        self.y = targets

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x = self.x[idx]
        y = self.y[idx]
        return x, y
    
class MLP(nn.Module):
    def __init__(self, indim, hidden, outdim):
        super().__init__()
        self.fc1 = nn.Linear(indim, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, hidden)
        self.fc4 = nn.Linear(hidden, outdim)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(F.relu(self.fc2(x)))
        x = self.dropout(F.relu(self.fc3(x)))
        x = F.relu(self.fc4(x))
        return x

## Node prediction for CSA-Trans

In [4]:
parser = argparse.ArgumentParser("Example application")
parser.add_argument("--config", type=Path, help="Input configuration file")
parser.add_argument("--use_hype_params", action="store_true")
parser.add_argument("--data_type", type=str, default="")
parser.add_argument("--exp_type", type=str, default="summary")
parser.add_argument("--g", type=str, default="")

args = parser.parse_args(['--config', './config/python.py', '--g', '0',])
config = ConfigObject(args.config)

if args.g != "":
    os.environ["CUDA_VISIBLE_DEVICES"] = args.g
    config.device = "cuda"
    config.g = args.g
    
(config.src_vocab,config.tgt_vocab,) = load_vocab(config.data_dir, config.is_split, config.data_type)
test_data_set = config.data_set(config, "test")
test_loader = DataLoader(dataset=test_data_set,batch_size=config.batch_size // len(config.g.split(",")),shuffle=False,collate_fn=test_data_set.collect_fn,)

Data Set Name : < Fast AST Data Set >
loading test data...
loading ./processed/tree_sitter_python/test/split_pot.seq...


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 18502/18502 [00:04<00:00, 4320.55it/s]


loading ./processed/tree_sitter_python/test/nl.original ...


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 18502/18502 [00:00<00:00, 306835.57it/s]


loading matrices...
building dataset
building edges.


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 18502/18502 [01:27<00:00, 210.99it/s]


dataset lenght: 18502


In [5]:
model = config.model(
    config.src_vocab.size(),
    config.tgt_vocab.size(),
    config.hidden_size,
    config.num_heads,
    config.num_layers,
    config.sbm_layers,
    config.use_pegen,
    config.dim_feed_forward,
    config.dropout,
    config.pe_dim,
    config.pegen_dim,
    config.sbm_enc_dim,
    config.clusters,
    config.full_att,
    config.checkpoint,
    config.max_src_len,
)
state_path = './csa-trans/outputs/final_models/256_512_512_4_4_10_10_10_b64_tgt50_vanilla/best_model_445_val_bleu=0.3652.pt'
state_dict = torch.load(state_path)
model.load_state_dict(state_dict)
model = model.to('cuda')

sbm_param: 6777344
decoder param: 16817152
generator param: 10260000
Init or load model.


In [6]:
src_pes = []
from tqdm.notebook import tqdm
with torch.no_grad():
    for idx, batch in tqdm(enumerate(test_loader)):
        x, y = _graph_prepare_batch(batch)
        y_, sparsity, src_pe, graphs, attns = model(x.to('cuda'))
        src_pe = src_pe.detach()
        src_pes += src_pe

0it [00:00, ?it/s]

In [7]:
BATCH_SIZE=128
input_dim=512
hidden_dim=1024
device ='cuda'
criterion = nn.CrossEntropyLoss()

In [8]:
whole = list(original_test_dataset.values())
X = [] # the two embeddings
Y = [] # intermediate vocabulary
for ast_idx, ast_instance in tqdm(enumerate(whole)):
    for path in ast_instance:
        node1 = src_pes[ast_idx][int(path[0].split(':')[-1]) - 1]
        node2 = src_pes[ast_idx][int(path[2].split(':')[-1]) - 1]

        if path[1].split(':')[1] not in config.src_vocab.w2i.keys():
            continue
        tgt = config.src_vocab.w2i[path[1].split(':')[1]]
        
        X.append(torch.cat([node1, node2]))
        Y.append(tgt)
train_X = X[:int(len(X)*0.8)]
train_Y = Y[:int(len(X)*0.8)]
test_X = X[int(len(X)*0.8):]
test_Y = Y[int(len(X)*0.8):]
train_dataset = SyntheticDataset(train_X, train_Y)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataset = SyntheticDataset(test_X, test_Y)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

0it [00:00, ?it/s]

In [9]:
print(len(train_X))
print(len(test_X))

293814
73454


In [10]:
out_dim=config.src_vocab.size()
model = MLP(input_dim, hidden_dim, out_dim)
model = model.to(device)
optimizer = AdamW(model.parameters(), lr=1e-4)
num_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"num_param: {num_param}")

model.train()
for epoch in range(50):
    loss_acc = 0
    for batch in train_loader:
        x, y = batch
        pred = model(x.to(device))
        loss = criterion(pred, y.to(device).squeeze())
        loss.backward()
        loss_acc += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        
    if epoch%5 == 0:
        per_sample = loss_acc / len(train_loader) / BATCH_SIZE
        total_correct = 0
        model.eval()
        for batch in test_loader:
            x, y = batch
            pred = model(x.to(device))
            total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
        print(f'Epoch - {epoch}, Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')
        model.train()


num_param: 12874512
Epoch - 0, Accuracy: 0.8217688202857971
Epoch - 5, Accuracy: 0.8769599199295044
Epoch - 10, Accuracy: 0.8866643309593201
Epoch - 15, Accuracy: 0.8973486423492432
Epoch - 20, Accuracy: 0.8905161023139954
Epoch - 25, Accuracy: 0.892271876335144
Epoch - 30, Accuracy: 0.8931974172592163
Epoch - 35, Accuracy: 0.8904889225959778
Epoch - 40, Accuracy: 0.8949259519577026
Epoch - 45, Accuracy: 0.8969947695732117


In [11]:
total_correct = 0
model.eval()
for batch in tqdm(test_loader):
    x, y = batch
    pred = model(x.to(device))
    total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
print(f'Total correct: {total_correct}')
print(f'Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')

  0%|          | 0/574 [00:00<?, ?it/s]

Total correct: 65822
Accuracy: 0.8958787322044373


## Treepos

In [12]:
args = parser.parse_args(['--config', './config/python_treepos.py', '--g', '0',])
config = ConfigObject(args.config)
(config.src_vocab,config.tgt_vocab,) = load_vocab(config.data_dir, config.is_split, config.data_type)
model = config.model(
    config.src_vocab.size(),
    config.tgt_vocab.size(),
    config.hidden_size,
    config.num_heads,
    config.num_layers,
    config.sbm_layers,
    config.use_pegen,
    config.dim_feed_forward,
    config.dropout,
    config.pe_dim,
    config.pegen_dim,
    config.sbm_enc_dim,
    config.clusters,
    config.full_att,
    config.checkpoint,
    config.max_src_len,
)
test_data_set = config.data_set(config, "test")
test_loader = DataLoader(dataset=test_data_set,batch_size=config.batch_size // len(config.g.split(",")),shuffle=False,collate_fn=test_data_set.collect_fn,)
state_path = './csa-trans/outputs/final_models/python_v3_treepos/best_model_495_val_bleu=0.3628.pt'
state_dict = torch.load(state_path)
model.load_state_dict(state_dict)
model.eval()
model = model.to('cuda')

sbm_param: 6777344
decoder param: 16817152
generator param: 10260000
Init or load model.
Data Set Name : < Fast AST Data Set >
loading test data...
loading existing dataset
dataset lenght: 18502


In [13]:
src_pes = []
with torch.no_grad():
    for idx, batch in tqdm(enumerate(test_loader)):
        x, y = _graph_prepare_batch(batch)
        y_, sparsity, src_pe, _, _ = model(x.to('cuda'))
        src_pe = src_pe.detach()
        src_pes += src_pe

0it [00:00, ?it/s]

In [14]:
whole = list(original_test_dataset.values())
X = []
Y = []
for ast_idx, ast_instance in tqdm(enumerate(whole)):
    for path in ast_instance:
        node1 = src_pes[ast_idx][int(path[0].split(':')[-1]) - 1]
        node2 = src_pes[ast_idx][int(path[2].split(':')[-1]) - 1]

        if path[1].split(':')[1] not in config.src_vocab.w2i.keys():
            continue
        tgt = config.src_vocab.w2i[path[1].split(':')[1]]
        
        X.append(torch.cat([node1, node2]))
        Y.append(tgt)
        
train_X = X[:int(len(X)*0.8)]
train_Y = Y[:int(len(X)*0.8)]
test_X = X[int(len(X)*0.8):]
test_Y = Y[int(len(X)*0.8):]
train_dataset = SyntheticDataset(train_X, train_Y)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataset = SyntheticDataset(test_X, test_Y)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

0it [00:00, ?it/s]

In [15]:
device ='cuda'
out_dim=config.src_vocab.size()
model = MLP(input_dim, hidden_dim, out_dim)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-4)
num_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"num_param: {num_param}")

model.train()
for epoch in range(50):
    loss_acc = 0
    for batch in train_loader:
        x, y = batch
        pred = model(x.to(device))
        loss = criterion(pred, y.to(device).squeeze())
        loss.backward()
        loss_acc += loss.item()
        optimizer.step()
        optimizer.zero_grad()
    if epoch%5 == 0:
        per_sample = loss_acc / len(train_loader) / BATCH_SIZE
        total_correct = 0
        model.eval()
        for batch in test_loader:
            x, y = batch
            pred = model(x.to(device))
            total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
        print(f'Epoch - {epoch}, Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')
        model.train()

num_param: 12874512
Epoch - 0, Accuracy: 0.6163436770439148
Epoch - 5, Accuracy: 0.6502476930618286
Epoch - 10, Accuracy: 0.6522485017776489
Epoch - 15, Accuracy: 0.651921808719635
Epoch - 20, Accuracy: 0.6494311094284058
Epoch - 25, Accuracy: 0.649567186832428
Epoch - 30, Accuracy: 0.649703323841095
Epoch - 35, Accuracy: 0.6477161645889282
Epoch - 40, Accuracy: 0.648015558719635
Epoch - 45, Accuracy: 0.6464503407478333


In [16]:
total_correct = 0
model.eval()
for batch in tqdm(test_loader):
    x, y = batch
    pred = model(x.to(device))
    total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
print(f'Total correct: {total_correct}')
print(f'Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')

  0%|          | 0/574 [00:00<?, ?it/s]

Total correct: 47571
Accuracy: 0.6474711298942566


## Laplacian

### Since laplacian pe and sequential pe are not learnable, they can be just used by changing the configs from treepos configuration

In [17]:
args = parser.parse_args(['--config', './config/python_lap.py', '--g', '0',])
config = ConfigObject(args.config)
(config.src_vocab,config.tgt_vocab,) = load_vocab(config.data_dir, config.is_split, config.data_type)
model = config.model(
    config.src_vocab.size(),
    config.tgt_vocab.size(),
    config.hidden_size,
    config.num_heads,
    config.num_layers,
    config.sbm_layers,
    config.use_pegen,
    config.dim_feed_forward,
    config.dropout,
    config.pe_dim,
    config.pegen_dim,
    config.sbm_enc_dim,
    config.clusters,
    config.full_att,
    config.checkpoint,
    config.max_src_len,
)
test_data_set = config.data_set(config, "test")
test_loader = DataLoader(dataset=test_data_set,batch_size=config.batch_size // len(config.g.split(",")),shuffle=False,collate_fn=test_data_set.collect_fn,)

sbm_param: 6777344
decoder param: 16817152
generator param: 10260000
Init or load model.
Data Set Name : < Fast AST Data Set >
loading test data...
loading existing dataset
dataset lenght: 18502


In [18]:
state_path = './csa-trans/outputs/final_models/python_v3_lap/best_model_470_val_bleu=0.3542.pt'
state_dict = torch.load(state_path)
model.load_state_dict(state_dict)
model.eval()
model = model.to('cuda')

In [19]:
src_pes = []
from tqdm.notebook import tqdm
with torch.no_grad():
    for idx, batch in tqdm(enumerate(test_loader)):
        x, y = _graph_prepare_batch(batch)
        y_, sparsity, src_pe, _, _ = model(x.to('cuda'))
        src_pe = src_pe.detach()
        src_pes += src_pe

0it [00:00, ?it/s]

In [20]:
whole = list(original_test_dataset.values())
X = [] # the two embeddings
Y = [] # intermediate vocabulary
for ast_idx, ast_instance in tqdm(enumerate(whole)):
    for path in ast_instance:
        node1 = src_pes[ast_idx][int(path[0].split(':')[-1]) - 1]
        node2 = src_pes[ast_idx][int(path[2].split(':')[-1]) - 1]

        if path[1].split(':')[1] not in config.src_vocab.w2i.keys():
            continue
        tgt = config.src_vocab.w2i[path[1].split(':')[1]]
        
        X.append(torch.cat([node1, node2]))
        Y.append(tgt)
        
train_X = X[:int(len(X)*0.8)]
train_Y = Y[:int(len(X)*0.8)]
test_X = X[int(len(X)*0.8):]
test_Y = Y[int(len(X)*0.8):]
train_dataset = SyntheticDataset(train_X, train_Y)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataset = SyntheticDataset(test_X, test_Y)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

0it [00:00, ?it/s]

In [21]:
device ='cuda'
out_dim=config.src_vocab.size()
model = MLP(input_dim, hidden_dim, out_dim)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-4)
num_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"num_param: {num_param}")

model.train()
for epoch in range(50):
    loss_acc = 0
    for batch in train_loader:
        x, y = batch
        pred = model(x.to(device))
        loss = criterion(pred, y.to(device).squeeze())
        loss.backward()
        loss_acc += loss.item()
        optimizer.step()
        optimizer.zero_grad()
    if epoch%5 == 0:
        per_sample = loss_acc / len(train_loader) / BATCH_SIZE
        total_correct = 0
        model.eval()
        for batch in test_loader:
            x, y = batch
            pred = model(x.to(device))
            total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
        print(f'Epoch - {epoch}, Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')
        model.train()

num_param: 12874512
Epoch - 0, Accuracy: 0.2962352931499481
Epoch - 5, Accuracy: 0.38283973932266235
Epoch - 10, Accuracy: 0.4082643687725067
Epoch - 15, Accuracy: 0.43522703647613525
Epoch - 20, Accuracy: 0.44885125756263733
Epoch - 25, Accuracy: 0.45135563611984253
Epoch - 30, Accuracy: 0.4589911997318268
Epoch - 35, Accuracy: 0.45531630516052246
Epoch - 40, Accuracy: 0.4538191556930542
Epoch - 45, Accuracy: 0.46235302090644836


In [22]:
total_correct = 0
model.eval()
for batch in tqdm(test_loader):
    x, y = batch
    pred = model(x.to(device))
    total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
print(f'Total correct: {total_correct}')
print(f'Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')

  0%|          | 0/574 [00:00<?, ?it/s]

Total correct: 33482
Accuracy: 0.4557110369205475


## Sequential

In [23]:
from module.components import PositionalEncoding
sbm_enc_dim=512
src_pe = PositionalEncoding(sbm_enc_dim, config["max_src_len"])
src_pe = src_pe.pe.squeeze()

In [24]:
whole = list(original_test_dataset.values())
X = [] # the two embeddings
Y = [] # intermediate vocabulary
for ast_idx, ast_instance in tqdm(enumerate(whole)):
    for path in ast_instance:
        node1 = src_pe[int(path[0].split(':')[-1]) - 1]
        node2 = src_pe[int(path[2].split(':')[-1]) - 1]

        if path[1].split(':')[1] not in config.src_vocab.w2i.keys():
            continue
        tgt = config.src_vocab.w2i[path[1].split(':')[1]]
        X.append(torch.cat([node1, node2]))
        Y.append(tgt)
        
train_X = X[:int(len(X)*0.8)]
train_Y = Y[:int(len(X)*0.8)]
test_X = X[int(len(X)*0.8):]
test_Y = Y[int(len(X)*0.8):]
train_dataset = SyntheticDataset(train_X, train_Y)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataset = SyntheticDataset(test_X, test_Y)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

0it [00:00, ?it/s]

In [25]:
input_dim=1024
hidden_dim=1024
out_dim=config.src_vocab.size()
model = MLP(input_dim, hidden_dim, out_dim)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-4)
num_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"num_param: {num_param}")

model.train()
for epoch in range(50):
    loss_acc = 0
    for batch in train_loader:
        x, y = batch
        pred = model(x.to(device))
        loss = criterion(pred, y.to(device).squeeze())
        loss.backward()
        loss_acc += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        per_sample = loss.item() / BATCH_SIZE
    if epoch%5 == 0:
        per_sample = loss_acc / len(train_loader) / BATCH_SIZE
        
        total_correct = 0
        model.eval()
        for batch in test_loader:
            x, y = batch
            pred = model(x.to(device))
            total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
        print(f'Epoch - {epoch}, Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')
        model.train()

num_param: 13398800
Epoch - 0, Accuracy: 0.36192017793655396
Epoch - 5, Accuracy: 0.36646613478660583
Epoch - 10, Accuracy: 0.3715973496437073
Epoch - 15, Accuracy: 0.3741697669029236
Epoch - 20, Accuracy: 0.3732578456401825
Epoch - 25, Accuracy: 0.3739519715309143
Epoch - 30, Accuracy: 0.37321701645851135
Epoch - 35, Accuracy: 0.37287673354148865
Epoch - 40, Accuracy: 0.37434670329093933
Epoch - 45, Accuracy: 0.37412893772125244


In [26]:
total_correct = 0
model.eval()
for batch in tqdm(test_loader):
    x, y = batch
    pred = model(x.to(device))
    total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
print(f'Total correct: {total_correct}')
print(f'Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')

  0%|          | 0/574 [00:00<?, ?it/s]

Total correct: 27450
Accuracy: 0.373611718416214


## Triplet

In [27]:
args = parser.parse_args(['--config', './config/python_triplet.py', '--g', '0',])
config = ConfigObject(args.config)
(config.src_vocab,config.tgt_vocab,) = load_vocab(config.data_dir, config.is_split, config.data_type)
model = config.model(
    config.src_vocab.size(),
    config.tgt_vocab.size(),
    config.hidden_size,
    config.num_heads,
    config.num_layers,
    config.sbm_layers,
    config.use_pegen,
    config.dim_feed_forward,
    config.dropout,
    config.pe_dim,
    config.pegen_dim,
    config.sbm_enc_dim,
    config.clusters,
    config.full_att,
    config.checkpoint,
    config.max_src_len,
)
test_data_set = config.data_set(config, "test")
test_loader = DataLoader(dataset=test_data_set,batch_size=config.batch_size // len(config.g.split(",")),shuffle=False,collate_fn=test_data_set.collect_fn,)

sbm_param: 6777344
decoder param: 16817152
generator param: 10260000
Init or load model.
Data Set Name : < Fast AST Data Set >
loading test data...
loading existing dataset
dataset lenght: 18502


In [28]:
state_path = '/home/coinse/greentea/src/cbcgt/outputs/final_models/python_v3_triplet/python_triplet.pt'
state_dict = torch.load(state_path)
model.load_state_dict(state_dict)
model.eval()
model = model.to('cuda')

In [29]:
src_pes = []
with torch.no_grad():
    for idx, batch in tqdm(enumerate(test_loader)):
        # if idx >= 5:
        #     break
        x, y = _graph_prepare_batch(batch)
        y_, sparsity, src_pe, _, _ = model(x.to('cuda'))
        src_pe = src_pe.detach()
        src_pes += src_pe

0it [00:00, ?it/s]

In [30]:
whole = list(original_test_dataset.values())
X = [] # the two embeddings
Y = [] # intermediate vocabulary
for ast_idx, ast_instance in tqdm(enumerate(whole)):
    for path in ast_instance:
        node1 = src_pes[ast_idx][int(path[0].split(':')[-1]) - 1]
        node2 = src_pes[ast_idx][int(path[2].split(':')[-1]) - 1]

        if path[1].split(':')[1] not in config.src_vocab.w2i.keys():
            continue
        tgt = config.src_vocab.w2i[path[1].split(':')[1]]
        
        X.append(torch.cat([node1, node2]))
        Y.append(tgt)
        
train_X = X[:int(len(X)*0.8)]
train_Y = Y[:int(len(X)*0.8)]
test_X = X[int(len(X)*0.8):]
test_Y = Y[int(len(X)*0.8):]
train_dataset = SyntheticDataset(train_X, train_Y)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataset = SyntheticDataset(test_X, test_Y)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

0it [00:00, ?it/s]

In [31]:
input_dim=512
hidden_dim=1024
out_dim=config.src_vocab.size()
model = MLP(input_dim, hidden_dim, out_dim)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-4)
num_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"num_param: {num_param}")

model.train()
for epoch in range(50):
    loss_acc = 0
    for batch in train_loader:
        x, y = batch
        pred = model(x.to(device))
        loss = criterion(pred, y.to(device).squeeze())
        loss.backward()
        loss_acc += loss.item()
        optimizer.step()
        optimizer.zero_grad()
    if epoch%5 == 0:
        per_sample = loss_acc / len(train_loader) / BATCH_SIZE
        total_correct = 0
        model.eval()
        for batch in test_loader:
            x, y = batch
            pred = model(x.to(device))
            total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
        print(f'Epoch - {epoch}, Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')
        model.train()

num_param: 12874512
Epoch - 0, Accuracy: 0.5974248647689819
Epoch - 5, Accuracy: 0.6176366806030273
Epoch - 10, Accuracy: 0.6222915053367615
Epoch - 15, Accuracy: 0.6238430738449097
Epoch - 20, Accuracy: 0.6251769661903381
Epoch - 25, Accuracy: 0.6228495240211487
Epoch - 30, Accuracy: 0.6257485747337341
Epoch - 35, Accuracy: 0.626415491104126
Epoch - 40, Accuracy: 0.6255171895027161
Epoch - 45, Accuracy: 0.6257213950157166


In [32]:
total_correct = 0
model.eval()
for batch in tqdm(test_loader):
    x, y = batch
    pred = model(x.to(device))
    total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
print(f'Total correct: {total_correct}')
print(f'Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')

  0%|          | 0/574 [00:00<?, ?it/s]

Total correct: 45896
Accuracy: 0.6246733665466309
