In [1]:
import torch
import torch_geometric
from torch_geometric.datasets import MoleculeNet


In [2]:
def unsqueeze_y(data):
    data.y = data.y.squeeze(1).long()
    return data

In [3]:
dataset = MoleculeNet(root="./dataset/just", name='HIV')

In [4]:
print()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('===========================================================================================================')

# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')


Dataset: HIV(41127):
Number of graphs: 41127
Number of features: 9
Number of classes: 2

Data(x=[19, 9], edge_index=[2, 40], edge_attr=[40, 3], smiles='CCC1=[O+][Cu-3]2([O+]=C(CC)C1)[O+]=C(CC)CC(CC)=[O+]2', y=[1, 1])
Number of nodes: 19
Number of edges: 40
Average node degree: 2.11
Has isolated nodes: False
Has self-loops: False
Is undirected: True


In [5]:
import sys
sys.path.insert(1, '/home/sam/Documents/network/supernode/HIV_test/')

In [6]:
from data.transformation import AddSupernodes
from data.concepts import *

In [7]:
dataset = MoleculeNet(root="./dataset/MoleculeNety1", name='HIV', pre_transform=unsqueeze_y)

In [8]:
concepts_list = [
       {"name": "GCB", "fun": cycle_basis, "args": [], "features" : [2]},
       {"name": "GMC", "fun": max_cliques, "args": []},
    ]
data = AddSupernodes(concepts_list)(dataset[0])
data

Data(x=[21], edge_index=[2, 64], y=[1], ntype=[21], S=[21], edge_S=[64, 1])

In [9]:
data.y

tensor([0.])

# DATASET TR

# FILES

In [10]:
from collections.abc import Mapping
from typing import Any, List, Optional, Sequence, Union

import torch.utils.data
from torch.utils.data.dataloader import default_collate

from torch_geometric.data import Batch, Dataset
from torch_geometric.data.data import BaseData
from torch_geometric.data.datapipes import DatasetAdapter
from torch_geometric.typing import TensorFrame, torch_frame


class Collater:
    def __init__(
        self,
        dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
    ):
        print("zzzzz")
        self.dataset = dataset
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

    def __call__(self, batch: List[Any]) -> Any:
        print("xxxxx")
        elem = batch[0]
        if isinstance(elem, BaseData):
            print("0000")
            return Batch.from_data_list(
                batch,
                follow_batch=self.follow_batch,
                exclude_keys=self.exclude_keys,
            )
        elif isinstance(elem, torch.Tensor):
            print("a")
            return default_collate(batch)
        elif isinstance(elem, TensorFrame):
            print("b")
            return torch_frame.cat(batch, dim=0)
        elif isinstance(elem, float):
            print("c")
            return torch.tensor(batch, dtype=torch.float)
        elif isinstance(elem, int):
            print("d")
            return torch.tensor(batch)
        elif isinstance(elem, str):
            print("e")
            return batch
        elif isinstance(elem, Mapping):
            print("f")
            return {key: self([data[key] for data in batch]) for key in elem}
        elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
            print("g")
            return type(elem)(*(self(s) for s in zip(*batch)))
        elif isinstance(elem, Sequence) and not isinstance(elem, str):
            print("h")
            return [self(s) for s in zip(*batch)]

        raise TypeError(f"DataLoader found invalid type: '{type(elem)}'")


class DataLoader(torch.utils.data.DataLoader):
    r"""A data loader which merges data objects from a
    :class:`torch_geometric.data.Dataset` to a mini-batch.
    Data objects can be either of type :class:`~torch_geometric.data.Data` or
    :class:`~torch_geometric.data.HeteroData`.

    Args:
        dataset (Dataset): The dataset from which to load the data.
        batch_size (int, optional): How many samples per batch to load.
            (default: :obj:`1`)
        shuffle (bool, optional): If set to :obj:`True`, the data will be
            reshuffled at every epoch. (default: :obj:`False`)
        follow_batch (List[str], optional): Creates assignment batch
            vectors for each key in the list. (default: :obj:`None`)
        exclude_keys (List[str], optional): Will exclude each key in the
            list. (default: :obj:`None`)
        **kwargs (optional): Additional arguments of
            :class:`torch.utils.data.DataLoader`.
    """
    def __init__(
        self,
        dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],
        batch_size: int = 1,
        shuffle: bool = False,
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
        **kwargs,
    ):
        # Remove for PyTorch Lightning:
        kwargs.pop('collate_fn', None)

        # Save for PyTorch Lightning < 1.6:
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

        print("eeee")

        super().__init__(
            dataset,
            batch_size,
            shuffle,
            collate_fn=Collater(dataset, follow_batch, exclude_keys),
            **kwargs,
        )

