In [None]:
import dgl
import torch
from tqdm import tqdm
import dgl.function as fn
import torch.nn as nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import time

In [2]:
benign_graphs, benign_labels = dgl.load_graphs("benign.bin")
malicious_graphs, malicious_labels = dgl.load_graphs("malicious_aftermp.bin")

In [3]:
benign_graphs = [g.to(device) for g in benign_graphs]
malicious_graphs = [g.to(device) for g in malicious_graphs]

In [4]:
class GraphConvolution(nn.Module):
    def __init__(self, graph, input_features, hidden_features):
        super(GraphConvolution, self).__init__()
        self.graph = graph
        self.weight = nn.ModuleDict({
            f'{srctype}-{etype}-{dsttype}': nn.Linear(input_features[srctype], hidden_features[(srctype, etype, dsttype)])
            for srctype, etype, dsttype in graph.canonical_etypes
        })

    def forward(self, feature_dictionary):
        g = self.graph
        funcs = {}

        for srctype, etype, dsttype in g.canonical_etypes:
            key = f'{srctype}-{etype}-{dsttype}'
            Wh = self.weight[key](feature_dictionary[srctype])
            g.nodes[srctype].data['h'] = Wh
            funcs[(srctype, etype, dsttype)] = fn.copy_u(f'h', 'm'), fn.mean('m', 'h')

        g.multi_update_all(funcs, 'sum')
            



In [5]:
def multi_round_message_passing(graph, convolution, rounds):
    for _ in range(rounds):
        feature_dictionary = {ntype: graph.nodes[ntype].data['h'] for ntype in graph.ntypes}
        convolution.forward(feature_dictionary)


In [None]:
for g in tqdm(benign_graphs, desc="Message Passing for Benign", unit="Graphs"):
    input_features = {ntype: 128 for ntype in g.ntypes}
    hidden_features = {(srctype, etype, dsttype): 128 for srctype, etype, dsttype in g.canonical_etypes}
    feature_dictionary = {ntype: g.nodes[ntype].data['h'] for ntype in g.ntypes}
    convolution = GraphConvolution(g, input_features, hidden_features)
    convolution.to(device)
    multi_round_message_passing(g, convolution, 3)

dgl.save_graphs('benign_aftermp.bin', benign_graphs, benign_labels)

In [None]:
for g in tqdm(malicious_graphs, desc="Message Passing for Malicious", unit="Graphs"):
    input_features = {ntype: 128 for ntype in g.ntypes}
    hidden_features = {(srctype, etype, dsttype): 128 for srctype, etype, dsttype in g.canonical_etypes}
    feature_dictionary = {ntype: g.nodes[ntype].data['h'] for ntype in g.ntypes}
    convolution = GraphConvolution(g, input_features, hidden_features)
    convolution.to(device)
    multi_round_message_passing(g, convolution, 1)



In [None]:
dgl.save_graphs('malicious_aftermp.bin', malicious_graphs, malicious_labels)