# Experiment for RQ2: Intermediate node prediction

### We first build the dataset for intermediate node prediction

In [1]:
import torch
import numpy as np
from itertools import combinations
from tqdm.notebook import tqdm
from utils import load_vocab
from torch.utils.data import DataLoader
import argparse
import os
from pathlib import Path
from py_config_runner import ConfigObject
from module import FastASTTrans
from ignite.utils import convert_tensor
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import torch.nn as nn
from torch.optim import AdamW
import torch.nn.functional as F

In [2]:
matrices_path = './csa-trans/processed/tree_sitter_java/test/split_matrices.npz'
data = np.load(matrices_path, allow_pickle=True)
test_rfs = data["root_first_seq"]

for split_idx, rfs in enumerate([test_rfs]):
    original_test_dataset = {}
    for data_idx, sample_ast in enumerate(tqdm(rfs)):
        # generate parent_path_list and brother_path_list
        distance_map = {}
        brother_map = {}

        parent_path_list = []
        brother_path_list = []

        for node in sample_ast:
            if len(node.children) == 0:
                path = [node.label]
                n = node
                while n.parent is not None:
                    path.append(n.parent.label)
                    n = n.parent
                parent_path_list.append(list(reversed(path)))
            else:
                brother_path_list.append([child.label for child in node.children])

        # remove identifier dangling nodes from parent path
        refined_parent_path_list = []
        for path in parent_path_list:
            idt_remove_path = []
            for e in path:
                if e.split(':')[0] != 'idt':
                    idt_remove_path.append(e)
                else:
                    break
            if len(idt_remove_path) < 2:
                continue
            refined_parent_path_list.append(idt_remove_path)

        # remove identifier nodes from brother list
        refined_brother_path_list = []
        for path in brother_path_list:
            idt_remove_path = []
            for e in path:
                if e.split(':')[0] != 'idt':
                    idt_remove_path.append(e)
            refined_brother_path_list.append(idt_remove_path)

        # get dataset
        parent_intermediate = set()
        brother_intermediate = set()
        for par_path in refined_parent_path_list:
            for idx in range(len(par_path) - 2):
                prev = par_path[idx]
                inter = par_path[idx + 1]
                after = par_path[idx + 2]
                parent_intermediate.add((prev, inter, after))

        for bro_path in refined_brother_path_list:
            for e in combinations(bro_path, 2):
                parent = sample_ast[int(e[0].split(':')[-1]) -1].parent
                brother_intermediate.add((e[0], parent.label, e[1]))

        parent_intermediate = list(parent_intermediate)
        brother_intermediate = list(brother_intermediate)
        # sample
        SAMPLE_NUM = min([10, len(parent_intermediate), len(brother_intermediate)])
        idx1 = np.random.choice(range(len(parent_intermediate)), size=SAMPLE_NUM, replace=False)
        idx2 = np.random.choice(range(len(brother_intermediate)), size=SAMPLE_NUM, replace=False)
        parent_inter = [parent_intermediate[i] for i in idx1]
        brother_inter = [brother_intermediate[i] for i in idx2]
        original_test_dataset[data_idx] = parent_inter + brother_inter

  0%|          | 0/8714 [00:00<?, ?it/s]

In [3]:
def _graph_prepare_batch(batch, device=None, non_blocking: bool = False):
    x, y = batch
    return (
        x.to(device),
        convert_tensor(y, device=device, non_blocking=non_blocking),
    )

class SyntheticDataset(Dataset):
    def __init__(self, embeddings, targets):
        self.x = embeddings
        self.y = targets

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x = self.x[idx]
        y = self.y[idx]
        return x, y
    
class MLP(nn.Module):
    def __init__(self, indim, hidden, outdim):
        super().__init__()
        self.fc1 = nn.Linear(indim, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, hidden)
        self.fc4 = nn.Linear(hidden, outdim)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(F.relu(self.fc2(x)))
        x = self.dropout(F.relu(self.fc3(x)))
        x = F.relu(self.fc4(x))
        return x

## Node prediction for CSA-Trans on Java

In [4]:
parser = argparse.ArgumentParser("Example application")
parser.add_argument("--config", type=Path, help="Input configuration file")
parser.add_argument("--use_hype_params", action="store_true")
parser.add_argument("--data_type", type=str, default="")
parser.add_argument("--exp_type", type=str, default="summary")
parser.add_argument("--g", type=str, default="")

