Imports and Utilities

In [1]:
import sys

%load_ext autoreload
%autoreload 2
from modules.data.load.loaders import GraphLoader
from modules.data.preprocess.preprocessor import PreProcessor
from modules.utils.utils import (
    describe_data,
    load_dataset_config,
    load_model_config,
    load_transform_config,
)

Loading the Dataset

In [2]:
dataset_name = "Ethereum"
dataset_config = load_dataset_config(dataset_name)
loader = GraphLoader(dataset_config)


Dataset configuration for Ethereum:

{'data_domain': 'graph',
 'data_type': 'Transactions',
 'data_name': 'EthereumTokenNetwork',
 'data_dir': 'datasets/Transactions/EthereumTokenNetwork',
 'num_features': 1,
 'num_classes': 5,
 'task': 'classification',
 'loss_type': 'cross_entropy',
 'task_level': 'graph'}


In [3]:
dataset = loader.load()
describe_data(dataset, 1)

/mnt/c/Users/ronan/Downloads/GitHub/challenge-icml-2024/datasets/Transactions/EthereumTokenNetwork/EthereumTokenNetwork.pt

Dataset only contains 1 sample:
 - Graph with 47052 vertices and 79722 edges.
 - Features dimensions: [1, 0]
 - There are 32686 isolated nodes.



Load and Apply Lifting

In [4]:
# Define transformation type and id
transform_type = "liftings"
# If the transform is a topological lifting, it should include both the type of the lifting and the identifier
transform_id = "digraph2simplicial/weighted_clique_lifting"

# Read yaml file
transform_config = {"lifting": load_transform_config(transform_type, transform_id)}


Transform configuration for digraph2simplicial/weighted_clique_lifting:

{'transform_type': 'lifting',
 'transform_name': 'WeightedSimplicialCliqueLifting',
 'complex_dim': 3,
 'preserve_edge_attr': True,
 'signed': True,
 'feature_lifting': 'ProjectionSum'}


In [5]:
lifted_dataset = PreProcessor(dataset, transform_config, loader.data_dir)
describe_data(lifted_dataset)

Create and Run the Simplicial NN Model

In [None]:
from modules.models.simplicial.san import SANModel

model_type = "simplicial"
model_id = "san"
model_config = load_model_config(model_type, model_id)  # I need to look at this

model = SANModel(model_config, dataset_config)

In [None]:
# Verify this works for one pass then actually create a full model if we want to
y_hat = model(lifted_dataset.get(0))
print(y_hat)