# Load DNS and mDNS datasets

In [1]:
import os
import sys

import numpy as np
import pandas as pd

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import torch
import torch_geometric.transforms as T

from libs.loader import DNS
from libs.utils import score

In [2]:
cuda_device = 3

if torch.cuda.is_available():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.cuda.set_device(cuda_device)
    
torch.manual_seed(42)

<torch._C.Generator at 0x7f633d47d8d0>

## Load Graphs

In [3]:
kg_path = lambda graph_name: f'../data/{graph_name}'

#### mDNS

In [4]:
# dataset = DNS(root=kg_path('mDNS'), transform=T.Compose([T.NormalizeFeatures(), T.ToUndirected()]), balance_gt=True)
# data = dataset[0]
# data

#### DNS

In [5]:
dataset = DNS(root=kg_path('DNS'), transform=T.Compose([T.NormalizeFeatures(), T.ToUndirected()]), balance_gt=True)
data = dataset[0]
print(data['domain_node'].y.unique(return_counts=True))
print(data['domain_node'].train_mask.unique(return_counts=True))

Remove parallel edges: type
similar    50910
dtype: int64
(tensor([0, 1, 2]), tensor([  4386,  19953, 349136]))
(tensor([False,  True]), tensor([369089,   4386]))


In [6]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import Linear
from hgt.hgt_conv import HGTConv

class HGT(torch.nn.Module):
    def __init__(self, data, hidden_channels, out_channels, num_heads, num_layers, num_features=-1):
        super().__init__()

        self.lin_dict = torch.nn.ModuleDict()
        for node_type in data.node_types:
            self.lin_dict[node_type] = Linear(num_features, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, data.metadata(),
                           num_heads, group='sum')
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict, target_nodetype='domain_node'):
        for node_type, x in x_dict.items():
            x_dict[node_type] = self.lin_dict[node_type](x).relu_()

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)

        return self.lin(x_dict[target_nodetype])
    

model = HGT(data, hidden_channels=64, out_channels=2, num_heads=8,
                  num_layers=2)

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)

with torch.no_grad():  # Initialize lazy modules.
    out = model(data.x_dict, data.edge_index_dict)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)


def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict, data.edge_index_dict)
    mask = data['domain_node'].train_mask
    loss = F.cross_entropy(out[mask], data['domain_node'].y[mask])
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test():
    model.eval()
    with torch.no_grad():
        pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)

    accs = []
    for split in ['train_mask', 'val_mask', 'test_mask']:
        mask = data['domain_node'][split]
        acc = (pred[mask] == data['domain_node'].y[mask]).sum() / mask.sum()
        accs.append(float(acc))
    return accs


for epoch in range(1, 201):
    loss = train()
    train_acc, val_acc, test_acc = test()
    if epoch % 20 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
              f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')
    
model.eval()
with torch.no_grad():
    pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)
mask = data['domain_node']['test_mask']
scores = score(pred[mask],data['domain_node'].y[mask])
for metric, score in scores.items():
    print(metric, ':{:.2f}'.format(score))

Epoch: 020, Loss: 0.4275, Train: 0.7845, Val: 0.7840, Test: 0.7845
Epoch: 040, Loss: 0.3509, Train: 0.8548, Val: 0.8530, Test: 0.8590
Epoch: 060, Loss: 0.3219, Train: 0.8657, Val: 0.8661, Test: 0.8670
Epoch: 080, Loss: 0.3028, Train: 0.8792, Val: 0.8758, Test: 0.8746
Epoch: 100, Loss: 0.2915, Train: 0.8776, Val: 0.8684, Test: 0.8750
Epoch: 120, Loss: 0.2834, Train: 0.8867, Val: 0.8764, Test: 0.8791
Epoch: 140, Loss: 0.2763, Train: 0.8862, Val: 0.8678, Test: 0.8799
Epoch: 160, Loss: 0.2629, Train: 0.8919, Val: 0.8781, Test: 0.8867
Epoch: 180, Loss: 0.2583, Train: 0.8931, Val: 0.8803, Test: 0.8913
Epoch: 200, Loss: 0.2462, Train: 0.9036, Val: 0.8843, Test: 0.8951
tn, fp, fn, tp 1154 134 142 1201
tn, fp, fn, tp 1154 134 142 1201
acc :0.90
f1 :0.90
auc :0.90
prec :0.90
recall :0.89
fpr :0.10
