In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GatedGraphRecurrentLayer(nn.Module):
    def __init__(self, hidden_dim, edge_types):
        super(GatedGraphRecurrentLayer, self).__init__()
        self.hidden_dim = hidden_dim
        self.edge_types = edge_types
        self.weight_matrices = nn.ModuleList([
            nn.Linear(hidden_dim, hidden_dim, bias=True) for _ in range(edge_types)
        ])
        self.gru = nn.GRUCell(hidden_dim, hidden_dim)

    def forward(self, node_features, adjacency_matrices, num_steps):
        h = node_features
        for _ in range(num_steps):
            aggregated_messages = []
            for i, adjacency_matrix in enumerate(adjacency_matrices):
                weighted_features = self.weight_matrices[i](h)
                aggregated_message = torch.matmul(adjacency_matrix.T, weighted_features)
                aggregated_messages.append(aggregated_message)
            combined_message = torch.stack(aggregated_messages).sum(dim=0)
            h = self.gru(combined_message, h)

        return h


In [None]:
class ConvLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_conv_layers):
        super(ConvLayer, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_conv_layers = num_conv_layers
        self.conv_layers = nn.ModuleList([
            nn.Conv1d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=1) for _ in range(num_conv_layers)
        ])
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, node_features):
        node_features = node_features.unsqueeze(0).permute(0, 2, 1)
        for conv in self.conv_layers:
            node_features = F.relu(conv(node_features))
            node_features = F.max_pool1d(node_features, kernel_size=node_features.size(2))
        graph_features = node_features.view(-1)
        output = self.mlp(graph_features)
        return output

In [3]:
class GraphClassificationModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, edge_types, num_conv_layers, num_gru_steps):
        super(GraphClassificationModel, self).__init__()
        self.input_linear = nn.Linear(input_dim, hidden_dim)
        self.ggrl = GatedGraphRecurrentLayer(hidden_dim, edge_types)
        self.conv = ConvLayer(hidden_dim, hidden_dim, num_conv_layers)
        self.num_gru_steps = num_gru_steps

    def forward(self, node_features, adjacency_matrices):
        node_features = self.input_linear(node_features)
        updated_node_features = self.ggrl(node_features, adjacency_matrices, self.num_gru_steps)
        output = self.conv(updated_node_features)
        return output

In [4]:
import os
import torch
from torch_geometric.data import Data
from torch.utils.data import Dataset
import pickle

class Dataset(Dataset):
    def __init__(self, pkl_dir):
        self.pkl_dir = pkl_dir
        self.file_list = os.listdir(pkl_dir)
        self.data_list = []
        for filename in self.file_list:
            batch_file = os.path.join(pkl_dir, filename)
            with open(batch_file, 'rb') as f:
                batch_data = pickle.load(f)
                self.data_list.extend(batch_data)
    
    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        return self.data_list[idx]

In [None]:
from sklearn.model_selection import train_test_split

pkl_dir = 'data/emb'
dataset = Dataset(pkl_dir)
train_indices, test_indices = train_test_split(
    list(range(len(dataset))), test_size=0.2, random_state=42
)

train_dataset = torch.utils.data.Subset(dataset, train_indices)
test_dataset = torch.utils.data.Subset(dataset, test_indices)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

input_dim = 768 + 8
hidden_dim = 128
edge_types = 3
num_conv_layers = 5
num_gru_steps = 5

num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GraphClassificationModel(input_dim, hidden_dim, edge_types, num_conv_layers, num_gru_steps)
model.to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    train_correct = 0
    train_total = 0

    for idx in range(len(train_dataset)):
        optimizer.zero_grad()
        sample = train_dataset[idx]
        node_features = sample.x
        label = sample.y.float()
        adjacency_matrices = [sample.ast_adj_matrix, sample.cfg_adj_matrix, sample.pdg_adj_matrix]

        node_features = node_features.to(device)
        label = label.to(device)
        adjacency_matrices = [adj.to(device).float() for adj in adjacency_matrices]
        output = model(node_features, adjacency_matrices)

        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        pred = torch.sigmoid(output).item() >= 0.5
        correct = pred == label.item()
        train_correct += int(correct)
        train_total += 1

    train_acc = train_correct / train_total
    
    model.eval()
    test_loss = 0
    test_correct = 0
    test_total = 0

    with torch.no_grad():
        for idx in range(len(test_dataset)):
            sample= test_dataset[idx]
            node_features = sample.x
            label = sample.y.float()
            adjacency_matrices = [sample.ast_adj_matrix, sample.cfg_adj_matrix, sample.pdg_adj_matrix]

            node_features = node_features.to(device)
            label = label.to(device)
            adjacency_matrices = [adj.to(device).float() for adj in adjacency_matrices]
            output = model(node_features, adjacency_matrices)

            loss = criterion(output, label)
            test_loss += loss.item()

            pred = torch.sigmoid(output).item() >= 0.5
            correct = pred == label.item()
            test_correct += int(correct)
            test_total += 1

    test_acc = test_correct / test_total

    print(f'Epoch {epoch+1}/{num_epochs}, '
          f'Train Loss: {train_loss/train_total:.4f}, Train Acc: {train_acc:.4f}, '
          f'Test Loss: {test_loss/test_total:.4f}, Test Acc: {test_acc:.4f}')


Epoch 1/10, Train Loss: 0.6912, Train Acc: 0.5563, Test Loss: 0.7059, Test Acc: 0.5200
Epoch 2/10, Train Loss: 0.6864, Train Acc: 0.5637, Test Loss: 0.7129, Test Acc: 0.5200
Epoch 3/10, Train Loss: 0.6874, Train Acc: 0.5637, Test Loss: 0.7048, Test Acc: 0.5200
Epoch 4/10, Train Loss: 0.6865, Train Acc: 0.5637, Test Loss: 0.7005, Test Acc: 0.5200
Epoch 5/10, Train Loss: 0.6864, Train Acc: 0.5625, Test Loss: 0.7032, Test Acc: 0.5200
Epoch 6/10, Train Loss: 0.6862, Train Acc: 0.5637, Test Loss: 0.7008, Test Acc: 0.5200
Epoch 7/10, Train Loss: 0.6859, Train Acc: 0.5637, Test Loss: 0.6975, Test Acc: 0.5200
Epoch 8/10, Train Loss: 0.6858, Train Acc: 0.5637, Test Loss: 0.6975, Test Acc: 0.5200
Epoch 9/10, Train Loss: 0.6857, Train Acc: 0.5637, Test Loss: 0.6989, Test Acc: 0.5200
Epoch 10/10, Train Loss: 0.6859, Train Acc: 0.5637, Test Loss: 0.6983, Test Acc: 0.5200
