In [1]:
import torch
import torch.nn as nn
from torch_geometric.loader import DataLoader
from schnetspec import SchNetspec
from pyg_schnet import SchNet
from create_dataset import XASDataset

INFO:rdkit:Enabling RDKit 2024.03.5 jupyter extensions


In [2]:
def train(train_loader, model, criterion, optimizer, epoch, device):

    model.train()

    loss_all = 0

    for batch in train_loader:
        batch = batch.to(device)

        optimizer.zero_grad()

        output = model(batch.z, batch.x, batch.batch)

        target = torch.Tensor(batch.spectrum).to(device)

        train_loss = criterion(output.double(), target.double())

        loss_all += train_loss.data * batch.num_graphs

        train_loss.backward()

        optimizer.step()

    return loss_all
    
def validate(val_loader, model, criterion, epoch, device):
    losses_all = 0

    model.eval()

    for batch in val_loader:
        batch = batch.to(device)

        output = model(batch.z, batch.x, batch.batch)

        target = torch.Tensor(batch.spectrum).to(device)

        val_loss = criterion(output.double(), target.double())

        losses_all += val_loss.data * batch.num_graphs

    return losses_all

In [3]:
data = XASDataset('./')

  if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
  if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):


In [4]:
train_dataset = data[0:10]#113249]
val_dataset = data[11:20]#113249:119541]

train_loader = DataLoader(train_dataset, batch_size=50, shuffle=False, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=50, shuffle=True, num_workers=0)

conv = SchNet(hidden_channels=300, num_filters=200, num_interactions=6,
                  num_gaussians=50, cutoff=10.0)

In [5]:
for batch in train_loader:
    print(batch)
    break

DataBatch(x=[182, 3], edge_index=[2, 182], edge_attr=[182, 4], spectrum=[3000], z=[182], idx=[10], smiles=[10], batch=[182], ptr=[11])


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = conv.to(device)

In [7]:
model

SchNet(hidden_channels=300, num_filters=200, num_interactions=6, num_gaussians=50, cutoff=10.0)

In [8]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                       factor=0.9, patience=5,
                                                       min_lr=0.000001)
criterion = nn.MSELoss()

In [9]:
loss_train = []

for epoch in range(300):
    lr = scheduler.optimizer.param_groups[0]['lr']

    train_loss = (train(train_loader, model, criterion, optimizer, epoch, device)/(len(train_dataset)))

    print(train_loss)

    val_loss = (validate(val_loader, model, criterion, epoch, device)/len(val_dataset))

    print(val_loss)

    scheduler.step(val_loss)
    

  return F.mse_loss(input, target, reduction=self.reduction)


tensor(46.5009, dtype=torch.float64)
tensor(12117.4524, dtype=torch.float64)


  return F.mse_loss(input, target, reduction=self.reduction)


tensor(19513.6994, dtype=torch.float64)
tensor(966.0238, dtype=torch.float64)
tensor(1383.2972, dtype=torch.float64)
tensor(90.0306, dtype=torch.float64)
tensor(128.6611, dtype=torch.float64)
tensor(791.1295, dtype=torch.float64)
tensor(1073.4461, dtype=torch.float64)
tensor(828.5694, dtype=torch.float64)
tensor(1076.2567, dtype=torch.float64)
tensor(373.3004, dtype=torch.float64)
tensor(441.6027, dtype=torch.float64)
tensor(28.8363, dtype=torch.float64)
tensor(21.8099, dtype=torch.float64)
tensor(125.8554, dtype=torch.float64)
tensor(259.8282, dtype=torch.float64)
tensor(289.6558, dtype=torch.float64)
tensor(515.6080, dtype=torch.float64)
tensor(172.9860, dtype=torch.float64)
tensor(296.0455, dtype=torch.float64)
tensor(30.5391, dtype=torch.float64)
tensor(51.6172, dtype=torch.float64)
tensor(11.8333, dtype=torch.float64)
tensor(9.7059, dtype=torch.float64)
tensor(78.0135, dtype=torch.float64)
tensor(101.0262, dtype=torch.float64)
tensor(132.8564, dtype=torch.float64)
tensor(176.8667,