<h1>Use siamese GNN to predict the similarity of two source codes</h1>

<h3>Import dependencies</h3>

In [251]:
from tree_sitter import Language, Parser
import tree_sitter_java as ts_java
import os
import pandas as pd
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn.functional as F, torch.nn as nn
from torch_geometric.data import Batch

<h2>Data preparation</h2>

<h3>Define constants</h3>

In [252]:
java_directory1 = './datasets/fire14-source-code-training-dataset/java'
java_directory2 = './datasets/ir_plag_preprocessed'
java_LANGUAGE = Language(ts_java.language())
parser = Parser(java_LANGUAGE)
csv_paths = ['./labels/fire14-labels.csv', './labels/ir_plag_labels.csv']

In [253]:
def load_csv(csv_path):
    df = pd.read_csv(csv_path)
    return df

<h3>Get AST</h3>

In [254]:
def parse_java_file(filepath):
    with open(filepath, 'r', encoding='utf8') as file:
        code = file.read()

    tree = parser.parse(bytes(code, "utf8"))
    root_node = tree.root_node

    nodes = []
    edges = []

    def traverse(node, parent_idx=None):
        idx = len(nodes)
        nodes.append(node.type)
        
        if parent_idx is not None:
            edges.append((parent_idx, idx))
        
        for child in node.children:
            traverse(child, idx)

    traverse(root_node)
    return nodes, edges

<h3>Build data for GNN</h3>

In [255]:
def build_global_vocab(java_directories, file_lists):
    all_node_types = set()

    for java_directory, file_list in zip(java_directories, file_lists):
        for file_name in file_list:
            file_path = os.path.join(java_directory, file_name)
            nodes, _ = parse_java_file(file_path)
            all_node_types.update(nodes)

    node_type_to_idx = {typ: idx for idx, typ in enumerate(sorted(all_node_types))}
    return node_type_to_idx



def create_node_features(nodes, node_type_to_idx):
    node_features = [node_type_to_idx[typ] for typ in nodes]
    return node_features

def create_graph_data(nodes, edges, node_features, embedding_layer):
    x = embedding_layer(torch.tensor(node_features))
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    data = Data(x=x, edge_index=edge_index)
    return data

In [256]:
class NodeEmbeddingLayer(nn.Module):
    def __init__(self, num_node_types, embedding_dim):
        super(NodeEmbeddingLayer, self).__init__()
        self.embeddings = nn.Embedding(num_node_types, embedding_dim)

    def forward(self, node_indices):
        return self.embeddings(node_indices)
    
embedding_dim = 16


In [257]:
def prepare_data_for_pairs(pairs_df, java_directory, node_type_to_idx, embedding_layer):
    data_pairs = []
    for idx, row in pairs_df.iterrows():
        file1, file2, label = row['id1'], row['id2'], row['plagio']

        file1_path = os.path.join(java_directory, file1)
        file2_path = os.path.join(java_directory, file2)

        nodes1, edges1 = parse_java_file(file1_path)
        nodes2, edges2 = parse_java_file(file2_path)

        node_features1 = create_node_features(nodes1, node_type_to_idx)
        node_features2 = create_node_features(nodes2, node_type_to_idx)

        data1 = create_graph_data(nodes1, edges1, node_features1, embedding_layer)
        data2 = create_graph_data(nodes2, edges2, node_features2, embedding_layer)

        data_pairs.append((data1, data2, label))
        
    return data_pairs


In [258]:
pairs_df1 = load_csv(csv_paths[0])
pairs_df2 = load_csv(csv_paths[1])

file_list1 = list(set(pairs_df1['id1'].tolist() + pairs_df1['id2'].tolist()))
file_list2 = list(set(pairs_df2['id1'].tolist() + pairs_df2['id2'].tolist()))

java_directories = [java_directory1, java_directory2]
file_lists = [file_list1, file_list2]

node_type_to_idx = build_global_vocab(java_directories, file_lists)
embedding_layer = NodeEmbeddingLayer(len(node_type_to_idx), embedding_dim)

data_pairs1 = prepare_data_for_pairs(pairs_df1, java_directory1, node_type_to_idx, embedding_layer)
data_pairs2 = prepare_data_for_pairs(pairs_df2, java_directory2, node_type_to_idx, embedding_layer)

In [259]:
print("Data preparation complete.")
print(f"Number of pairs in dataset 1: {len(data_pairs1)}")
print(f"Number of pairs in dataset 2: {len(data_pairs2)}")

data1, data2, label1 = data_pairs1[0]
data3, data4, label2 = data_pairs2[0]

