## Imports

In [None]:
# Import NFF package from path

import sys
sys.path.append("./NFF")
from nff.data import Dataset, collate_dicts
from nff.train import Trainer, get_model, loss, hooks, metrics, evaluate

In [None]:
import os
import shutil
import numpy as np
import scipy
import json

import matplotlib.pyplot as plt
import pandas as pd

import torch
from torch.optim import Adam
from torch.utils.data import DataLoader

from evaluate import make_scatterplot, get_stats

In [None]:
# output folder
OUTDIR = "./Model_Adiabatic"
# make subfolders
try:
    os.makedirs(os.path.join(OUTDIR, "0"))
except:
    print("The folders already exists! This will overwrite what has been written there!")

In [None]:
# set CUDA as a device if available
DEVICE = 'cpu'

In [None]:
# Load the dataset
train = Dataset.from_file('./train.pth.tar')
test = Dataset.from_file('./test.pth.tar')
val = Dataset.from_file('./val.pth.tar')

In [None]:
# all the output keys
output_keys = []
grad_keys = []
for i in range(3):
    output_keys.append(f"energy_{i}")
    grad_keys.append(f"energy_{i}_grad")


In [None]:
# get and updateparameters
with open("./default_params.json", 'r') as f:
    params = json.load(f)

params.update({"compute_delta": False,
                "output_keys": output_keys,
                "grad_keys": grad_keys,
                'details': None})

# Serializing json
json_object = json.dumps(params, indent=4)

# Writing to sample.json
with open(os.path.join(OUTDIR, "params.json"), "w") as outfile:
    outfile.write(json_object)

In [None]:
# define loss function
# MSE loss for energies and forces with a weighting 1:10
# and additionally an MAE loss on the energy gap between neighbouring adiabatic states
multi_loss_dict = {"mse": [{"coef": 0.1, "params": {"key": key, "loss_type": "mae",}} for key in output_keys]+
                   [{"coef": 1.0, "params": {"key": key, "loss_type": "mse",}} for key in grad_keys],
                   #"diff": [{"coef": 0.1, "params": {"keys": [output_keys[ii+1], output_keys[ii]], "loss_type": "mae",}} for ii in range(len(output_keys)-1) if "energy_S" in output_keys[ii+1] ]
                    }


In [None]:
# The metric is the MAE, logged while training
train_metrics = [
    metrics.MeanAbsoluteError(outkey) for outkey in output_keys]
train_metrics += [
    metrics.MeanAbsoluteError(outkey) for outkey in grad_keys]

In [None]:
# make torch loaders for the splits
train_loader = DataLoader(train, batch_size=params['batch_size'], collate_fn=collate_dicts)
val_loader = DataLoader(val, batch_size=params['batch_size'], collate_fn=collate_dicts)
test_loader = DataLoader(test, batch_size=params['batch_size'], collate_fn=collate_dicts)

In [None]:
# activate a new model
model = get_model(params, model_type="Painn")

In [None]:
# initialize new optimzers etc.
loss_fn = loss.build_multi_loss(multi_loss_dict = multi_loss_dict)
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = Adam(trainable_params, lr=params['lr'])

In [None]:
# set hooks, maximum number of epochs, how to change learning rate depending on convergence
train_hooks = [
hooks.MaxEpochHook(params['max_epochs']),
hooks.CSVHook(
    OUTDIR,
    metrics=train_metrics,
    ),
    hooks.PrintingHook(
        OUTDIR,
        metrics=train_metrics,
        separator = ' | ',
        time_strf='%M:%S'
    ),
hooks.ReduceLROnPlateauHook(
        optimizer=optimizer,
        patience=params['lr_patience'],
        factor=params['lr_decay'],
        min_lr=params['lr_min'],
        window_length=1,
        stop_after_min=True
    )
]

In [None]:
# initialize trainer class
T = Trainer(
model_path=OUTDIR,
model=model,
loss_fn=loss_fn,
optimizer=optimizer,
train_loader=train_loader,
validation_loader=val_loader,
checkpoint_interval=1,
hooks=train_hooks
)

In [None]:
# do the actual training
T.train(device=DEVICE, n_epochs=params['max_epochs'])

In [None]:
metrics = pd.read_csv(os.path.join(OUTDIR, 'log.csv'))
metrics.columns

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(14,5), sharex=True, sharey=False)

blues = ['#00429d', '#346cc2', '#5399e8']
reds = ['#94003a', '#c82c46', '#ff4e52']

axs[0].plot(metrics['Train loss'], label='Train loss')
axs[0].plot(metrics['Validation loss'], label='Val loss')

for state in range(3):
    axs[1].plot(metrics[f'MAE_energy_{state}'], label=f'S{state}', color=blues[state])
    axs[2].plot(metrics[f'MAE_energy_{state}_grad'], label=f'S{state}', color=reds[state])

for ax in axs:
    ax.tick_params(axis='y',length=6,width=3,labelsize=20, pad=10, direction='in')
    ax.tick_params(axis='x',length=6,width=3,labelsize=20, pad=10, direction='in')
    for key in ax.spines.keys():
        ax.spines[key].set_linewidth(3)
    #ax.set_ylim([0, 0.5])
    ax.set_xlabel("Epoch", fontsize=20)
    ax.legend(frameon=False, fontsize=16)

axs[1].set_title("Energy MAE", fontsize=20)
axs[2].set_title("Forces MAE", fontsize=20)
axs[0].set_ylabel(r"Error", fontsize=20)


plt.legend(frameon=False)
plt.savefig(os.path.join(OUTDIR, "learning_curve.png"), dpi=300)
plt.show()

## Calculate Test Stats

In [None]:
results_targets_testloss = evaluate(T.get_best_model(), test_loader, loss_fn, device=DEVICE)
pred_dict, targ_dict, _ = results_targets_testloss

In [None]:
test_stats = get_stats(targ_dict, pred_dict, output_keys, grad_keys)
# Serializing json
json_object = json.dumps(test_stats, indent=4)

# Writing to sample.json
with open(os.path.join(OUTDIR, "evaluate.json"), "w") as outfile:
    outfile.write(json_object)

In [None]:
print(test_stats['energy'])
print(test_stats['delta_energy'])
print(test_stats['energy_grad'])

In [None]:
make_scatterplot(os.path.join(OUTDIR, "Scatter_Test.png"), 
                    targ_dict,
                    pred_dict)