args = parser.parse_args(['--config', './config/java.py', '--g', '0',])
config = ConfigObject(args.config)

if args.g != "":
    os.environ["CUDA_VISIBLE_DEVICES"] = args.g
    config.device = "cuda"
    config.g = args.g

(config.src_vocab,config.tgt_vocab,) = load_vocab(config.data_dir, config.is_split, config.data_type)
test_data_set = config.data_set(config, "test")
test_loader = DataLoader(dataset=test_data_set,batch_size=config.batch_size // len(config.g.split(",")),shuffle=False,collate_fn=test_data_set.collect_fn,)

Data Set Name : < Fast AST Data Set >
loading test data...
loading existing dataset
dataset lenght: 8714


In [5]:
model = config.model(
    config.src_vocab.size(),
    config.tgt_vocab.size(),
    config.hidden_size,
    config.num_heads,
    config.num_layers,
    config.sbm_layers,
    config.use_pegen,
    config.dim_feed_forward,
    config.dropout,
    config.pe_dim,
    config.pegen_dim,
    config.sbm_enc_dim,
    config.clusters,
    config.full_att,
    config.checkpoint,
    config.max_src_len,
)

state_path = './csa-trans/outputs/final_models/java_v3/java_v3.pt'
state_dict = torch.load(state_path)
model.load_state_dict(state_dict)
model = model.to('cuda')

sbm_param: 14789888
decoder param: 16817152
generator param: 10260000
Init or load model.


In [6]:
src_pes = []
with torch.no_grad():
    for idx, batch in tqdm(enumerate(test_loader)):
        x, y = _graph_prepare_batch(batch)
        y_, sparsity, src_pe, graphs, attns = model(x.to('cuda'))
        src_pe = src_pe.detach()
        src_pes += src_pe

0it [00:00, ?it/s]

In [7]:
BATCH_SIZE=128
input_dim=256
hidden_dim=1024
device ='cuda'
criterion = nn.CrossEntropyLoss()

In [8]:
whole = list(original_test_dataset.values())
X = [] # the two embeddings
Y = [] # intermediate vocabulary
for ast_idx, ast_instance in tqdm(enumerate(whole)):
    for path in ast_instance:
        node1 = src_pes[ast_idx][int(path[0].split(':')[-1]) - 1]
        node2 = src_pes[ast_idx][int(path[2].split(':')[-1]) - 1]

        if path[1].split(':')[1] not in config.src_vocab.w2i.keys():
            continue
        tgt = config.src_vocab.w2i[path[1].split(':')[1]]
        
        X.append(torch.cat([node1, node2]))
        Y.append(tgt)
        
train_X = X[:int(len(X)*0.8)]
train_Y = Y[:int(len(X)*0.8)]
test_X = X[int(len(X)*0.8):]
test_Y = Y[int(len(X)*0.8):]

train_dataset = SyntheticDataset(train_X, train_Y)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataset = SyntheticDataset(test_X, test_Y)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

0it [00:00, ?it/s]

In [9]:
print(len(train_X))
print(len(test_X))

138051
34513


In [10]:
out_dim=config.src_vocab.size()
model = MLP(input_dim, hidden_dim, out_dim)
model = model.to(device)
optimizer = AdamW(model.parameters(), lr=1e-4)
num_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"num_param: {num_param}")

model.train()
for epoch in range(50):
    loss_acc = 0
    for batch in train_loader:
        x, y = batch
        pred = model(x.to(device))
        loss = criterion(pred, y.to(device).squeeze())
        loss.backward()
        loss_acc += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        per_sample = loss.item() / BATCH_SIZE
    if epoch%5 == 0:
        per_sample = loss_acc / len(train_loader) / BATCH_SIZE
        total_correct = 0
        model.eval()
        for batch in test_loader:
            x, y = batch
            pred = model(x.to(device))
            total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
        print(f'Epoch - {epoch}, Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')
        model.train()


num_param: 12612368
Epoch - 0, Accuracy: 0.7125868201255798
Epoch - 5, Accuracy: 0.8318576216697693
Epoch - 10, Accuracy: 0.8503182530403137
Epoch - 15, Accuracy: 0.8592592477798462
Epoch - 20, Accuracy: 0.8638599514961243
Epoch - 25, Accuracy: 0.8642071485519409
Epoch - 30, Accuracy: 0.8642361164093018
Epoch - 35, Accuracy: 0.8656249642372131
Epoch - 40, Accuracy: 0.8673321604728699
Epoch - 45, Accuracy: 0.8692708015441895


