In [1]:
import torch
import torch.nn as nn
from datasets.pcba import PCBA
from datasets.moshkov import MoshkovCS
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pamnet.model import PAMNetModel
from dataloader import GraphDataModule


In [2]:
hparams = dict(
    train_split = 0.8,
    train_shuffle = True,
    seed = 42,
    learning_rate = 3e-4,
    batch_size = 384,
    exit_dim = 270, # n_classes - for classification task
    
    # net
    cutoff = 5.0,
    cutoff_l = 5.0,
    cutoff_g = 5.0,
    num_layers = 4,
    in_channels = 79,
    hidden_dim = 64,
    n_heads = 4,
    dropout = 0.2,
    out_channels = 64, #size of embed vector
    num_spherical = 7,
    num_radial = 6,
    envelope_exponent = 5
)

In [3]:
model = PAMNetModel(hparams)
print(f"Number of Parameters: {sum(p.numel() for p in model.parameters())}")

Number of Parameters: 687670


In [4]:
#dataset = PCBA(root='data')
dataset = MoshkovCS('data/moshkov_cs', with_hydrogen=False, with_coords=True).split(1)
dm = GraphDataModule(dataset, hparams)

In [5]:
torch.set_float32_matmul_precision('medium') # suggested by lightning

In [6]:
#wandb_logger = WandbLogger()

# saves top-k checkpoints based on "val_loss" metric
checkpoint_callback = ModelCheckpoint(
    save_top_k=1,
    monitor="val_metric",
    mode="max",
    dirpath="model_checkpoints/",
    filename="PAMNet_checkpoint-{epoch:02d}-{val_metric:.2f}-{num_good_labels:.0f}",
)

trainer = pl.Trainer(
    max_epochs=50, # ~ 30 epochs is enough on 5e-4 lr
    accelerator="auto",
    devices='auto',
    #enable_checkpointing=False,
    callbacks=[checkpoint_callback],
    precision="bf16-mixed", # why not
    log_every_n_steps=30,
    #logger=wandb_logger,
    #detect_anomaly=True,
    #fast_dev_run = True,
)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, datamodule=dm)