In [50]:
import sys
sys.path.insert(1, '/home/sam/Documents/network/supernode/BREC_test/')

In [51]:
import torch

from concepts.concepts import *
from concepts.transformations import AddSupernodesHeteroMulti

In [52]:
import networkx as nx
import torch
from torch_geometric.data import Data, HeteroData
from torch_geometric.transforms import BaseTransform
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.utils import from_networkx, to_networkx

class AddSupernodesHeteroMulti(BaseTransform):
    def __init__(self, concepts_list) -> None:
        self.concepts_list = concepts_list

    def forward(self, data: Data) -> HeteroData:
        data_with_supernodes = HeteroData({
            'normal'    : {'x' : data.x.float()},
            ('normal', 'orig', 'normal'  )   : { 'edge_index': data.edge_index, 'edge_attr' : data.edge_attr},
#            ('normal', 'orig', 'normal'  )   : { 'edge_index': data.edge_index},
        })
        t1 = torch.arange(data.x.shape[0])
        data_with_supernodes['normal', 'identity', 'normal'].edge_index = torch.stack([t1, t1], dim=0).long()

        G = to_networkx(data, to_undirected=True, node_attrs=["x"])

        # find all the concepts in the graph on the original graph only
        for concept in self.concepts_list:
            concept_name = concept["name"]
            comp = concept["fun"](G, *concept["args"])
            if len(comp) != 0:
                current_supernode = 0
                from_normal = []
                to_sup      = []
                supnodes    = []
                for concept in comp:
                    supnodes.append(current_supernode)
                    for node in concept:
                        from_normal.append(node)
                        to_sup.append(current_supernode)
                    current_supernode += 1

                toSup_edges = torch.Tensor((from_normal, to_sup)).long()
                toNor_edges = torch.Tensor((to_sup, from_normal)).long()
                #data_with_supernodes[concept_name].x = torch.zeros(len(comp), data.num_features)
                data_with_supernodes[concept_name].x = torch.ones(len(comp), data.num_features)
                data_with_supernodes['normal', 'toSup', concept_name].edge_index = toSup_edges
                data_with_supernodes[concept_name, 'toNor', 'normal'].edge_index = toNor_edges
                t2 = torch.arange(len(comp))
                data_with_supernodes[concept_name, 'identity', concept_name].edge_index = torch.stack([t2, t2], dim=0).long()
            else:
                data_with_supernodes[concept_name].x = torch.zeros(1, data.num_features)

        return data_with_supernodes


In [65]:
import torch
from torch_geometric.nn import MLP, global_add_pool, HeteroConv, SimpleConv, GATConv, GINConv
from torch_geometric.nn import HGTConv, Linear

def get_HGAT_multi_simple(args, supnodes_name, dropout=0.5, hidden_channels=64,
                   num_layers=4, out_channels=16):
    SConv_dict = {
            ('normal', 'identity', 'normal'): SimpleConv('add'),
            }
    for supnode_type in supnodes_name:
        SConv_dict |= {('normal', 'toSup', supnode_type) : SimpleConv('add')}


    SConv = HeteroConv(SConv_dict, aggr='sum')

    HConvs = torch.nn.ModuleList()
    for _ in range(num_layers):
        Conv_dict = {("normal", "orig", "normal") : GATConv((-1, -1), hidden_channels, add_self_loops=True)}

        for supnode_type in supnodes_name:
            Conv_dict |= {("normal", "toSup", supnode_type) : SimpleConv('add'),
                          (supnode_type, "toNor", "normal") : GATConv((-1, -1), hidden_channels, add_self_loops=False)}

        conv = HeteroConv(Conv_dict, aggr='sum')
        HConvs.append(conv)

    class HGAT_simple_multi(torch.nn.Module):
        def __init__(self):
            super(HGAT_simple_multi, self).__init__()
            self.supinit = SConv
            self.convs = HConvs
            self.readout = global_add_pool
            self.classifier = MLP([hidden_channels, hidden_channels, out_channels],
                                   norm="batch_norm", dropout=dropout)

        def forward(self, data):
            x_dict, edge_index_dict = (data.x_dict, data.edge_index_dict)
            x_dict = self.supinit(x_dict, edge_index_dict)

            print("here", x_dict.keys())
            for conv in self.convs:
                x_dict = conv(x_dict, edge_index_dict)
                x_dict = {key: x.relu() for key, x in x_dict.items()}

            return x_dict


    model = HGAT_simple_multi()
    return model

