In [1]:
import sys

sys.path.append("../")

import torch
import torch.utils as utils

import pytorch_lightning as pl

from torchemlp.groups import SO, O, S, Z
from torchemlp.nn.equivariant import EMLP
from torchemlp.nn.utils import RegressionLightning, Standardize
from torchemlp.datasets import O5Synthetic

In [2]:
TRAINING_SET_SIZE = 5_000
VALIDATION_SET_SIZE = 1_000
TEST_SET_SIZE = 1_000

BATCH_SIZE = 500

N_EPOCHS = min(int(900_000 / TRAINING_SET_SIZE), 1000)

N_CHANNELS = 384
N_LAYERS = 3

DL_WORKERS = 0

In [3]:
dataset = O5Synthetic(
    TRAINING_SET_SIZE + VALIDATION_SET_SIZE + TEST_SET_SIZE, device="cuda"
)
f"Input type: {dataset.repin(dataset.G)}, output type: {dataset.repout(dataset.G)}"

'Input type: 2V, output type: V⁰'

In [4]:
split_data = utils.data.random_split(
    dataset, [TRAINING_SET_SIZE, VALIDATION_SET_SIZE, TEST_SET_SIZE]
)

train_loader = utils.data.DataLoader(
    split_data[0], batch_size=BATCH_SIZE, num_workers=DL_WORKERS, shuffle=True
)
val_loader = utils.data.DataLoader(
    split_data[1], batch_size=BATCH_SIZE, num_workers=DL_WORKERS
)
test_loader = utils.data.DataLoader(
    split_data[2], batch_size=BATCH_SIZE, num_workers=DL_WORKERS
)

In [5]:
model = Standardize(
    EMLP(dataset.repin, dataset.repout, dataset.G, N_CHANNELS, N_LAYERS), dataset.stats
).cuda()
plmodel = RegressionLightning(model)

In [6]:
trainer = pl.Trainer(
    limit_train_batches=BATCH_SIZE, max_epochs=N_EPOCHS, accelerator="gpu"
)
trainer.fit(plmodel, train_loader, val_loader)
trainer.test(plmodel, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type        | Params
--------------------------------------
0 | model | Standardize | 480 K 
--------------------------------------
480 K     Trainable params
0         Non-trainable params
480 K     Total params
1.923     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|                                                                                                                                                            | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                                                                                                                                                                     

  rank_zero_warn(
  rank_zero_warn(


Epoch 0:  83%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                        | 10/12 [00:04<00:00,  2.03it/s, loss=41.7, v_num=34]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                                                                                                                              | 0/2 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                                                                                 | 0/2 [00:00<?, ?it/s][A
Epoch 0:  92%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊            | 11/12 [00:05<00:00,  2.19it/s, loss=41.7, v_num=34][A
Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████

`Trainer.fit` stopped: `max_epochs=180` reached.


Epoch 179: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:05<00:00,  2.28it/s, loss=0.0358, v_num=34]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 11.81it/s]


[{'test_loss': 0.06504896283149719}]