In [1]:
import sys
sys.path.insert(1, '/home/sam/Documents/network/project/lightsupernode/')

In [2]:
import torch
from concepts.concepts import *
from concepts.transformations import AddSupernodesHetero
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv, Linear, GraphConv, SimpleConv, global_mean_pool
from torch.nn import Linear
from torch_geometric.loader import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root="./dataset/TUD", name='MUTAG')

In [4]:
train_loader = DataLoader(dataset[:0.8], 3, shuffle=False)
batch = next(iter(train_loader))
print(batch.edge_index)
print(len(batch.x))
print(global_mean_pool(batch.x, batch.batch))

conv = GraphConv(dataset.num_features, 10)
print(len(conv(batch.x, batch.edge_index)))

tensor([[ 0,  0,  1,  1,  2,  2,  3,  3,  3,  4,  4,  4,  5,  5,  6,  6,  7,  7,
          8,  8,  8,  9,  9,  9, 10, 10, 11, 11, 12, 12, 12, 13, 13, 14, 14, 14,
         15, 16, 17, 17, 18, 18, 19, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24,
         24, 24, 25, 25, 25, 26, 26, 27, 27, 27, 28, 29, 30, 30, 31, 31, 32, 32,
         32, 33, 33, 33, 34, 34, 35, 35, 36, 36, 37, 37, 38, 38, 38, 39, 39, 40,
         40, 40, 41, 42],
        [ 1,  5,  0,  2,  1,  3,  2,  4,  9,  3,  5,  6,  0,  4,  4,  7,  6,  8,
          7,  9, 13,  3,  8, 10,  9, 11, 10, 12, 11, 13, 14,  8, 12, 12, 15, 16,
         14, 14, 18, 26, 17, 19, 18, 20, 24, 19, 21, 20, 22, 21, 23, 22, 24, 19,
         23, 25, 24, 26, 27, 17, 25, 25, 28, 29, 27, 27, 31, 39, 30, 32, 31, 33,
         37, 32, 34, 38, 33, 35, 34, 36, 35, 37, 32, 36, 33, 39, 40, 30, 38, 38,
         41, 42, 40, 40]])
43
tensor([[0.8235, 0.0588, 0.1176, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6923, 0.1538, 0.1538, 0.0000, 0.0000, 0.0000, 0.0000],
  

In [5]:
concepts_list_ex = [
       {"name": "GCB", "fun": cycle_basis, "args": []},
       {"name": "GMC", "fun": max_cliques, "args": []},
      {"name": "GLP2", "fun": line_paths, "args": []}
    ]

dataset.transform = AddSupernodesHetero(concepts_list_ex)
data = dataset[0]

In [6]:
data

HeteroData(
  y=[1],
  normal={ x=[17, 7] },
  supernodes={ x=[7, 7] },
  (normal, orig, normal)={
    edge_index=[2, 38],
    edge_attr=[38, 4],
  },
  (normal, toSup, supernodes)={ edge_index=[2, 35] },
  (supernodes, toNor, normal)={ edge_index=[2, 35] },
  (normal, void, normal)={ edge_index=[2, 17] },
  (supernodes, void, supernodes)={ edge_index=[2, 7] }
)

In [7]:
import torch_geometric.transforms as T
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv, Linear, GraphConv, SimpleConv, global_mean_pool, MLP

class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                ('supernodes', 'toNor', 'normal'): GraphConv(-1, hidden_channels, add_self_loops=False),
                ('normal', 'toSup', 'supernodes'): SAGEConv((-1, -1), hidden_channels, add_self_loops=False),
               ('normal', 'orig', 'normal'): GATConv((-1, -1), hidden_channels, add_self_loops=True),
            }, aggr='sum')
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: x.relu() for key, x in x_dict.items()}
        return self.lin(x_dict['supernodes'])

model = HeteroGNN(hidden_channels=64, out_channels=dataset.num_classes,
                  num_layers=2)

out = model(data.x_dict, data.edge_index_dict)
print(data)
print(out)

HeteroData(
  y=[1],
  normal={ x=[17, 7] },
  supernodes={ x=[7, 7] },
  (normal, orig, normal)={
    edge_index=[2, 38],
    edge_attr=[38, 4],
  },
  (normal, toSup, supernodes)={ edge_index=[2, 35] },
  (supernodes, toNor, normal)={ edge_index=[2, 35] },
  (normal, void, normal)={ edge_index=[2, 17] },
  (supernodes, void, supernodes)={ edge_index=[2, 7] }
)
tensor([[-0.5770, -0.1057],
        [-0.6203, -0.1284],
        [-0.5564, -0.0937],
        [-0.5574, -0.0956],
        [-0.5564, -0.0937],
        [-0.6214, -0.1305],
        [-0.6204, -0.1282]], grad_fn=<AddmmBackward0>)