In [66]:
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_networkx

dataset = TUDataset(root="./dataset/TUD", name='MUTAG')

concepts_list_ex = [
       {"name": "GCB", "fun": cycle_basis, "args": []},
       {"name": "GMC", "fun": max_cliques, "args": []},
       {"name": "GLP2", "fun": line_paths, "args": []}
    ]

print("cliques found: ", len(max_cliques(to_networkx(dataset[0], to_undirected=True))))

supnodes_name = [concept['name'] for concept in concepts_list_ex]
datasetT = dataset.transform = AddSupernodesHeteroMulti(concepts_list_ex)

loader = DataLoader(dataset, batch_size=1)
data = next(iter(loader))
print(data)
print(data['GMC'].x)
model = get_HGAT_multi_simple(None, supnodes_name)

print(data.x_dict.keys())
model(data)

cliques found:  0
HeteroDataBatch(
  normal={
    x=[17, 7],
    batch=[17],
    ptr=[2],
  },
  GCB={
    x=[3, 7],
    batch=[3],
    ptr=[2],
  },
  GMC={
    x=[1, 7],
    batch=[1],
    ptr=[2],
  },
  GLP2={
    x=[4, 7],
    batch=[4],
    ptr=[2],
  },
  (normal, orig, normal)={
    edge_index=[2, 38],
    edge_attr=[38, 4],
  },
  (normal, identity, normal)={ edge_index=[2, 17] },
  (normal, toSup, GCB)={ edge_index=[2, 18] },
  (GCB, toNor, normal)={ edge_index=[2, 18] },
  (GCB, identity, GCB)={ edge_index=[2, 3] },
  (normal, toSup, GLP2)={ edge_index=[2, 17] },
  (GLP2, toNor, normal)={ edge_index=[2, 17] },
  (GLP2, identity, GLP2)={ edge_index=[2, 4] }
)
tensor([[0., 0., 0., 0., 0., 0., 0.]])
dict_keys(['normal', 'GCB', 'GMC', 'GLP2'])
here dict_keys(['normal', 'GCB', 'GLP2'])


