In [1]:
from transformers import AutoTokenizer, AutoModel
import torch

import importlib
from torch.utils.data import DataLoader
import torch.optim as optim
import pandas as pd
from tqdm import tqdm
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np
import xml.etree.ElementTree as ET
import GraphModelContrastivev5
importlib.reload(GraphModelContrastivev5)
from GraphModelContrastivev5 import SemanticFilter
from MethodInfo import MethodInfo
from MutantInfo import MutantInfo
import ChangeImpactDataBuilder
from ChangeImpactDataBuilder import ChangeImpactDataBuilder
from ChangeImpactMutantIndicesDataset import ChangeImpactMutantIndicesDataset
from ChangeImpactNodeIndicesDataset import ChangeImpactNodeIndicesDataset

import config
importlib.reload(config)
from config import *
import random
from util import my_collate_fn

In [2]:
device="cuda"
debug=False
mutant_batch_size_train=400
node_batch_size_train=7000

In [None]:
method_info=MethodInfo()
mutant_info=MutantInfo(method_info,debug=debug)

In [3]:
train_indices=np.load(model_state_home+f"{path_sep}train_indices.npy").tolist()

if debug:
    train_indices=[i for i in range(800)]

In [None]:
data_builder=ChangeImpactDataBuilder(method_info,mutant_info)

node_dataset=ChangeImpactNodeIndicesDataset(method_info)
node_loader_train=DataLoader(node_dataset,batch_size=node_batch_size_train,shuffle=True,collate_fn=my_collate_fn)

mutant_dataset_train=ChangeImpactMutantIndicesDataset(train_indices)
mutant_loader_train=DataLoader(mutant_dataset_train,batch_size=mutant_batch_size_train,shuffle=True,collate_fn=my_collate_fn)

In [3]:
def get_methods_from_csv(file_dir:str):
    methods_df=pd.read_csv(file_dir,encoding="utf-8")
    names=methods_df["method name"].tolist()
    bodies=methods_df["method body"].tolist()
    methods=[]
    for i in range(len(names)):
        methods.append(names[i]+str(bodies[i]))
    return methods

In [None]:
tokenizer=AutoTokenizer.from_pretrained("microsoft/codebert-base")
model=AutoModel.from_pretrained("microsoft/codebert-base").to(device)

In [7]:
def embed_data(raw_data):
    embeddings = []
    # max_length = tokenizer.model_max_length
    for text in raw_data:
        inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)

        with torch.no_grad():
            outputs = model(**inputs)
            last_hidden_state = outputs.last_hidden_state
            cls_embedding = last_hidden_state[:, 0, :]
            embeddings.append(cls_embedding.to("cpu"))

    return embeddings

In [None]:
raw_data=get_methods_from_csv(method_dir)
# raw_data=get_methods_from_callgraph(callgraph_dir)
print(type(raw_data))
print(len(raw_data))

In [None]:
splitline_raw=[]
splitline_idx=[]
start=0
end=0
for method in raw_data:
    splitline=method.split("\n")
    start=end
    end=start+len(splitline)
    splitline_raw+=splitline
    splitline_idx.append([start,end])
print(len(splitline_raw))
print(len(splitline_idx))

In [None]:
splitline_idx=torch.tensor(splitline_idx)
print(splitline_idx.shape)

In [13]:
# all_embeddings=embed_data(raw_data).squeeze(dim=1)
all_embeddings=embed_data(splitline_raw)

In [60]:
all_embed_list=[]
for i in range(splitline_idx.shape[0]):
    e = torch.stack(all_embeddings[splitline_idx[i][0].item():splitline_idx[i][1].item()]).squeeze(1)
    e.to(device)
    all_embed_list.append(e)

In [None]:
se_model=SemanticFilter(all_embeddings,splitline_idx)
se_model.initialize_model()
optimizer=optim.Adam(se_model.parameters(),lr=0.001)
se_model.to(device)

In [None]:
merge_embeddings=[]
n_method=splitline_idx.shape[0]

for epoch in range(40):
    if (epoch%5)==0:
        with torch.no_grad():
            cur_merge_embeddings=se_model.merge(all_embed_list)
            merge_embeddings.append(cur_merge_embeddings.cpu().numpy())
    epoch_loss=0
    for mutant_indexs in tqdm(mutant_loader_train):
        for node_indexs in node_loader_train:
            change_embeddings,node_embeddings,edge_indexs,node_predict_indexs,node_predict_labels,node_predict_types,node_change_indexs,node_mutant_predict_indexs,node_predict_indexs_origin,st_embeddings=data_builder.build_batch_data(mutant_indexs,node_indexs)
            change_embeddings=change_embeddings.to(device)
            st_embeddings=st_embeddings.to(device)
            node_embeddings=node_embeddings.to(device)
            edge_indexs=edge_indexs.to(device)

            info_loss=se_model.forward([all_embed_list[i.item()] for i in node_predict_indexs],device,node_predict_indexs,node_predict_labels,node_predict_types,node_change_indexs)
            epoch_loss+=info_loss.item()
            optimizer.zero_grad()
            info_loss.backward()
            optimizer.step()
            print(f"epoch {epoch} loss: {info_loss.item()}\n")

# for epoch in range(120):
#     if (epoch%20)==0:
#         with torch.no_grad():
#             cur_merge_embeddings=se_model.merge(all_embed_list)
#             merge_embeddings.append(cur_merge_embeddings.cpu().numpy())
#     start=0
#     end=0
#     batch_size=1000
#     pick_list=list(range(n_method))
#     random.shuffle(pick_list)
#     epoch_loss=0
#     while end < n_method:
#         end+=batch_size
#         info_loss=se_model.forward(pick_list[start:end],device=device)
#         epoch_loss+=info_loss.item()
#         start=end
#         optimizer.zero_grad()
#         info_loss.backward()
#         optimizer.step()
#     print(f"epoch {epoch} loss: {info_loss.item()}\n")

In [41]:
# merge_embeddings=[]
# for split in splitline_idx:
#     start=split[0]
#     end=split[1]
#     line_embeddings=all_embeddings[start:end]
#     merge_embedding=torch.sum(line_embeddings,dim=0)/(end-start)
#     merge_embeddings.append(merge_embedding)
# merge_embeddings=torch.stack(merge_embeddings)
# print(merge_embeddings.shape)

In [43]:
# all_embeddings=all_embeddings.numpy()

In [None]:
print(len(merge_embeddings))

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(20, 8))
axes = axes.flatten()

for i, embeddings in enumerate(merge_embeddings):
    pca = PCA(n_components=2)
    reduced = pca.fit_transform(embeddings)

    ax = axes[i]
    ax.scatter(reduced[:, 0], reduced[:, 1], s=1)
    ax.set_title(f"Matrix {i+1}")
    ax.set_xlabel("PC 1")
    ax.set_ylabel("PC 2")

plt.tight_layout()
plt.show()

In [None]:
# torch.save(splitline_idx,method_splitlins_dir)
# torch.save(merge_embeddings,method_embedding_dir)
torch.save(merge_embeddings[0],method_embedding_dir)
# torch.save(all_embeddings,method_embedding_dir)