In [11]:
from data.transformation import AddSupernodes
import torch_geometric.transforms as T


def squeeze_y(data):
    data.y = data.y.squeeze(1)
    data.smiles=None
    return data

concepts_list = [
       {"name": "GMC", "fun": cycle_basis, "args": []},
    ]


dataset = MoleculeNet("./dataset/testttaaa", name="HIV",
                      pre_transform=T.Compose(squeeze_y, AddSupernodes(concepts_list))
                     )
loader = DataLoader(dataset, 100,
                    shuffle=False, num_workers=0)

TypeError: Compose.__init__() takes 2 positional arguments but 3 were given

In [None]:
print(dataset[0])
print(dataset[1])

i = iter(loader)
for data in loader:
    print(data)

In [None]:
dataset[1].y

In [None]:
import torch
a = torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [None]:
a.long()

In [None]:
import torch
from torch_geometric.data import Data

# Assuming data1 and data2 are your Data objects
data1 = Data(x=torch.randn(19, 9), edge_index=torch.randint(0, 19, (2, 40)), y=torch.randn(1), ntype=torch.randn(19), S=torch.randn(19), edge_S=torch.randn(40, 1))
data2 = Data(x=torch.randn(39, 9), edge_index=torch.randint(0, 39, (2, 88)), y=torch.randn(1), ntype=torch.randn(39), S=torch.randn(39), edge_S=torch.randn(88, 1))

# Concatenate x and y attributes
concatenated_x = torch.cat([data1.x, data2.x], dim=0)
concatenated_y = torch.cat([data1.y, data2.y], dim=0)

# Create a new Data object with concatenated attributes
concatenated_data = Data(x=concatenated_x, edge_index=data1.edge_index, y=concatenated_y, ntype=data1.ntype, S=data1.S, edge_S=data1.edge_S)

# Print concatenated_data to verify
print(concatenated_data)

### DATASET


In [None]:
from torch_geometric.data import InMemoryDataset
import shutil
from data.transformation import AddSupernodesHeteroMulti
from torch_geometric.data import HeteroData

concepts_list = [
       {"name": "GMC", "fun": cycle_basis, "args": []},
]

In [None]:
class MoleculeHIV_herero_multi(InMemoryDataset):
    def __init__(self, root, concepts, transform=None, pre_transform=None):
        self.concepts = concepts
        super().__init__(root, transform, pre_transform)
        self.load(self.processed_paths[0], data_cls=HeteroData)

    def raw_file_names(self):
        return ['data.pth']        

    def processed_file_names(self):
        return ['transformed_dataset.pth']


    def download(self):
        dataset = MoleculeNet("./dataset/STEP/", name="HIV",
                              pre_transform=squeeze_y,)
        transformed_dataset = [AddSupernodesHeteroMulti(self.concepts)(data) for data in dataset]
        torch.save(transformed_dataset, f'{self.raw_dir}/data.pth')

    def process(self):
        print(self.processed_paths[0])
        data_list = torch.load(f'{self.raw_dir}/data.pth')

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        self.save(data_list, f'{self.processed_dir}/transformed_dataset.pth')


In [None]:
dataset = MoleculeHIV_herero_multi("./dataset/gg", concepts=concepts_list)

In [None]:
dataset[4]

In [None]:
T

dataset = MoleculeNet("./dataset/STEP/", name="HIV",
                      transform=squeeze_y,)


In [13]:
from data.transformation import AddSupernodesHeteroMulti
import torch_geometric.transforms as T


def squeeze_y(data):
    data.y = data.y.squeeze(1)
    data.smiles=None
    return data

concepts_list = [
       {"name": "GCB", "fun": cycle_basis, "args": []}, # max_num
       {"name": "GMC", "fun": max_cliques, "args": []},
    ]

