In [1]:
import torch
from dataset import MoleculeNetWithNewFeatures
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from graph_models.mxmnet.model import MXMNetModel
from graph_models.pamnet.model import PAMNetModel
from dataloader import GraphDataModule

In [2]:
hparams = dict(
    train_split = 0.9,
    train_shuffle = True,
    seed = 42,
    learning_rate = 3e-4,
    batch_size = 512,
    exit_dim = 128, # n_classes - for classification task
    
    # net
    cutoff=5.0,
    cutoff_l=5.0,
    cutoff_g=5.0,
    num_layers=4,
    in_channels = 79,
    out_channels = 128, #size of embed vector
    hidden_dim = 128,
    num_spherical = 7,
    num_radial = 6,
    envelope_exponent = 5
)

In [3]:
dataset = MoleculeNetWithNewFeatures(root='data', name='PCBA')
dm = GraphDataModule(dataset, hparams)
#model = PAMNetModel(hparams)
model = PAMNetModel.load_from_checkpoint('model_checkpoints/PAMNET_checkpoint-epoch=03-val_loss=0.28-v1.ckpt')

print(f"Number of Parameters: {sum(p.numel() for p in model.parameters())}")

Number of Parameters: 2550696


In [4]:
torch.set_float32_matmul_precision('medium') # suggested by lightning

In [5]:
wandb_logger = WandbLogger()

# saves top-3 checkpoints based on "val_loss" metric
checkpoint_callback = ModelCheckpoint(
    save_top_k=3,
    monitor="epoch",
    mode="max",
    dirpath="model_checkpoints/",
    filename="PAMNET_checkpoint-{epoch:02d}-{val_loss:.2f}",
)

trainer = pl.Trainer(
    max_epochs=40, # ~ 30 epochs is enough on 5e-4 lr
    accelerator="auto",
    devices='auto',
    callbacks=[checkpoint_callback],
    precision="bf16-mixed", # why not
    logger=wandb_logger,
    #detect_anomaly=True,
    #fast_dev_run = True,
)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [6]:
trainer.fit(model, datamodule=dm)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mklyambus[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111479883333328, max=1.0)…

/home/langley/miniconda3/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:634: Checkpoint directory model_checkpoints/ exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type              | Params
-------------------------------------------------
0 | encoder    | PAMNet            | 2.5 M 
1 | classifier | Sequential        | 16.5 K
2 | crit       | BCEWithLogitsLoss | 0     
3 | metric     | BinaryAUROC       | 0     
-------------------------------------------------
2.6 M     Trainable params
0         Non-trainable params
2.6 M     Total params
10.203    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

BCE_val_loss: 0.2983, AUROC: 0.9488


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

BCE_val_loss: 0.2819, AUROC: 0.9535


Validation: |          | 0/? [00:00<?, ?it/s]

BCE_val_loss: 0.2849, AUROC: 0.9532


Validation: |          | 0/? [00:00<?, ?it/s]

BCE_val_loss: 0.2823, AUROC: 0.9536


Validation: |          | 0/? [00:00<?, ?it/s]

BCE_val_loss: 0.2868, AUROC: 0.9529


Validation: |          | 0/? [00:00<?, ?it/s]

BCE_val_loss: 0.2883, AUROC: 0.9524


/home/langley/miniconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