## LOADER

In [22]:
import hashlib
from torch_geometric.loader import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

concepts_list_ex = [
       {"name": "GCB", "fun": cycle_basis, "args": []},
       {"name": "GMC", "fun": max_cliques, "args": []},
       {"name": "GLP2", "fun": line_paths, "args": []}
    ]
path_name = ''.join(map(lambda x: x['name'] + str(x['args']), concepts_list_ex))
hash_name = hashlib.sha256(path_name.encode('utf-8')).hexdigest()

dataset = TUDataset("./dataset/MutagHetero"+hash_name, name="MUTAG",
                    transform=AddSupernodesHetero(concepts_list_ex))
train_loader = DataLoader(dataset[:0.8], 3)

Using device: cuda


In [23]:
data = next(iter(train_loader))
data.collect('batch')

{'normal': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
 'supernodes': tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2])}

## CONV

In [24]:
from typing import List, Optional, Union
import torch
from torch import Tensor
from torch_geometric.nn.conv import MessagePassing

class IdentityConv(MessagePassing):
    def __init__(self, aggr: str = 'add', **kwargs):
        super(IdentityConv, self).__init__(aggr=aggr, **kwargs)

    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        return self.propagate(edge_index, x=x)

    def message(self, x_j: Tensor) -> Tensor:
        return x_j

    def update(self, aggr_out: Tensor, x: Tensor) -> Tensor:
        return x

## MODEL

In [36]:
class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.supinit = HeteroConv({
            ('normal', 'void', 'normal'): SimpleConv('add'),
            ('supernodes', 'toSup', 'supernodes'): SimpleConv('add'),
        })

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
#                ('supernodes', 'toNor', 'normal'): GraphConv(-1, hidden_channels, add_self_loops=False),
#                ('normal', 'toSup', 'supernodes') : IdentityConv(),
#                ('normal', 'toSup', 'supernodes'): SAGEConv((-1, -1), hidden_channels, add_self_loops=False),
               ('supernodes', 'void', 'supernodes'): SimpleConv('add'), 
               ('normal', 'orig', 'normal'): GATConv((-1, -1), hidden_channels, add_self_loops=True),
            }, aggr='sum')
            self.convs.append(conv)

        self.readout = global_mean_pool


        self.mlp = MLP([hidden_channels, hidden_channels, out_channels],
           norm=None, dropout=0.5)

        #self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict, batch_dict):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: x.relu() for key, x in x_dict.items()}

        x_dict = self.supinit(x_dict, edge_index_dict)

        x_dict = {key: global_mean_pool(x_dict[key], batch_dict[key]) for key in x_dict.keys()}
        x = torch.stack(tuple(x_dict.values()), dim=0).sum(dim=0)

        x = self.mlp(x)
        
        return x

model = HeteroGNN(hidden_channels=64, out_channels=dataset.num_classes,
                  num_layers=2)

print(model)

HeteroGNN(
  (supinit): HeteroConv(num_relations=2)
  (convs): ModuleList(
    (0-1): 2 x HeteroConv(num_relations=2)
  )
  (mlp): MLP(64, 64, 2)
)


In [37]:
out = model(data.x_dict, data.edge_index_dict, data.collect('batch'))
print(out)

tensor([[-0.0193, -0.1412],
        [ 0.0004, -0.0802],
        [ 0.0356, -0.0968]], grad_fn=<AddmmBackward0>)


In [38]:
data

HeteroDataBatch(
  y=[3],
  normal={
    x=[43, 7],
    batch=[43],
    ptr=[4],
  },
  supernodes={
    x=[15, 7],
    batch=[15],
    ptr=[4],
  },
  (normal, orig, normal)={
    edge_index=[2, 94],
    edge_attr=[94, 4],
  },
  (normal, toSup, supernodes)={ edge_index=[2, 81] },
  (supernodes, toNor, normal)={ edge_index=[2, 81] },
  (normal, void, normal)={ edge_index=[2, 43] },
  (supernodes, void, supernodes)={ edge_index=[2, 15] }
)