In [11]:
total_correct = 0
model.eval()
for batch in tqdm(test_loader):
    x, y = batch
    pred = model(x.to(device))
    total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
print(f'Total correct: {total_correct}')
print(f'Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')

  0%|          | 0/270 [00:00<?, ?it/s]

Total correct: 29961
Accuracy: 0.8669270873069763


## Do the same experiment for Treepos encoding

In [12]:
args = parser.parse_args(['--config', './config/java_treepos.py', '--g', '0',])
config = ConfigObject(args.config)
(config.src_vocab,config.tgt_vocab,) = load_vocab(config.data_dir, config.is_split, config.data_type)
model = config.model(
    config.src_vocab.size(),
    config.tgt_vocab.size(),
    config.hidden_size,
    config.num_heads,
    config.num_layers,
    config.sbm_layers,
    config.use_pegen,
    config.dim_feed_forward,
    config.dropout,
    config.pe_dim,
    config.pegen_dim,
    config.sbm_enc_dim,
    config.clusters,
    False,
    config.checkpoint,
    config.max_src_len,
)
test_data_set = config.data_set(config, "test")
test_loader = DataLoader(dataset=test_data_set,batch_size=config.batch_size // len(config.g.split(",")),shuffle=False,collate_fn=test_data_set.collect_fn,)
state_path = './csa-trans/outputs/final_models/java_v3_treepos/java_v3_treepos.pt'
state_dict = torch.load(state_path)
model.load_state_dict(state_dict)
model.eval()
model = model.to('cuda')

sbm_param: 14789888
decoder param: 16817152
generator param: 10260000
Init or load model.
Data Set Name : < Fast AST Data Set >
loading test data...
loading existing dataset
dataset lenght: 8714


In [13]:
src_pes = []
with torch.no_grad():
    for idx, batch in tqdm(enumerate(test_loader)):
        x, y = _graph_prepare_batch(batch)
        y_, sparsity, src_pe, _, _ = model(x.to('cuda'))
        src_pe = src_pe.detach()
        src_pes += src_pe

0it [00:00, ?it/s]

In [14]:
whole = list(original_test_dataset.values())
X = [] # the two embeddings
Y = [] # intermediate vocabulary
for ast_idx, ast_instance in tqdm(enumerate(whole)):
    for path in ast_instance:
        node1 = src_pes[ast_idx][int(path[0].split(':')[-1]) - 1]
        node2 = src_pes[ast_idx][int(path[2].split(':')[-1]) - 1]

        if path[1].split(':')[1] not in config.src_vocab.w2i.keys():
            continue
        tgt = config.src_vocab.w2i[path[1].split(':')[1]]
        
        X.append(torch.cat([node1, node2]))
        Y.append(tgt)
        
train_X = X[:int(len(X)*0.8)]
train_Y = Y[:int(len(X)*0.8)]
test_X = X[int(len(X)*0.8):]
test_Y = Y[int(len(X)*0.8):]
train_dataset = SyntheticDataset(train_X, train_Y)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataset = SyntheticDataset(test_X, test_Y)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

0it [00:00, ?it/s]

In [15]:
out_dim=config.src_vocab.size()
model = MLP(input_dim, hidden_dim, out_dim)
model = model.to(device)
optimizer = AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
num_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"num_param: {num_param}")

model.train()
for epoch in range(50):
    loss_acc = 0
    for batch in train_loader:
        x, y = batch
        pred = model(x.to(device))
        loss = criterion(pred, y.to(device).squeeze())
        loss.backward()
        loss_acc += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        per_sample = loss.item() / BATCH_SIZE
    if epoch%5 == 0:
        per_sample = loss_acc / len(train_loader) / BATCH_SIZE
        total_correct = 0
        model.eval()
        for batch in test_loader:
            x, y = batch
            pred = model(x.to(device))
            total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
        print(f'Epoch - {epoch}, Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')
        model.train()

