In [2]:
import torch
from torch.nn import MSELoss
import os
from EGNN5 import EGNN5
from get_MD17_data_loaders import get_MD17_data_loaders
import wandb
from test_model import test_model
from train_model import train_model
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [3]:
# reproducibility
torch.manual_seed(2002)

<torch._C.Generator at 0x107698170>

In [4]:
# hyperparameters saved to config dict
config = {
    'name': 'EGNN5',
    'base_learning_rate': 0.001,
    'num_epochs': 50,
    'optimizer': 'Adam',
    'scheduler': 'ReduceLROnPlateau',
    'scheduler_mode': 'min',
    'scheduler_factor': 0.32, 
    'scheduler_patience': 1,
    'scheduler_threshold': 0,
    'training_loss_fn': 'MSELoss',
    'rho': 1-1e-1,
    'batch_size': 32
}

# initialize the star of the show
model = EGNN5()

# I couldn't think of a concise way to initialize optimizer, scheduler, and loss_fn based on the contents of config
# this is all for show anyway, but it would be nice to have a natural way of doing this that generalizes when I am selecting hyperparameters more carefully
optimizer = Adam(model.parameters(), lr=config['base_learning_rate'])

scheduler = ReduceLROnPlateau(
    optimizer=optimizer, 
    mode=config['scheduler_mode'], 
    factor=config['scheduler_factor'], 
    patience=config['scheduler_patience'], 
    threshold=config['scheduler_threshold']
    )

loss_fn = MSELoss()


In [5]:
# setting up wandb
os.environ['WANDB_NOTEBOOK_NAME'] = 'main.py'

# wandb
wandb.init(
    project = "EGNN",
    config = config,
)

0,1
E_train_losses,█▁▁▁▁
F_train_losses,▁█▆▆▅
train_losses,█▁▁▁▁
training_rates,▁▁▁▁▁

0,1
E_train_losses,1e-05
F_train_losses,0.02226
train_losses,0.02227
training_rates,0.001


In [6]:
# get dataloaders
train_loader, val_loader, test_loader = get_MD17_data_loaders(train_split=0.8, val_split=0.1, test_split = 0.1, batch_size=config['batch_size'])



In [7]:
train_model(model=model, optimizer=optimizer, scheduler=scheduler, loss_fn=loss_fn, train_loader=train_loader, val_loader=val_loader, rho=config['rho'], num_epochs=config['num_epochs'], name=config['name'])
test_model(model=model, loss_fn=loss_fn, test_loader=test_loader, rho=config['rho'])

DataBatch(pos=[384, 3], z=[384], energy=[32], force=[384, 3], edge_index=[2, 768], batch=[384], ptr=[33])
DataBatch(pos=[384, 3], z=[384], energy=[32], force=[384, 3], edge_index=[2, 768], batch=[384], ptr=[33])
DataBatch(pos=[384, 3], z=[384], energy=[32], force=[384, 3], edge_index=[2, 768], batch=[384], ptr=[33])
DataBatch(pos=[384, 3], z=[384], energy=[32], force=[384, 3], edge_index=[2, 768], batch=[384], ptr=[33])
DataBatch(pos=[384, 3], z=[384], energy=[32], force=[384, 3], edge_index=[2, 768], batch=[384], ptr=[33])
DataBatch(pos=[384, 3], z=[384], energy=[32], force=[384, 3], edge_index=[2, 768], batch=[384], ptr=[33])
DataBatch(pos=[384, 3], z=[384], energy=[32], force=[384, 3], edge_index=[2, 768], batch=[384], ptr=[33])
DataBatch(pos=[384, 3], z=[384], energy=[32], force=[384, 3], edge_index=[2, 768], batch=[384], ptr=[33])
DataBatch(pos=[384, 3], z=[384], energy=[32], force=[384, 3], edge_index=[2, 768], batch=[384], ptr=[33])
DataBatch(pos=[384, 3], z=[384], energy=[32], 

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x103959bd0>>
Traceback (most recent call last):
  File "/usr/local/Caskroom/miniconda/base/envs/GDL/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


DataBatch(pos=[384, 3], z=[384], energy=[32], force=[384, 3], edge_index=[2, 768], batch=[384], ptr=[33])
DataBatch(pos=[384, 3], z=[384], energy=[32], force=[384, 3], edge_index=[2, 768], batch=[384], ptr=[33])
DataBatch(pos=[384, 3], z=[384], energy=[32], force=[384, 3], edge_index=[2, 768], batch=[384], ptr=[33])
DataBatch(pos=[384, 3], z=[384], energy=[32], force=[384, 3], edge_index=[2, 768], batch=[384], ptr=[33])
DataBatch(pos=[384, 3], z=[384], energy=[32], force=[384, 3], edge_index=[2, 768], batch=[384], ptr=[33])
DataBatch(pos=[384, 3], z=[384], energy=[32], force=[384, 3], edge_index=[2, 768], batch=[384], ptr=[33])
DataBatch(pos=[384, 3], z=[384], energy=[32], force=[384, 3], edge_index=[2, 768], batch=[384], ptr=[33])
DataBatch(pos=[384, 3], z=[384], energy=[32], force=[384, 3], edge_index=[2, 768], batch=[384], ptr=[33])
DataBatch(pos=[384, 3], z=[384], energy=[32], force=[384, 3], edge_index=[2, 768], batch=[384], ptr=[33])
DataBatch(pos=[384, 3], z=[384], energy=[32], 