## Tutorial showing how to use DCM

### Setup

In [1]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import pytorch_lightning as pl
import torch
from utils.data_utils import cross_validation_split
from model.model_dcm import ModelDCM
from modules.cell_network import CellNetwork
from torch_geometric import seed_everything
from torch_geometric.data import LightningNodeData
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

torch.set_default_tensor_type("torch.cuda.FloatTensor")
torch.set_float32_matmul_precision("high")


  from .autonotebook import tqdm as notebook_tqdm


This configuration file is used to determine the hyperparameters of the model and set the random seeds to have repeatable results.

In [2]:
config = {
    "metric": {"name": "val_acc", "goal": "maximize"},
    "seed":  42,
    "data_seed":  0,
    "hsize":  32,
    "n_pre":  1,
    "n_post":  1,
    "n_conv":  1,
    "n_dgm_layers":  2,
    "n_dgm_blocks":  1,
    "dropout":  0.5,
    "lr":  0.01,
    "use_gcn":  False,
    "sampler":  "entmax",
    "sample_P":  "entmax",
    "k":  4,
    "graph_loss_reg":  1,
    "poly_loss_reg":  1,
    "std":  0,
    "ensemble_steps":  1,
    "gamma":  50,
    "epochs":  100,
    }

seed_everything(config["seed"])

### Loading the dataset

For this example the Cora dataset is used.

In [3]:
dataset = Planetoid(
        root="data/Planetoid",
        name="cora",
        split="full",
        transform=NormalizeFeatures(),
    )
    
data = dataset[0]
# Update data split
data = cross_validation_split(
    data, dataset_name="cora", curr_seed=config["data_seed"]
)
print(f"Nodes features shape: {data.x.shape}")
print(f"Labels shape: {data.y.shape}")
print(f"Number of edges: {data.edge_index.shape[1]}")
print(f"Number of training samples: {torch.sum(data.train_mask)}")
print(f"Number of validation samples: {torch.sum(data.val_mask)}")
print(f"Number of test samples: {torch.sum(data.test_mask)}")

datamodule = LightningNodeData(
    data,
    data.train_mask,
    data.val_mask,
    data.test_mask,
    loader="full",
    generator=torch.Generator(device="cuda"),
)

Nodes features shape: torch.Size([2708, 1433])
Labels shape: torch.Size([2708])
Number of edges: 10556
Number of training samples: 2437
Number of validation samples: 271
Number of test samples: 271


### Creating the model

In [4]:
config["num_features"] = dataset.num_features
config["num_classes"] = dataset.num_classes
config["pre_layers"] = [dataset.num_features] + [config["hsize"] for _ in range(config["n_pre"])]
config["post_layers"] = [config["hsize"] for _ in range(config["n_post"])] + [dataset.num_classes]
config["dgm_layers"] = [config["hsize"] for _ in range(config["n_dgm_layers"] + 1)]
config["conv_layers"] = [config["hsize"] for _ in range(config["n_conv"])]

model = ModelDCM(config)

### Training

In [5]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="val_acc", mode="max")

trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    max_epochs=config["epochs"],
    log_every_n_steps=3,
    check_val_every_n_epoch=3,
    num_sanity_val_steps=0,
    callbacks=[checkpoint_callback],
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [6]:
trainer.fit(model, datamodule)
trainer.test(ckpt_path="best", datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type               | Params
----------------------------------------------------
0 | pre          | MLP                | 45.9 K
1 | dcm          | DCM                | 2.1 K 
2 | cell_network | CellNetwork        | 0     
3 | post         | MLP                | 455   
4 | train_acc    | MulticlassAccuracy | 0     
5 | val_acc      | MulticlassAccuracy | 0     
6 | test_acc     | MulticlassAccuracy | 0     
----------------------------------------------------
48.5 K    Trainable params
0         Non-trainable params
48.5 K    Total params
0.194     Total estimated model params size (MB)
  rank_zero_warn(


Epoch 2: 100%|██████████| 1/1 [00:00<00:00,  9.53it/s, loss=2.78, v_num=17, train_acc=0.159, train_loss=2.770]

  rank_zero_warn(


Epoch 99: 100%|██████████| 1/1 [00:00<00:00,  8.62it/s, loss=0.554, v_num=17, train_acc=0.816, train_loss=0.548, val_acc=0.775]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 1/1 [00:00<00:00,  8.16it/s, loss=0.554, v_num=17, train_acc=0.816, train_loss=0.548, val_acc=0.775]

Restoring states from the checkpoint path at /home/marco/Documents/phd/DCM_simple/differentiable_cell-complex_module/lightning_logs/version_17/checkpoints/epoch=74-step=75.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at /home/marco/Documents/phd/DCM_simple/differentiable_cell-complex_module/lightning_logs/version_17/checkpoints/epoch=74-step=75.ckpt
  rank_zero_warn(



Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  6.35it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.7933579087257385
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_acc': 0.7933579087257385}]

### Changing the CellNetwork

Testing a different network over the structure inferred by DCM is straightforward.

In [7]:
# addign one more conv layer
config["num_features"] = dataset.num_features
config["num_classes"] = dataset.num_classes
config["pre_layers"] = [dataset.num_features] + [config["hsize"] for _ in range(config["n_pre"])]
config["post_layers"] = [config["hsize"] for _ in range(config["n_post"])] + [dataset.num_classes]
config["dgm_layers"] = [config["hsize"] for _ in range(config["n_dgm_layers"] + 1)]
config["conv_layers"] = [config["hsize"] for _ in range(config["n_conv"] + 1)]

model = ModelDCM(config)

new_cell_network = CellNetwork(config)
model.cell_network = new_cell_network

In [8]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="val_acc", mode="max")

trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    max_epochs=config["epochs"],
    log_every_n_steps=3,
    check_val_every_n_epoch=3,
    num_sanity_val_steps=0,
    callbacks=[checkpoint_callback],
)
trainer.fit(model, datamodule)
trainer.test(ckpt_path="best", datamodule=datamodule)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type               | Params
----------------------------------------------------
0 | pre          | MLP                | 45.9 K
1 | dcm          | DCM                | 2.1 K 
2 | cell_network | CellNetwork        | 4.2 K 
3 | post         | MLP                | 455   
4 | train_acc    | MulticlassAccuracy | 0     
5 | val_acc      | MulticlassAccuracy | 0     
6 | test_acc     | MulticlassAccuracy | 0     
----------------------------------------------------
52.7 K    Trainable params
0         Non-trainable params
52.7 K    Total params
0.211     Total estimated model params size (MB)
  rank_zero_warn(


Epoch 2: 100%|██████████| 1/1 [00:00<00:00,  9.60it/s, loss=2.7, v_num=18, train_acc=0.136, train_loss=2.700] 

  rank_zero_warn(


Epoch 99: 100%|██████████| 1/1 [00:00<00:00,  9.07it/s, loss=0.993, v_num=18, train_acc=0.684, train_loss=0.932, val_acc=0.760]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 1/1 [00:00<00:00,  8.59it/s, loss=0.993, v_num=18, train_acc=0.684, train_loss=0.932, val_acc=0.760]

Restoring states from the checkpoint path at /home/marco/Documents/phd/DCM_simple/differentiable_cell-complex_module/lightning_logs/version_18/checkpoints/epoch=77-step=78.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at /home/marco/Documents/phd/DCM_simple/differentiable_cell-complex_module/lightning_logs/version_18/checkpoints/epoch=77-step=78.ckpt
  rank_zero_warn(



Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  5.48it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.7712177038192749
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_acc': 0.7712177038192749}]