# Load DNS and mDNS datasets

In [7]:
import os
import sys

import numpy as np
import pandas as pd

module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import torch
import src.temporal_loader_v2 as tl
from src.utils import to_homogeneous


In [8]:
cuda_device = 4

if torch.cuda.is_available():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.cuda.set_device(cuda_device)
    
torch.manual_seed(42)

<torch._C.Generator at 0x7f823441a790>

## Load Graphs

In [9]:
kg_path = lambda graph_name: f'../../data/{graph_name}'
dataset = tl.DNS(root=kg_path('DNS_2m'), start=0, end=6, test_list=[7], balance_gt=True, domain_file='domains2.csv')

Total labeled 897635
Labeled node count for 0, 6: 31778
After balancing labeled count: 31282
Labeled node count for 0, 7: 2610


#### DNS

In [10]:
data = dataset.train_data # training data
# data = to_homogeneous(dataset.train_data) # training data
# test_data = to_homogeneous(dataset.test_data[0])
dataset.train_data.metadata()

(['domain_node', 'ip_node'],
 [('domain_node', 'apex', 'domain_node'),
  ('domain_node', 'resolves', 'ip_node'),
  ('domain_node', 'similar', 'domain_node')])

#### GNN

In [11]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import Linear
from hgt_conv import HGTConv

class HGT(torch.nn.Module):
    def __init__(self, data, hidden_channels, out_channels, num_heads, num_layers, num_features=-1):
        super().__init__()

        self.lin_dict = torch.nn.ModuleDict()
        for node_type in data.node_types:
            self.lin_dict[node_type] = Linear(num_features, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, data.metadata(),
                           num_heads, group='sum')
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict, target_nodetype='domain_node'):
        for node_type, x in x_dict.items():
            x_dict[node_type] = self.lin_dict[node_type](x).relu_()

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)

        return self.lin(x_dict[target_nodetype])
    



In [12]:
cuda_device = 0
torch.manual_seed(42)
from src.utils import score

def train(model, data, optimizer):
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict, data.edge_index_dict)
    mask = data['domain_node'].train_mask
    loss = F.cross_entropy(out[mask], data['domain_node'].y[mask])
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test(model, data):
    model.eval()
    pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)

    accs = []
    for split in ['train_mask', 'val_mask']:
        mask = data['domain_node'][split]
        acc = (pred[mask] == data['domain_node'].y[mask]).sum() / mask.sum()
        accs.append(float(acc))
    return accs

def experiment(model,start,end,test_list, model_type):
    kg_path = lambda graph_name: f'../../data/{graph_name}'

    dataset = tl.DNS(root=kg_path('DNS_2m'), start=start, end=end, test_list=test_list, balance_gt=False, domain_file='domains2.csv')
    # data = to_homogeneous(dataset.train_data) # training data
    data = dataset.train_data # training data

    if torch.cuda.is_available():
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        torch.cuda.set_device(cuda_device)

        data, model = data.to(device), model.to(device)

    with torch.no_grad():  # Initialize lazy modules.
        out = model(data.x_dict, data.edge_index_dict)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)

    for epoch in range(0, 201):
        loss = train(model, data, optimizer)
        train_acc, val_acc = test(model,data)
        if epoch % 20 == 0:
            print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
                f'Val: {val_acc:.4f}')
        
    model.eval()
    for index, test_data in enumerate(dataset.test_data):
        # test_data = to_homogeneous(test_data)
        test_data = test_data.to(device)
        with torch.no_grad():
            pred = model(test_data.x_dict, test_data.edge_index_dict).argmax(dim=-1)
        mask = test_data['domain_node']['val_mask']
        scores = score(pred[mask],test_data['domain_node'].y[mask])
        with open("../../results_copy.csv", "a") as logger:
            logger.write("{},{},{},{},".format(model_type,start,end,index))
            logger.write(",".join(str(x) for x in scores.values()))
            logger.write('\n')


        for metric, val in scores.items():
            print(metric, ':{:.4f}'.format(val))
    
for model_type in ['hgt']:
    for i in range(5):
        # model_type='gcn'  
        # data.x.size(1) 
        model = HGT(data, hidden_channels=64, out_channels=2, num_heads=8,
                  num_layers=2)
        experiment(model,i,i+6,[i+7,i+8], model_type)

Total labeled 897635
Labeled node count for 0, 6: 31778
Labeled node count for 0, 7: 2610
Labeled node count for 0, 8: 2083
Epoch: 000, Loss: 0.6959, Train: 0.5134, Val: 0.5194
Epoch: 020, Loss: 0.5720, Train: 0.7262, Val: 0.7223
Epoch: 040, Loss: 0.5300, Train: 0.7676, Val: 0.7613
Epoch: 060, Loss: 0.4711, Train: 0.7658, Val: 0.7583
Epoch: 080, Loss: 0.4396, Train: 0.7883, Val: 0.7789
Epoch: 100, Loss: 0.4278, Train: 0.8031, Val: 0.7929
Epoch: 120, Loss: 0.4188, Train: 0.8058, Val: 0.7959
Epoch: 140, Loss: 0.4102, Train: 0.8124, Val: 0.8028
Epoch: 160, Loss: 0.4023, Train: 0.8169, Val: 0.8101
Epoch: 180, Loss: 0.3959, Train: 0.8173, Val: 0.8107
Epoch: 200, Loss: 0.3908, Train: 0.8182, Val: 0.8101
tn, fp, fn, tp 1369 236 175 830
acc :0.8425
f1 :0.8433
auc :0.8394
prec :0.7786
recall :0.8259
fpr :0.1470
mi_f1 :0.8425
ma_f1 :0.8355
tn, fp, fn, tp 1122 187 150 624
acc :0.8382
f1 :0.8389
auc :0.8317
prec :0.7694
recall :0.8062
fpr :0.1429
mi_f1 :0.8382
ma_f1 :0.8284
Total labeled 897635
La