num_param: 12612368
Epoch - 0, Accuracy: 0.4765046238899231
Epoch - 5, Accuracy: 0.5166956186294556
Epoch - 10, Accuracy: 0.5260995030403137
Epoch - 15, Accuracy: 0.5280382037162781
Epoch - 20, Accuracy: 0.5299189686775208
Epoch - 25, Accuracy: 0.5328414440155029
Epoch - 30, Accuracy: 0.5312210321426392
Epoch - 35, Accuracy: 0.528211772441864
Epoch - 40, Accuracy: 0.5264177918434143
Epoch - 45, Accuracy: 0.5288194417953491


In [16]:
total_correct = 0
model.eval()
for batch in tqdm(test_loader):
    x, y = batch
    pred = model(x.to(device))
    total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
print(f'Total correct: {total_correct}')
print(f'Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')

  0%|          | 0/270 [00:00<?, ?it/s]

Total correct: 18163
Accuracy: 0.5255497694015503


## Try with laplacian pe

### Since laplacian pe and sequential pe are not learnable, they can be just used by changing the configs from treepos configuration

In [17]:
args = parser.parse_args(['--config', './config/java_lap.py', '--g', '0',])
config = ConfigObject(args.config)
(config.src_vocab,config.tgt_vocab,) = load_vocab(config.data_dir, config.is_split, config.data_type)
model = config.model(
    config.src_vocab.size(),
    config.tgt_vocab.size(),
    config.hidden_size,
    config.num_heads,
    config.num_layers,
    config.sbm_layers,
    config.use_pegen,
    config.dim_feed_forward,
    config.dropout,
    config.pe_dim,
    config.pegen_dim,
    config.sbm_enc_dim,
    config.clusters,
    config.full_att,
    config.checkpoint,
    config.max_src_len,
)
test_data_set = config.data_set(config, "test")
test_loader = DataLoader(dataset=test_data_set,batch_size=config.batch_size // len(config.g.split(",")),shuffle=False,collate_fn=test_data_set.collect_fn,)

sbm_param: 14789888
decoder param: 16817152
generator param: 10260000
Init or load model.
Data Set Name : < Fast AST Data Set >
loading test data...
loading existing dataset
dataset lenght: 8714


In [18]:
state_path = './csa-trans/outputs/final_models/java_v3_lap/java_v3_lap.pt'
state_dict = torch.load(state_path)
model.load_state_dict(state_dict)
model.eval()
model = model.to('cuda')

In [19]:
src_pes = []
with torch.no_grad():
    for idx, batch in tqdm(enumerate(test_loader)):
        x, y = _graph_prepare_batch(batch)
        y_, sparsity, src_pe, _, _ = model(x.to('cuda'))
        src_pe = src_pe.detach()
        src_pes += src_pe

0it [00:00, ?it/s]

In [20]:
whole = list(original_test_dataset.values())
X = [] # the two embeddings
Y = [] # intermediate vocabulary
for ast_idx, ast_instance in tqdm(enumerate(whole)):
    for path in ast_instance:
        node1 = src_pes[ast_idx][int(path[0].split(':')[-1]) - 1]
        node2 = src_pes[ast_idx][int(path[2].split(':')[-1]) - 1]

        if path[1].split(':')[1] not in config.src_vocab.w2i.keys():
            continue
        tgt = config.src_vocab.w2i[path[1].split(':')[1]]
        
        X.append(torch.cat([node1, node2]))
        Y.append(tgt)
        
train_X = X[:int(len(X)*0.8)]
train_Y = Y[:int(len(X)*0.8)]
test_X = X[int(len(X)*0.8):]
test_Y = Y[int(len(X)*0.8):]
train_dataset = SyntheticDataset(train_X, train_Y)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataset = SyntheticDataset(test_X, test_Y)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

0it [00:00, ?it/s]

In [21]:
out_dim=config.src_vocab.size()
model = MLP(input_dim, hidden_dim, out_dim)
model = model.to(device)
optimizer = AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
num_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"num_param: {num_param}")

model.train()
for epoch in range(50):
    loss_acc = 0
    for batch in train_loader:
        x, y = batch
        pred = model(x.to(device))
        loss = criterion(pred, y.to(device).squeeze())
        loss.backward()
        loss_acc += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        per_sample = loss.item() / BATCH_SIZE
    if epoch%5 == 0:
        per_sample = loss_acc / len(train_loader) / BATCH_SIZE
        
        total_correct = 0
        model.eval()
        for batch in test_loader:
            x, y = batch
            pred = model(x.to(device))
            total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
        print(f'Epoch - {epoch}, Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')
        model.train()