dataset = MoleculeNet("./dataset/uuu", name="HIV",
                      transform=T.Compose([squeeze_y, AddSupernodesHeteroMulti(concepts_list)])
                     )
loader = DataLoader(dataset, 100,
                    shuffle=False, num_workers=0)

Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv
Processing...


eeee
zzzzz


Done!


In [14]:
a = next(iter(loader))

xxxxx
0000


In [15]:
a

HeteroDataBatch(
  normal={
    x=[1653, 9],
    batch=[1653],
    ptr=[101],
  },
  label={ y=[100] },
  GCB={
    x=[185, 9],
    batch=[185],
    ptr=[101],
  },
  GMC={
    x=[100, 9],
    batch=[100],
    ptr=[101],
  },
  (normal, orig, normal)={
    edge_index=[2, 3412],
    edge_attr=[3412, 3],
  },
  (normal, identity, normal)={ edge_index=[2, 1653] },
  (normal, toSup, GCB)={ edge_index=[2, 943] },
  (GCB, toNor, normal)={ edge_index=[2, 943] },
  (GCB, identity, GCB)={ edge_index=[2, 161] }
)

In [16]:
a['normal', 'toSup', 'GCB']

{'edge_index': tensor([[   3,    4,    5,  ..., 1328, 1329, 1324],
        [   0,    0,    0,  ...,  160,  160,  160]])}

# LIGHT

In [271]:
from data.dataset import MoleculeHIV_hetero_multi_NetDataModule
import hashlib

BATCH_SIZE = 100

concepts_list = [
       {"name": "GCB", "fun": cycle_basis, "args": []}, # max_num
    ]
supnodes_name = [concept['name'] for concept in concepts_list]

path_name = ''.join(map(lambda x: x['name'] + str(x['args']), concepts_list))
hash_name = hashlib.sha256(path_name.encode('utf-8')).hexdigest()
dataset_name = f"HIV_supernode_hetero_multi_{hash_name}"

dm = MoleculeHIV_hetero_multi_NetDataModule(f'./dataset/{dataset_name}',
                              concept_list=concepts_list, batch_size=1,
                              train_prop=0.6, test_prop=0.2, val_prop=0.2,
                              num_workers=0
                              )
dm.setup()

In [272]:
td = dm.val_dataloader()
itd = iter(td)
a = next(itd)
b = next(itd)

In [273]:
a

HeteroDataBatch(
  normal={
    x=[23, 9],
    batch=[23],
    ptr=[2],
  },
  label={ y=[1] },
  GCB={
    x=[2, 9],
    batch=[2],
    ptr=[2],
  },
  (normal, orig, normal)={
    edge_index=[2, 48],
    edge_attr=[48, 3],
  },
  (normal, identity, normal)={ edge_index=[2, 23] },
  (normal, toSup, GCB)={ edge_index=[2, 18] },
  (GCB, toNor, normal)={ edge_index=[2, 18] },
  (GCB, identity, GCB)={ edge_index=[2, 3] }
)

In [245]:
a['normal', 'toSup', 'GCB'].edge_index

tensor([[19, 20, 21, 23, 24, 18,  9, 10, 11, 13, 14,  8,  4,  5,  6, 15, 25,  2],
        [ 0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2]])

In [106]:
a['GCB'].x

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.]])

In [107]:
b

HeteroDataBatch(
  normal={
    x=[30, 9],
    batch=[30],
    ptr=[2],
  },
  label={ y=[1] },
  GCB={
    x=[4, 9],
    batch=[4],
    ptr=[2],
  },
  (normal, orig, normal)={
    edge_index=[2, 66],
    edge_attr=[66, 3],
  },
  (normal, identity, normal)={ edge_index=[2, 30] },
  (normal, toSup, GCB)={ edge_index=[2, 22] },
  (GCB, toNor, normal)={ edge_index=[2, 22] },
  (GCB, identity, GCB)={ edge_index=[2, 4] }
)

In [108]:
b['normal', 'toSup', 'GCB'].edge_index

tensor([[18, 19, 20, 21, 22, 17,  7,  8,  9, 10, 11,  6,  9, 14, 13, 12, 10,  2,
          3,  4, 15,  1],
        [ 0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  3,
          3,  3,  3,  3]])