print(f"Dataset 1 - First pair:")
print(f"  Graph 1: {data1.num_nodes} nodes, {data1.num_edges} edges")
print(f"  Graph 2: {data2.num_nodes} nodes, {data2.num_edges} edges")
print(f"  Label: {label1}")

print(f"Dataset 2 - First pair:")
print(f"  Graph 1: {data3.num_nodes} nodes, {data3.num_edges} edges")
print(f"  Graph 2: {data4.num_nodes} nodes, {data4.num_edges} edges")
print(f"  Label: {label2}")


Data preparation complete.
Number of pairs in dataset 1: 504
Number of pairs in dataset 2: 460
Dataset 1 - First pair:
  Graph 1: 2068 nodes, 2067 edges
  Graph 2: 619 nodes, 618 edges
  Label: 0
Dataset 2 - First pair:
  Graph 1: 108 nodes, 107 edges
  Graph 2: 109 nodes, 108 edges
  Label: 1


<h2>Model</h2>

<h3>Build GNN siamese architecture</h3>

In [260]:
class GNNEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GNNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        return x

class SiameseNetwork(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(SiameseNetwork, self).__init__()
        self.encoder = GNNEncoder(in_channels, hidden_channels, out_channels)

    def forward(self, data1, data2):
        h1 = self.encoder(data1.x, data1.edge_index, data1.batch)
        h2 = self.encoder(data2.x, data2.edge_index, data2.batch)
        return h1, h2

def contrastive_loss(h1, h2, label, margin=1.0):
    distance = F.pairwise_distance(h1, h2)
    loss = (label * torch.pow(distance, 2) + 
           (1 - label) * torch.pow(F.relu(margin - distance), 2))
    return loss.mean()

def collate_fn(pairs, device):
    data1_list, data2_list, labels = [], [], []
    for d1, d2, label in pairs:
        data1_list.append(d1)
        data2_list.append(d2)
        labels.append(label)

    batch1 = Batch.from_data_list(data1_list).to(device)
    batch2 = Batch.from_data_list(data2_list).to(device)
    labels = torch.tensor(labels, dtype=torch.float, device=device).to(device)

    return batch1, batch2, labels

<h3>Training</h3>

In [261]:
def train(model, optimizer, data_pairs, device, epochs=10, batch_size=32, threshold=1.0):
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        correct = 0
        total = 0

        for i in range(0, len(data_pairs), batch_size):
            batch_pairs = data_pairs[i:i+batch_size]

            batch1, batch2, labels = collate_fn(batch_pairs, device)

            optimizer.zero_grad()
            h1, h2 = model(batch1, batch2)
            loss = contrastive_loss(h1, h2, labels)
            loss.backward(retain_graph=True)
            optimizer.step()

            total_loss += loss.item()

            distances = F.pairwise_distance(h1, h2)
            predictions = (distances < threshold).float()
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

        accuracy = correct / total if total > 0 else 0

        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}, Accuracy: {accuracy*100:.2f}%")

In [262]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [263]:
embedding_dim = 16
hidden_dim = 32
out_dim = 32

model = SiameseNetwork(
    in_channels=embedding_dim,
    hidden_channels=hidden_dim,
    out_channels=out_dim
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train(model, optimizer, data_pairs1, device, epochs=20, batch_size=16)

# Save the model
torch.save(model.state_dict(), './models/siamese_gnn_model.pth')

Epoch 1, Loss: 8.2825, Accuracy: 19.25%
Epoch 2, Loss: 4.6771, Accuracy: 33.93%
Epoch 3, Loss: 3.6900, Accuracy: 50.40%
Epoch 4, Loss: 3.3347, Accuracy: 58.53%
Epoch 5, Loss: 3.0904, Accuracy: 60.91%
Epoch 6, Loss: 2.8860, Accuracy: 61.90%
Epoch 7, Loss: 2.7075, Accuracy: 63.49%
Epoch 8, Loss: 2.5509, Accuracy: 63.89%
Epoch 9, Loss: 2.4139, Accuracy: 64.48%
Epoch 10, Loss: 2.2965, Accuracy: 65.48%
Epoch 11, Loss: 2.1975, Accuracy: 66.27%
Epoch 12, Loss: 2.1141, Accuracy: 66.87%
Epoch 13, Loss: 2.0437, Accuracy: 67.46%
Epoch 14, Loss: 1.9829, Accuracy: 68.65%
Epoch 15, Loss: 1.9292, Accuracy: 68.65%
Epoch 16, Loss: 1.8808, Accuracy: 69.25%
Epoch 17, Loss: 1.8360, Accuracy: 69.44%
Epoch 18, Loss: 1.7950, Accuracy: 70.24%
Epoch 19, Loss: 1.7567, Accuracy: 70.44%
Epoch 20, Loss: 1.7215, Accuracy: 70.63%
