In [1]:
import torch
import torch.nn as nn
from torch_geometric.loader import DataLoader
from schnetspec import SchNetspec
from create_dataset import XASDataset

INFO:rdkit:Enabling RDKit 2024.03.5 jupyter extensions


In [2]:
def train(train_loader, model, criterion, optimizer, epoch, device):

    model.train()

    loss_all = 0

    for batch in train_loader:
        batch = batch.to(device)

        optimizer.zero_grad()

        output = model(batch.z, batch.pos, batch.batch)

        target = torch.Tensor(batch.spectrum).to(device)

        train_loss = criterion(output.double(), target.double())

        loss_all += train_loss.data * batch.num_graphs

        train_loss.backward()

        optimizer.step()

    return loss_all
    
def validate(val_loader, model, criterion, epoch, device):
    losses_all = 0

    model.eval()

    for batch in val_loader:
        batch = batch.to(device)

        output = model(batch.z, batch.pos, batch.batch)

        target = torch.Tensor(batch.spectrum).to(device)

        val_loss = criterion(output.double(), target.double())

        losses_all += val_loss.data * batch.num_graphs

    return losses_all

In [3]:
data = XASDataset('./')

  if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
  if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
Processing...


Total number of molecules: 125833


Done!


In [8]:
train_dataset = data[0:113249]
val_dataset = data[113249:119541]

train_loader = DataLoader(train_dataset, batch_size=50, shuffle=True, num_workers=23)
val_loader = DataLoader(val_dataset, batch_size=50, shuffle=True, num_workers=23)

conv = SchNetspec(hidden_channels=300, num_filters=200, num_interactions=6,
                  num_gaussians=50, cutoff=10.0)

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = conv.to(device)

In [10]:
model

SchNetspec(hidden_channels=300, num_filters=200, num_interactions=6, num_gaussians=50, cutoff=10.0)

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                       factor=0.9, patience=5,
                                                       min_lr=0.000001)
criterion = nn.MSELoss()

In [15]:
loss_train = []

for epoch in range(300):
    lr = scheduler.optimizer.param_groups[0]['lr']

    train_loss = (train(train_loader, model, criterion, optimizer, epoch, device)/(len(train_dataset)))

    print(train_loss)

    val_loss = (validate(val_loader, model, criterion, epoch, device)/len(val_dataset))

    print(val_loss)

    scheduler.step(val_loss)
    

AssertionError: 