num_param: 12612368
Epoch - 0, Accuracy: 0.24635416269302368
Epoch - 5, Accuracy: 0.3444155156612396
Epoch - 10, Accuracy: 0.3838252127170563
Epoch - 15, Accuracy: 0.3977430462837219
Epoch - 20, Accuracy: 0.41197916865348816
Epoch - 25, Accuracy: 0.4162904918193817
Epoch - 30, Accuracy: 0.4221932888031006
Epoch - 35, Accuracy: 0.41495949029922485
Epoch - 40, Accuracy: 0.4181423485279083
Epoch - 45, Accuracy: 0.4087962806224823


In [22]:
total_correct = 0
model.eval()
for batch in tqdm(test_loader):
    x, y = batch
    pred = model(x.to(device))
    total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
print(f'Total correct: {total_correct}')
print(f'Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')

  0%|          | 0/270 [00:00<?, ?it/s]

Total correct: 14274
Accuracy: 0.4130208194255829


## Sequential

In [23]:
from module.components import PositionalEncoding
sbm_enc_dim=768
src_pe = PositionalEncoding(sbm_enc_dim, config["max_src_len"])
src_pe = src_pe.pe.squeeze()

In [24]:
whole = list(original_test_dataset.values())
X = [] # the two embeddings
Y = [] # intermediate vocabulary
for ast_idx, ast_instance in tqdm(enumerate(whole)):
    for path in ast_instance:
        node1 = src_pe[int(path[0].split(':')[-1]) - 1]
        node2 = src_pe[int(path[2].split(':')[-1]) - 1]

        if path[1].split(':')[1] not in config.src_vocab.w2i.keys():
            continue
        tgt = config.src_vocab.w2i[path[1].split(':')[1]]
        
        X.append(torch.cat([node1, node2]))
        Y.append(tgt)
        
train_X = X[:int(len(X)*0.8)]
train_Y = Y[:int(len(X)*0.8)]
test_X = X[int(len(X)*0.8):]
test_Y = Y[int(len(X)*0.8):]
train_dataset = SyntheticDataset(train_X, train_Y)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataset = SyntheticDataset(test_X, test_Y)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

0it [00:00, ?it/s]

In [25]:
input_dim=768*2
hidden_dim=1024
out_dim=config.src_vocab.size()
model = MLP(input_dim, hidden_dim, out_dim)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-4)
num_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"num_param: {num_param}")

#BATCH_SIZE=64
model.train()
for epoch in range(50):
    loss_acc = 0
    for batch in train_loader:
        x, y = batch
        pred = model(x.to(device))
        loss = criterion(pred, y.to(device).squeeze())
        loss.backward()
        loss_acc += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        per_sample = loss.item() / BATCH_SIZE
    if epoch%5 == 0:
        per_sample = loss_acc / len(train_loader) / BATCH_SIZE
        total_correct = 0
        model.eval()
        for batch in test_loader:
            x, y = batch
            pred = model(x.to(device))
            total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
        print(f'Epoch - {epoch}, Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')
        model.train()

num_param: 13923088
Epoch - 0, Accuracy: 0.419502317905426
Epoch - 5, Accuracy: 0.4334779977798462
Epoch - 10, Accuracy: 0.44441550970077515
Epoch - 15, Accuracy: 0.4505786895751953
Epoch - 20, Accuracy: 0.4515624940395355
Epoch - 25, Accuracy: 0.4497685134410858
Epoch - 30, Accuracy: 0.4528645873069763
Epoch - 35, Accuracy: 0.45315393805503845
Epoch - 40, Accuracy: 0.45092591643333435
Epoch - 45, Accuracy: 0.45321178436279297


In [26]:
total_correct = 0
model.eval()
for batch in tqdm(test_loader):
    x, y = batch
    pred = model(x.to(device))
    total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
print(f'Total correct: {total_correct}')
print(f'Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')

  0%|          | 0/270 [00:00<?, ?it/s]

Total correct: 15645
Accuracy: 0.4526909589767456


## Triplet