In [246]:
from data.dataset import MoleculeHIVNetDataModule
import hashlib

BATCH_SIZE = 100

dm = MoleculeHIVNetDataModule("./dataset/Molecule_normaleee", batch_size=1,
                              train_prop=0.6, test_prop=0.2, val_prop=0.2,
                              pre_transform=squeeze_y,
                              )
dm.setup()

In [247]:
td = dm.val_dataloader()
itd = iter(td)
aT = next(itd)
bT = next(itd)

In [248]:
aT

DataBatch(x=[23, 9], edge_index=[2, 48], edge_attr=[48, 3], y=[1], batch=[23], ptr=[2])

In [254]:
from torch_geometric.utils import to_networkx
import networkx as nx

G = to_networkx(aT, to_undirected=True)
cycle_basis(G)

[[4, 15, 17, 18, 20, 3], [6, 7, 11, 14, 5]]

In [255]:
concepts_list = [
       {"name": "GCB", "fun": cycle_basis, "args": []}, # max_num
    ]
aT2 = AddSupernodesHeteroMulti(concepts_list)(aT)

In [256]:
aT2

HeteroData(
  normal={ x=[23, 9] },
  label={ y=[1] },
  GCB={ x=[2, 9] },
  (normal, orig, normal)={
    edge_index=[2, 48],
    edge_attr=[48, 3],
  },
  (normal, identity, normal)={ edge_index=[2, 23] },
  (normal, toSup, GCB)={ edge_index=[2, 11] },
  (GCB, toNor, normal)={ edge_index=[2, 11] },
  (GCB, identity, GCB)={ edge_index=[2, 2] }
)

In [117]:
aT2['normal','toSup','GCB'].edge_index

tensor([[ 4, 15, 17, 18, 20,  3,  6,  7, 11, 14,  5],
        [ 0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1]])

In [114]:
a

HeteroDataBatch(
  normal={
    x=[23, 9],
    batch=[23],
    ptr=[2],
  },
  label={ y=[1] },
  GCB={
    x=[2, 9],
    batch=[2],
    ptr=[2],
  },
  (normal, orig, normal)={
    edge_index=[2, 48],
    edge_attr=[48, 3],
  },
  (normal, identity, normal)={ edge_index=[2, 23] },
  (normal, toSup, GCB)={ edge_index=[2, 18] },
  (GCB, toNor, normal)={ edge_index=[2, 18] },
  (GCB, identity, GCB)={ edge_index=[2, 3] }
)

In [118]:
a['normal','toSup','GCB'].edge_index

tensor([[19, 20, 21, 23, 24, 18,  9, 10, 11, 13, 14,  8,  4,  5,  6, 15, 25,  2],
        [ 0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2]])

In [260]:
from data.transformation import AddSupernodesHeteroMulti
import torch_geometric.transforms as T


def squeeze_y(data):
    data.y = data.y.squeeze(1)
    data.smiles=None
    return data

concepts_list = [
       {"name": "GCB", "fun": cycle_basis, "args": []}, # max_num
    ]

dm = MoleculeHIVNetDataModule("./dataset/Molecule_normaleee111", batch_size=1,
                              train_prop=0.6, test_prop=0.2, val_prop=0.2,
                              transform=T.Compose([squeeze_y, AddSupernodesHeteroMulti(concepts_list)]),
                              num_workers=4
                              )
dm.setup()

In [261]:
td = dm.train_dataloader()
itd = iter(td)
aT = next(itd)
bT = next(itd)

In [262]:
aT

HeteroDataBatch(
  normal={
    x=[19, 9],
    batch=[19],
    ptr=[2],
  },
  label={ y=[1] },
  GCB={
    x=[2, 9],
    batch=[2],
    ptr=[2],
  },
  (normal, orig, normal)={
    edge_index=[2, 40],
    edge_attr=[40, 3],
  },
  (normal, identity, normal)={ edge_index=[2, 19] },
  (normal, toSup, GCB)={ edge_index=[2, 12] },
  (GCB, toNor, normal)={ edge_index=[2, 12] },
  (GCB, identity, GCB)={ edge_index=[2, 2] }
)

In [207]:
aT['GCB','identity','GCB'].edge_index

tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
          14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
          28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
          42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
          56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
          70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
          84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
          98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
         112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
         126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
         140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
         154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
         168, 169, 170, 171, 172, 173, 174, 175, 176

