In [14]:
import torch
from torch.nn import MSELoss
import os
from EGNN5 import EGNN5
from MD17_data import benzene_dataloaders
import wandb
from test_model import test_model
from train_model import train_model
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [15]:
# reproducibility
torch.manual_seed(2002)

<torch._C.Generator at 0x10ba8ffd0>

In [16]:
# hyperparameters saved to config dict
config = {
    'name': 'EGNN5',
    'base_learning_rate': 0.001,
    'num_epochs': 5,
    'optimizer': 'Adam',
    'scheduler': 'ReduceLROnPlateau',
    'scheduler_mode': 'min',
    'scheduler_factor': 0.32, 
    'scheduler_patience': 1,
    'scheduler_threshold': 0,
    'training_loss_fn': 'MSELoss',
    'rho': 1-1e-1,
    'batch_size': 32
}

# initialize the star of the show
model = EGNN5()

# I couldn't think of a concise way to initialize optimizer, scheduler, and loss_fn based on the contents of config
# this is all for show anyway, but it would be nice to have a natural way of doing this that generalizes when I am selecting hyperparameters more carefully
optimizer = Adam(model.parameters(), lr=config['base_learning_rate'])

scheduler = ReduceLROnPlateau(
    optimizer=optimizer, 
    mode=config['scheduler_mode'], 
    factor=config['scheduler_factor'], 
    patience=config['scheduler_patience'], 
    threshold=config['scheduler_threshold']
    )

loss_fn = MSELoss()


In [17]:
# setting up wandb
os.environ['WANDB_NOTEBOOK_NAME'] = 'main.py'

# wandb
wandb.init(
    project = "EGNN",
    config = config,
)

In [18]:
# get dataloaders
train_loader, val_loader, test_loader = benzene_dataloaders(train_split=0.8, val_split=0.1, test_split = 0.1, batch_size=config['batch_size'])



In [19]:
train_model(model=model, optimizer=optimizer, scheduler=scheduler, loss_fn=loss_fn, train_loader=train_loader, val_loader=val_loader, rho=config['rho'], num_epochs=config['num_epochs'], name=config['name'])
test_model(model=model, loss_fn=loss_fn, test_loader=test_loader, rho=config['rho'])

EPOCH 1 OF 5 | VAL MEAN LOSS: 3.7275174236128805e-07
EPOCH 2 OF 5 | VAL MEAN LOSS: 1.4036270385986427e-06
EPOCH 3 OF 5 | VAL MEAN LOSS: 8.260968797912938e-08
EPOCH 4 OF 5 | VAL MEAN LOSS: 7.601409635071832e-08
EPOCH 5 OF 5 | VAL MEAN LOSS: 1.5135384501263616e-08


NameError: name 'E_squared_loss' is not defined