In [1]:
import torch
import torch.nn as nn
from datasets.pcba import PCBA
from datasets.moshkov import MoshkovCS
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pamnet.model import PAMNetModel
from dataloader import GraphDataModule


In [2]:
hparams = dict(
    train_split = 0.8,
    train_shuffle = True,
    seed = 42,
    learning_rate = 3e-4,
    batch_size = 384,
    exit_dim = 270, # n_classes - for classification task
    
    # net
    cutoff = 5.0,
    cutoff_l = 5.0,
    cutoff_g = 5.0,
    num_layers = 4,
    in_channels = 79,
    hidden_dim = 64,
    n_heads = 4,
    dropout = 0.2,
    out_channels = 64, #size of embed vector
    num_spherical = 7,
    num_radial = 6,
    envelope_exponent = 5
)

In [3]:
model = PAMNetModel(hparams)
print(f"Number of Parameters: {sum(p.numel() for p in model.parameters())}")

Number of Parameters: 687670


In [4]:
#dataset = PCBA(root='data')
dataset = MoshkovCS('data/moshkov_cs', with_hydrogen=False, with_coords=True).split(1)
dm = GraphDataModule(dataset, hparams)

In [5]:
torch.set_float32_matmul_precision('medium') # suggested by lightning

In [6]:
#wandb_logger = WandbLogger()

# saves top-k checkpoints based on "val_loss" metric
checkpoint_callback = ModelCheckpoint(
    save_top_k=1,
    monitor="val_metric",
    mode="max",
    dirpath="model_checkpoints/",
    filename="PAMNet_checkpoint-{epoch:02d}-{val_metric:.2f}-{num_good_labels:.0f}",
)

trainer = pl.Trainer(
    max_epochs=50, # ~ 30 epochs is enough on 5e-4 lr
    accelerator="auto",
    devices='auto',
    #enable_checkpointing=False,
    callbacks=[checkpoint_callback],
    precision="bf16-mixed", # why not
    log_every_n_steps=30,
    #logger=wandb_logger,
    #detect_anomaly=True,
    #fast_dev_run = True,
)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
trainer.fit(model, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type              | Params
-------------------------------------------------
0 | encoder    | PAMNet            | 644 K 
1 | classifier | Sequential        | 43.1 K
2 | crit       | BCEWithLogitsLoss | 0     
3 | metric     | BinaryAUROC       | 0     
-------------------------------------------------
687 K     Trainable params
0         Non-trainable params
687 K     Total params
2.751     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

                            CS
--------------------  --------
# assays AUROC > 0.9  2
val mean loss         0.730367
val mean metric       0.531389




Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  18
val mean loss          0.632352
val mean metric        0.69578




Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  19
val mean loss          0.624174
val mean metric        0.701501


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  21
val mean loss          0.616884
val mean metric        0.716148


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  21
val mean loss          0.61413
val mean metric        0.723709


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  22
val mean loss          0.617742
val mean metric        0.71947


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  18
val mean loss          0.605356
val mean metric        0.737786


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  20
val mean loss          0.59745
val mean metric        0.749485


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  20
val mean loss          0.59805
val mean metric        0.752647


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  21
val mean loss          0.605029
val mean metric        0.743651


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  25
val mean loss          0.591371
val mean metric        0.760909


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  23
val mean loss          0.610833
val mean metric        0.751793


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  24
val mean loss          0.582616
val mean metric        0.773035


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  27
val mean loss          0.585094
val mean metric        0.767381


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  25
val mean loss          0.587571
val mean metric        0.769988


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  25
val mean loss          0.592568
val mean metric        0.75938


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  24
val mean loss          0.583124
val mean metric        0.777276


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  23
val mean loss          0.585739
val mean metric        0.770783


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  27
val mean loss          0.583641
val mean metric        0.77381


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  22
val mean loss          0.59361
val mean metric        0.764043


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  22
val mean loss          0.586181
val mean metric        0.770846


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  25
val mean loss          0.578882
val mean metric        0.777769


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  19
val mean loss          0.584489
val mean metric        0.775972


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  21
val mean loss          0.583167
val mean metric        0.777509


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  24
val mean loss          0.579095
val mean metric        0.783495


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  20
val mean loss          0.586922
val mean metric        0.782402


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  23
val mean loss          0.585899
val mean metric        0.780742


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  21
val mean loss          0.590405
val mean metric        0.771524


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  24
val mean loss          0.585358
val mean metric        0.779399


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  23
val mean loss          0.58643
val mean metric        0.779963


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  27
val mean loss          0.58455
val mean metric        0.781584


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  28
val mean loss          0.590688
val mean metric        0.777067


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  22
val mean loss          0.594688
val mean metric        0.772505


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  32
val mean loss          0.584226
val mean metric        0.787453


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  31
val mean loss          0.589338
val mean metric        0.784471


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  25
val mean loss          0.589167
val mean metric        0.780514


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  30
val mean loss          0.599958
val mean metric        0.775058


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  30
val mean loss          0.595596
val mean metric        0.783712


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  33
val mean loss          0.585726
val mean metric        0.78286


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  31
val mean loss          0.597731
val mean metric        0.776115


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  32
val mean loss          0.595572
val mean metric        0.780353


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  31
val mean loss          0.587419
val mean metric        0.784599


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  30
val mean loss          0.594772
val mean metric        0.77691


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  27
val mean loss          0.607699
val mean metric        0.772127


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  29
val mean loss          0.598134
val mean metric        0.784235


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  34
val mean loss          0.60202
val mean metric        0.780223


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  32
val mean loss          0.610924
val mean metric        0.775524


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  30
val mean loss          0.600225
val mean metric        0.780295


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  34
val mean loss          0.603753
val mean metric        0.78537


Validation: |          | 0/? [00:00<?, ?it/s]

                             CS
--------------------  ---------
# assays AUROC > 0.9  28
val mean loss          0.616236
val mean metric        0.766342


Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


                             CS
--------------------  ---------
# assays AUROC > 0.9  29
val mean loss          0.613383
val mean metric        0.775768
