In [None]:
import dgl
import torch
from tqdm import tqdm
import dgl.function as fn
import torch.nn as nn
from torch.utils.data import DataLoader
import dgl.nn as dglnn
import torch.nn.functional as F
from torch.optim import Adam
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
benign_graphs, benign_labels = dgl.load_graphs("benign.bin")
malicious_graphs, malicious_labels = dgl.load_graphs("malicious.bin")

In [3]:
graphs = benign_graphs+malicious_graphs
labels = torch.cat([benign_labels['labels'], malicious_labels['labels']])

dataset = list(zip(graphs, labels))

In [None]:
g = graphs[5]
print(g.ndata['h'])

In [5]:
def custom_collate_fn(batch):
    graphs, labels = zip(*batch)
    batched_graph = dgl.batch(graphs)
    batch_labels = torch.stack(labels)

    return batched_graph, batch_labels

In [6]:
from sklearn.model_selection import train_test_split

train, test = train_test_split(dataset, test_size=0.2, random_state=42)

train_dataloader = DataLoader(train, batch_size=1, shuffle=True, collate_fn=custom_collate_fn)
test_dataloader = DataLoader(test, batch_size=1, shuffle=False, collate_fn=custom_collate_fn)


In [13]:
class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='mean')
        
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='mean')
    
    def forward(self, graph, inputs):
        
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)

        return h

class HeteroClassifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes, rel_names):
        super().__init__()

        self.rgcn = RGCN(in_dim, hidden_dim, hidden_dim, rel_names)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        inputs = {ntype: g.nodes[ntype].data['h'] for ntype in g.ntypes}

        h = self.rgcn(g, inputs)

        with g.local_scope():
            for ntype in g.ntypes:
                if ntype in h:
                    g.nodes[ntype].data['h'] = h[ntype]
                else:
                    continue
            # Calculate graph representation by average readout.

            hg = None

            for ntype in g.ntypes:
                if hg is None:
                    hg = dgl.mean_nodes(g, 'h', ntype=ntype)
                else:
                    hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)
                
            return self.classify(hg)

In [14]:
unique_rel_names = set()

for g in graphs:
    unique_rel_names.update(g.etypes)

unique_rel_names = sorted(unique_rel_names)

In [15]:
from torch.optim import Adam
model = HeteroClassifier(32, 32, 2, unique_rel_names)
optimiser = Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for graph, label in train_dataloader:
        label = label.long()

        logits = model(graph)

        loss = loss_fn(logits, label)
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

        total_loss += loss.item()

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_dataloader)}')
    
    
        

Epoch 1/20, Loss: 0.6732132419943809
Epoch 2/20, Loss: 0.2707225933155669
Epoch 3/20, Loss: 0.15115901941095283
Epoch 4/20, Loss: 0.08752994238069726
Epoch 5/20, Loss: 0.15550163599788283
Epoch 6/20, Loss: 0.10215326090485352
Epoch 7/20, Loss: 0.10737856282532295
Epoch 8/20, Loss: 0.09697287421785691
Epoch 9/20, Loss: 0.07566454021243936
Epoch 10/20, Loss: 0.10460381269960309
Epoch 11/20, Loss: 0.052440187020162024
Epoch 12/20, Loss: 0.07175281691826721
Epoch 13/20, Loss: 0.02901981934497342
Epoch 14/20, Loss: 0.06837356890016019
Epoch 15/20, Loss: 0.035922021870017806
Epoch 16/20, Loss: 0.04403822421368653
Epoch 17/20, Loss: 0.048223911221109625
Epoch 18/20, Loss: 0.07065467952448631
Epoch 19/20, Loss: 0.025371484135782784
Epoch 20/20, Loss: 0.04040354458850115


In [18]:
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for graph, label in test_dataloader:

            logits = model(graph)
            preds = torch.argmax(logits, dim=1)
            print(f"Prediciction = {preds}, Label = {label}")

            correct += (preds == label).sum().item()
            total += len(label)

    accuracy = correct / total
    print(f"Test Accuracy: {accuracy * 100:.2f}%")

Prediciction = tensor([0]), Label = tensor([0.])
Prediciction = tensor([0]), Label = tensor([0.])
Prediciction = tensor([0]), Label = tensor([0.])
Prediciction = tensor([1]), Label = tensor([1.])
Prediciction = tensor([1]), Label = tensor([1.])
Prediciction = tensor([1]), Label = tensor([1.])
Prediciction = tensor([0]), Label = tensor([0.])
Prediciction = tensor([1]), Label = tensor([1.])
Prediciction = tensor([1]), Label = tensor([1.])
Prediciction = tensor([0]), Label = tensor([0.])
Prediciction = tensor([0]), Label = tensor([0.])
Prediciction = tensor([1]), Label = tensor([1.])
Prediciction = tensor([1]), Label = tensor([1.])
Prediciction = tensor([0]), Label = tensor([0.])
Prediciction = tensor([1]), Label = tensor([1.])
Prediciction = tensor([1]), Label = tensor([1.])
Prediciction = tensor([0]), Label = tensor([0.])
Prediciction = tensor([1]), Label = tensor([1.])
Prediciction = tensor([0]), Label = tensor([0.])
Prediciction = tensor([0]), Label = tensor([0.])
Prediciction = tenso

In [15]:
import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn

class GNN(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats):
        super(GNN, self).__init__()
        self.fc = nn.Linear(in_feats, hidden_feats)
        self.classify = nn.Linear(hidden_feats, out_feats)

    def forward(self, graph):

        with graph.local_scope():
            graph_feats = 0

            for ntype in graph.ntypes:
                graph_feats += dgl.mean_nodes(graph, 'h', ntype=ntype)

            h = self.fc(graph_feats)
            hg = F.relu(h)

            return self.classify(hg)

In [18]:


in_feats = 32
hidden_feats = 256
out_feats = 2

model = GNN(in_feats, hidden_feats, out_feats)
optimiser = Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

In [None]:
    num_epochs = 50

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for graph, label in train_dataloader:

            logits = model(graph)
            loss = loss_fn(logits, label.long())

            optimiser.zero_grad()
            loss.backward()
            optimiser.step()

            total_loss += loss.item()

        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_dataloader)}')

    # Evaluate
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for graph, label in test_dataloader:

            logits = model(graph)
            preds = torch.argmax(logits, dim=1)

            correct += (preds == label).sum().item()
            total += len(label)

    accuracy = correct / total
    print(f"Test Accuracy: {accuracy * 100:.2f}%")