In [27]:
args = parser.parse_args(['--config', './config/java_triplet.py', '--g', '0',])
config = ConfigObject(args.config)
(config.src_vocab,config.tgt_vocab,) = load_vocab(config.data_dir, config.is_split, config.data_type)
model = config.model(
    config.src_vocab.size(),
    config.tgt_vocab.size(),
    config.hidden_size,
    config.num_heads,
    config.num_layers,
    config.sbm_layers,
    config.use_pegen,
    config.dim_feed_forward,
    config.dropout,
    config.pe_dim,
    config.pegen_dim,
    config.sbm_enc_dim,
    config.clusters,
    config.full_att,
    config.checkpoint,
    config.max_src_len,
)
test_data_set = config.data_set(config, "test")
test_loader = DataLoader(dataset=test_data_set,batch_size=config.batch_size // len(config.g.split(",")),shuffle=False,collate_fn=test_data_set.collect_fn,)

sbm_param: 14789888
decoder param: 16817152
generator param: 10260000
Init or load model.
Data Set Name : < Fast AST Data Set >
loading test data...
loading existing dataset
dataset lenght: 8714


In [28]:
state_path = './csa-trans/outputs/final_models/java_v3_triplet/java_triplet.pt'
state_dict = torch.load(state_path)
model.load_state_dict(state_dict)
model.eval()
model = model.to('cuda')

In [29]:
src_pes = []
from tqdm.notebook import tqdm
with torch.no_grad():
    for idx, batch in tqdm(enumerate(test_loader)):
        # if idx >= 5:
        #     break
        x, y = _graph_prepare_batch(batch)
        #x.to('cuda')
        y_, sparsity, src_pe, _, _ = model(x.to('cuda'))
        src_pe = src_pe.detach()
        src_pes += src_pe

0it [00:00, ?it/s]

In [30]:
whole = list(original_test_dataset.values())
X = [] # the two embeddings
Y = [] # intermediate vocabulary
for ast_idx, ast_instance in tqdm(enumerate(whole)):
    for path in ast_instance:
        node1 = src_pes[ast_idx][int(path[0].split(':')[-1]) - 1]
        node2 = src_pes[ast_idx][int(path[2].split(':')[-1]) - 1]

        if path[1].split(':')[1] not in config.src_vocab.w2i.keys():
            continue
        tgt = config.src_vocab.w2i[path[1].split(':')[1]]
        
        X.append(torch.cat([node1, node2]))
        Y.append(tgt)
        
train_X = X[:int(len(X)*0.8)]
train_Y = Y[:int(len(X)*0.8)]
test_X = X[int(len(X)*0.8):]
test_Y = Y[int(len(X)*0.8):]
train_dataset = SyntheticDataset(train_X, train_Y)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataset = SyntheticDataset(test_X, test_Y)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

0it [00:00, ?it/s]

In [31]:
input_dim=256
hidden_dim=1024
out_dim=config.src_vocab.size()
model = MLP(input_dim, hidden_dim, out_dim)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-4)
num_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"num_param: {num_param}")

#BATCH_SIZE=64
model.train()
for epoch in range(50):
    loss_acc = 0
    for batch in train_loader:
        x, y = batch
        pred = model(x.to(device))
        loss = criterion(pred, y.to(device).squeeze())
        loss.backward()
        loss_acc += loss.item()
        optimizer.step()
        optimizer.zero_grad()
    if epoch%5 == 0:
        per_sample = loss_acc / len(train_loader) / BATCH_SIZE
        total_correct = 0
        model.eval()
        for batch in test_loader:
            x, y = batch
            pred = model(x.to(device))
            total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
        print(f'Epoch - {epoch}, Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')
        model.train()

num_param: 12612368
Epoch - 0, Accuracy: 0.567158579826355
Epoch - 5, Accuracy: 0.606336772441864
Epoch - 10, Accuracy: 0.6137441992759705
Epoch - 15, Accuracy: 0.61328125
Epoch - 20, Accuracy: 0.6109374761581421
Epoch - 25, Accuracy: 0.6127893328666687
Epoch - 30, Accuracy: 0.6154513955116272
Epoch - 35, Accuracy: 0.6135995388031006
Epoch - 40, Accuracy: 0.6155960559844971
Epoch - 45, Accuracy: 0.616348385810852


In [32]:
total_correct = 0
model.eval()
for batch in tqdm(test_loader):
    x, y = batch
    pred = model(x.to(device))
    total_correct += torch.sum(torch.argmax(pred, dim=-1) == y.squeeze().to('cuda'))
print(f'Total correct: {total_correct}')
print(f'Accuracy: {total_correct / len(test_loader) / BATCH_SIZE}')

  0%|          | 0/270 [00:00<?, ?it/s]

Total correct: 21303
Accuracy: 0.616406261920929
