In [None]:
%load_ext autoreload
%autoreload 2

import json
import random
import torch

import numpy as np
import pandas as pd

from jarvis.core.graphs import Graph
from jarvis.core.atoms import Atoms

from pymatgen.io.jarvis import JarvisAtomsAdaptor
from pymatgen.core import Structure

#from alignn.models.alignn import ALIGNN
from alignn_multi import ALIGNN

from tqdm.notebook import tqdm_notebook

import dgl

In [None]:
def atoms_to_graph(atoms, cutoff=6.0, max_neighbors=12,
    atom_features="cgcnn", use_canonize=True):
    """Convert structure dict to DGLGraph."""
    #structure = Atoms.from_dict(atoms)
    structure = JarvisAtomsAdaptor.get_atoms(Structure.from_dict(atoms))
    return Graph.atom_dgl_multigraph(
        structure,
        cutoff=cutoff,
        atom_features=atom_features,
        max_neighbors=max_neighbors,
        compute_line_graph=True,
        use_canonize=use_canonize,
    )

def group_decay(model):
    """Omit weight decay from bias and batchnorm params."""
    decay, no_decay = [], []

    for name, p in model.named_parameters():
        if "bias" in name or "bn" in name or "norm" in name:
            no_decay.append(p)
        else:
            decay.append(p)

    return [
        {"params": decay},
        {"params": no_decay, "weight_decay": 0},
    ]

def collate_line_graph(samples):
        """Dataloader helper to batch graphs cross `samples`."""
        graphs, line_graphs, labels = map(list, zip(*samples))
        batched_graph = dgl.batch(graphs)
        batched_line_graph = dgl.batch(line_graphs)
        print(labels[0])
        print(labels[0].size())
        if len(labels[0].size()) > 0:
            return batched_graph, batched_line_graph, torch.stack(labels)
        else:
            return batched_graph, batched_line_graph, torch.tensor(labels)

In [None]:
data = '../data/No-dup-complete_dataset_100.json'
with open(data, "rb") as f:
    dataset = json.loads(f.read())
    
for datum in tqdm_notebook(dataset):
    datum['atoms'] = atoms_to_graph(datum['structure'], cutoff=10.0)
    datum['has_prop'] = torch.FloatTensor(datum['OH']) #torch.FloatTensor(random.choices([0, 1], k=6))
    datum['target'] = torch.FloatTensor(datum['prop_list'])#random.sample(list(np.arange(-1, 1, 1e-4)), 6)

In [None]:
model = ALIGNN(n_outputs=7)
device = "cpu"
if torch.cuda.is_available():
    device = torch.device("cuda")

## Load an old model

Uncomment the next lines if you want to reload a model

In [None]:
#model.load_state_dict(torch.load('./best_model.pt', map_location=torch.device(device)))

In [None]:
train_split = 60
test_split = 80

Xtrain = [d['atoms'] for d in dataset[:train_split]]
Ptrain = [d['has_prop'] for d in dataset[:train_split]]
ytrain = [d['target'] for d in dataset[:train_split]]

Xval = [d['atoms'] for d in dataset[train_split:test_split]]
Pval = [d['has_prop'] for d in dataset[train_split:test_split]]
yval = [d['target'] for d in dataset[train_split:test_split]]

Xtest = [d['atoms'] for d in dataset[test_split:]]
Ptest = [d['has_prop'] for d in dataset[test_split:]]
ytest = [d['target'] for d in dataset[test_split:]]

### Define Train/Eval Loops

In [None]:
# First set up the optimiser, loss and device

criterion = torch.nn.L1Loss()
params = group_decay(model)
optimizer = torch.optim.AdamW(params, lr=1e-4)



In [None]:
n_epochs = 12    # number of epochs to run
batch_size = 1  # size of each batch
batches_per_epoch = len(Xtrain) // batch_size

training_losses = []
validation_losses = []
best_val_loss = np.inf
for epoch in range(n_epochs):
    training_loss = 0
    print('Epoch Number ', epoch)
    for i in tqdm_notebook(range(batches_per_epoch)):
        start = i * batch_size
        # take a batch
        Xbatch = Xtrain[start:start+batch_size]
        Propbatch = Ptrain[start:start+batch_size]
        ybatch = ytrain[start:start+batch_size]
        # forward pass
        y_pred = model(Xbatch[0], Propbatch[0])
        #print(y_pred, ybatch)
        loss = criterion(y_pred, ybatch[0])
        training_loss += loss
        # backward pass
        optimizer.zero_grad()
        loss.backward()
        # update weights
        optimizer.step()
    training_loss = training_loss / len(Xtrain)
    print('Training loss: %.3f'% training_loss.item())
    training_losses.append(training_loss.item())
    val_loss = 0
    for i, true in enumerate(yval):
        y_pred = model(Xval[i], Pval[i])
        val_loss += criterion(true, y_pred)
    val_loss = val_loss / i
    print('Validation loss: %.3f'% val_loss.item())
    validation_losses.append(val_loss.item())
    if val_loss.item() < best_val_loss:
        best_val_loss = val_loss.item()
        print('Model improved, saving')
        torch.save(model.state_dict(), 'best_model.pt')
        