{'normal': tensor([[14.2489,  0.0000,  8.4846,  ...,  7.1229,  0.0000,  0.0000],
         [14.2490,  0.0000,  8.4907,  ...,  7.1360,  0.0000,  0.0000],
         [14.1447,  0.0000,  8.3263,  ...,  6.9013,  0.0000,  0.0000],
         ...,
         [ 0.3303,  0.0000,  0.2025,  ...,  0.6132,  0.5468,  0.0000],
         [ 0.2340,  0.0806,  0.1751,  ...,  0.5406,  0.3612,  0.0000],
         [ 0.2340,  0.0806,  0.1751,  ...,  0.5406,  0.3612,  0.0000]],
        grad_fn=<ReluBackward0>),
 'GCB': tensor([[4.1627e+01, 0.0000e+00, 2.3845e+01, 1.6031e+01, 3.4333e+01, 0.0000e+00,
          6.9825e+01, 1.3890e+00, 1.0975e+01, 5.0513e+00, 3.3483e+01, 0.0000e+00,
          1.3676e+01, 0.0000e+00, 0.0000e+00, 1.3857e+00, 0.0000e+00, 1.9050e+01,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 1.5465e+01, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 1.8609e+00, 3.8395e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          5.5033e-01, 2.6717e+01, 0.0000e+00, 7.6281e+00, 9.0239e+01, 0.0000e+00,
          0.0000

In [48]:
def get_HGIN_multi_simple(args, supnodes_name, dropout=0.5, hidden_channels=64,
                   num_layers=4, out_channels=16):
    SConv_dict = {
            ('normal', 'identity', 'normal'): SimpleConv('add'),
            }
    for supnode_type in supnodes_name:
        SConv_dict |= {('normal', 'toSup', supnode_type) : SimpleConv('add')}


    SConv = HeteroConv(SConv_dict, aggr='sum')

    HConvs = torch.nn.ModuleList()
    for _ in range(num_layers):
        Conv_dict = {("normal", "orig", "normal") :  GINConv(MLP([-1, hidden_channels, hidden_channels]))}

        for supnode_type in supnodes_name:
            Conv_dict |= {("normal", "toSup", supnode_type) : SimpleConv('add'),
                          (supnode_type, "toNor", "normal") : GINConv(MLP([-1, hidden_channels, hidden_channels]))}

        conv = HeteroConv(Conv_dict, aggr='sum')
        HConvs.append(conv)

    class HGIN_simple_multi(torch.nn.Module):
        def __init__(self):
            super(HGIN_simple_multi, self).__init__()
            self.supinit = SConv
            self.convs = HConvs
            self.readout = global_add_pool
            self.classifier = MLP([hidden_channels, hidden_channels, out_channels],
                                   norm="batch_norm", dropout=dropout)

        def forward(self, data):
            x_dict, edge_index_dict, batch_dict = (data.x_dict, data.edge_index_dict, data.collect('batch'))
            x_dict = self.supinit(x_dict, edge_index_dict)

            for conv in self.convs:
                x_dict = conv(x_dict, edge_index_dict)
                x_dict = {key: x.relu() for key, x in x_dict.items()}
            
            return x_dict

    model = HGIN_simple_multi()
    return model


In [49]:
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_networkx

dataset = TUDataset(root="./dataset/TUD", name='MUTAG')

concepts_list_ex = [
       {"name": "GCB", "fun": cycle_basis, "args": []},
       {"name": "GMC", "fun": max_cliques, "args": []},
       {"name": "GLP2", "fun": line_paths, "args": []}
    ]

print("cliques found: ", len(max_cliques(to_networkx(dataset[0], to_undirected=True))))

supnodes_name = [concept['name'] for concept in concepts_list_ex]
datasetT = dataset.transform = AddSupernodesHeteroMulti(concepts_list_ex)

loader = DataLoader(dataset, batch_size=1)
data = next(iter(loader))
print(data)
print(data['GMC'].x)
model = get_HGIN_multi_simple(None, supnodes_name)

print(data.x_dict.keys())
model(data)

cliques found:  0
HeteroDataBatch(
  normal={
    x=[17, 7],
    batch=[17],
    ptr=[2],
  },
  GCB={
    x=[3, 7],
    batch=[3],
    ptr=[2],
  },
  GMC={
    x=[1, 7],
    batch=[1],
    ptr=[2],
  },
  GLP2={
    x=[4, 7],
    batch=[4],
    ptr=[2],
  },
  (normal, orig, normal)={
    edge_index=[2, 38],
    edge_attr=[38, 4],
  },
  (normal, identity, normal)={ edge_index=[2, 17] },
  (normal, toSup, GCB)={ edge_index=[2, 18] },
  (GCB, toNor, normal)={ edge_index=[2, 18] },
  (GCB, identity, GCB)={ edge_index=[2, 3] },
  (normal, toSup, GLP2)={ edge_index=[2, 17] },
  (GLP2, toNor, normal)={ edge_index=[2, 17] },
  (GLP2, identity, GLP2)={ edge_index=[2, 4] }
)
tensor([[0., 0., 0., 0., 0., 0., 0.]])
dict_keys(['normal', 'GCB', 'GMC', 'GLP2'])


RuntimeError: The size of tensor a (7) must match the size of tensor b (64) at non-singleton dimension 1