## Tutorial showing how to use DCM

### Setup

In [1]:
import warnings
warnings.filterwarnings("ignore")

import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import pytorch_lightning as pl
import torch
from model import DCM
from torch_geometric import seed_everything
from torch_geometric.data import LightningNodeData
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

torch.set_default_tensor_type("torch.cuda.FloatTensor")
torch.set_float32_matmul_precision("high")


This configuration file is used to determine the hyperparameters of the model and set the random seeds to have repeatable results.

In [2]:
config = {
    "metric": {"name": "val_acc", "goal": "maximize"},
    "seed":  2,
    "data_seed":  0,
    "hsize":  32,
    "n_pre":  1,
    "n_post":  1,
    "n_conv":  1,
    "n_dgm_layers":  2,
    "dropout":  0.5,
    "lr":  0.01,
    "use_gcn":  False,
    "k":  4,
    "graph_loss_reg":  1,
    "poly_loss_reg":  1,
    "std":  0,
    "ensemble_steps":  3,
    "gamma":  50,
    "epochs":  200,
    }

seed_everything(config["seed"])

### Loading the dataset

For this example the Cora dataset is used.

In [3]:
dataset = Planetoid(
        root="data/Planetoid",
        name="cora",
        split="random",
        num_train_per_class=373,
        num_val=271,
        num_test=271,
        transform=NormalizeFeatures(),
    )
    
data = dataset[0]

datamodule = LightningNodeData(
    data,
    data.train_mask,
    data.val_mask,
    data.test_mask,
    loader="full",
    generator=torch.Generator(device="cuda"),
)

### Creating the model

In [4]:
config["num_features"] = dataset.num_features
config["num_classes"] = dataset.num_classes
config["pre_layers"] = [dataset.num_features] + [config["hsize"] for _ in range(config["n_pre"])]
config["post_layers"] = [config["hsize"] for _ in range(config["n_post"])] + [dataset.num_classes]
config["dgm_layers"] = [config["hsize"] for _ in range(config["n_dgm_layers"] + 1)]
config["conv_layers"] = [config["hsize"] for _ in range(config["n_conv"])]

model = DCM(config)

### Training

In [5]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="val_acc", mode="max")

trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    max_epochs=config["epochs"],
    log_every_n_steps=1,
    check_val_every_n_epoch=3,
    num_sanity_val_steps=0,
    callbacks=[checkpoint_callback],
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [6]:
trainer.fit(model, datamodule)
trainer.test(ckpt_path="best", datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | pre       | MLP                | 45.9 K
1 | graph_f   | DGM                | 2.1 K 
2 | gnn       | GNN                | 0     
3 | cwnn      | CWNN               | 0     
4 | poly_ln   | LayerNorm          | 2     
5 | post      | MLP                | 455   
6 | train_acc | MulticlassAccuracy | 0     
7 | val_acc   | MulticlassAccuracy | 0     
8 | test_acc  | MulticlassAccuracy | 0     
-------------------------------------------------
48.5 K    Trainable params
0         Non-trainable params
48.5 K    Total params
0.194     Total estimated model params size (MB)


Epoch 199: 100%|██████████| 1/1 [00:00<00:00,  6.01it/s, loss=0.378, v_num=57, train_acc=0.869, train_loss=0.381, val_acc=0.697]

`Trainer.fit` stopped: `max_epochs=200` reached.


Epoch 199: 100%|██████████| 1/1 [00:00<00:00,  5.71it/s, loss=0.378, v_num=57, train_acc=0.869, train_loss=0.381, val_acc=0.697]


Restoring states from the checkpoint path at /home/marco/Documents/phd/DCM_simple/differentiable_cell-complex_module/lightning_logs/version_57/checkpoints/epoch=8-step=9.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at /home/marco/Documents/phd/DCM_simple/differentiable_cell-complex_module/lightning_logs/version_57/checkpoints/epoch=8-step=9.ckpt


Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  1.83it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.7970479726791382
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_acc': 0.7970479726791382}]