In [None]:
from typing import *
import torch
from torch import nn
from torch.nn import functional as F

from modules import HiFNode, HiFSubgraph, HiFGraph
from models import *


In [None]:
dataset_name = 'dblp_conf'
N = {
    'cora': 5,
    'citeseer': 5,
    'pubmed': 5,
    'dblp_conf': 20,
    'dblp_org': 8
}[dataset_name]

In [None]:
clients = {i: torch.load(f"../data/{dataset_name}/{i}_clients.pt") for i in ([-1] + list(range(1, N + 1)))}
clients


In [None]:
classes = set()
client_classes = [set(client.y.tolist()) for client in clients.values()]
for i in client_classes:
    classes = classes | i
num_classes = len(classes)


In [None]:
feature_dim = clients[-1].x.size(1)
hidden_dim = 64
graph = HiFGraph(global_models=get_module_dict(feature_dim, hidden_dim, num_classes))
subgraph_dict: Dict[int, HiFSubgraph] = {}
node_dict: Dict[int, HiFNode] = {}

# intra client
for client_id in range(1, N + 1):

    # subgraph
    subgraph = HiFSubgraph(
        local_models=get_module_dict(feature_dim, hidden_dim, num_classes),
        num_classes=num_classes,
    )
    graph.add_subgraph(subgraph)
    subgraph_dict[client_id] = subgraph
    X: torch.Tensor = clients[client_id].x
    Y: torch.Tensor = clients[client_id].y
    Y_one_hot: torch.Tensor = F.one_hot(Y, num_classes=num_classes).float()
    index_orig: torch.Tensor = clients[client_id].index_orig.tolist()
    num_nodes, feature_dim = X.size()

    # nodes
    for i in range(num_nodes):
        node = HiFNode(
            raw_feature=X[i],
            label=Y[i],
            label_one_hot=Y_one_hot[i],
        )
        subgraph.add_hif_node(node)
        node_dict[index_orig[i]] = node

    # edges
    src_indices, tgt_indices = clients[client_id].edge_index
    for src_index, tgt_index in zip(src_indices.tolist(), tgt_indices.tolist()):
        src_node: HiFNode = node_dict[index_orig[src_index]]
        tgt_node: HiFNode = node_dict[index_orig[tgt_index]]
        src_node.add_edge(tgt_node)
        tgt_node.add_edge(src_node)

    # split dataset
    subgraph.split_nodes_set(
        train_mask=clients[client_id].train_mask,
        val_mask=clients[client_id].val_mask,
        test_mask=clients[client_id].test_mask,
    )

# cross client
src_indices, tgt_indices = clients[-1].edge_index
index_orig: torch.Tensor = clients[-1].index_orig.tolist()
for src_index, tgt_index in zip(src_indices.tolist(), tgt_indices.tolist()):
    src_node: HiFNode = node_dict[index_orig[src_index]]
    tgt_node: HiFNode = node_dict[index_orig[tgt_index]]
    src_node.add_edge(tgt_node)
    tgt_node.add_edge(src_node)


In [None]:
for epoch in range(100):
    print(f'Epoch {epoch}:')
    graph.global_train()
    graph.global_validate()