In [266]:
from data.transformation import AddSupernodesHeteroMulti
import torch_geometric.transforms as T


def squeeze_y(data):
    data.y = data.y.squeeze(1)
    data.smiles=None
    return data

concepts_list = [
       {"name": "GCB", "fun": cycle_basis, "args": []}, # max_num
    ]

dm = MoleculeHIVNetDataModule("./dataset/Molecule_normaleee111222", batch_size=1,
                              train_prop=0.6, test_prop=0.2, val_prop=0.2,
                              pre_transform=T.Compose([squeeze_y, AddSupernodesHeteroMulti(concepts_list)]),
                              num_workers=0
                              )
dm.setup()

In [269]:
td = dm.val_dataloader()
itd = iter(td)
aT = next(itd)
bT = next(itd)

In [270]:
aT

HeteroDataBatch(
  normal={
    x=[23, 9],
    batch=[23],
    ptr=[2],
  },
  label={ y=[1] },
  GCB={
    x=[2, 9],
    batch=[2],
    ptr=[2],
  },
  (normal, orig, normal)={
    edge_index=[2, 48],
    edge_attr=[48, 3],
  },
  (normal, identity, normal)={ edge_index=[2, 23] },
  (normal, toSup, GCB)={ edge_index=[2, 18] },
  (GCB, toNor, normal)={ edge_index=[2, 18] },
  (GCB, identity, GCB)={ edge_index=[2, 3] }
)

In [177]:
loaded_dataset = []
loaded_dataset = torch.load("./dataset/HIV_supernode_hetero_multi_fb29bedbe192593aeebc9d5b07254bb2420628da983f4e051acf534b8f436426/raw/data.pth")

<torch_geometric.loader.dataloader.DataLoader at 0x7f71cd49fa90>

In [216]:
dataset = MoleculeNet("./dataset/atlol", name="HIV",
                     pre_transform=squeeze_y,)

Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv
Processing...
Done!


In [218]:
concepts_list = [
       {"name": "GCB", "fun": cycle_basis, "args": []}, # max_num
    ]
transformed_dataset = [AddSupernodesHeteroMulti(concepts_list)(data) for data in dataset]

In [230]:
from torch_geometric.loader import DataLoader
loader = DataLoader(transformed_dataset, batch_size=1)


In [239]:
next(iter(loader))['GCB', 'identity', 'GCB'] == {}

for data in loader:
    if data['GCB', 'identity', 'GCB'] == {}:
        print("hello")

hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hell

KeyboardInterrupt: 

In [306]:
dataset = MoleculeNet(root="./dataset/MoleculeNety1", name='HIV', pre_transform=unsqueeze_y,
                     transform=AddSupernodesHeteroMulti(concepts_list))
t = DataLoader(dataset[0.6:0.8], batch_size=1)

In [307]:
itd = iter(t)
aT = next(itd)
aT

HeteroDataBatch(
  normal={
    x=[23, 9],
    batch=[23],
    ptr=[2],
  },
  label={ y=[1] },
  GCB={
    x=[2, 9],
    batch=[2],
    ptr=[2],
  },
  (normal, orig, normal)={
    edge_index=[2, 48],
    edge_attr=[48, 3],
  },
  (normal, identity, normal)={ edge_index=[2, 23] },
  (normal, toSup, GCB)={ edge_index=[2, 11] },
  (GCB, toNor, normal)={ edge_index=[2, 11] },
  (GCB, identity, GCB)={ edge_index=[2, 2] }
)

In [301]:
aT2 = AddSupernodesHeteroMulti(concepts_list)(aT)

In [302]:
aT2

HeteroData(
  normal={ x=[53, 9] },
  label={ y=[2] },
  GCB={ x=[6, 9] },
  (normal, orig, normal)={
    edge_index=[2, 114],
    edge_attr=[114, 3],
  },
  (normal, identity, normal)={ edge_index=[2, 53] },
  (normal, toSup, GCB)={ edge_index=[2, 35] },
  (GCB, toNor, normal)={ edge_index=[2, 35] },
  (GCB, identity, GCB)={ edge_index=[2, 6] }
)

In [309]:
torch.arange(1)

tensor([0])