### Retrain M3GNet

In [None]:
# Pol Benítez Colominas, March 2024
# Universitat Politècnica de Catalunya

# Script to re-train M3GNet from the dataset created with DFT data
# IMPORTANT: this code is based in the code developed by Cibrán: https://github.com/CibranLopez/m3gnet

Import necessary modules

In [None]:
import os
import glob
import shutil
import warnings

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import pytorch_lightning as pl

from pytorch_lightning.loggers import CSVLogger

from pymatgen.io.vasp.outputs import Vasprun
from pymatgen.core.structure import Structure

import matgl
from matgl.ext.pymatgen import Structure2Graph, get_element_list
from matgl.graph.data import M3GNetDataset, MGLDataLoader, collate_fn_efs
from matgl.utils.training import PotentialLightningModule

warnings.simplefilter('ignore')

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Define materials with DFT data, here, we have a directory of data, that contains folders for each of the materials, and for each material contains a different number of folders with vasprun.xml files, for example: data_training/Ag3SCl/01/vasprun.xml

In [None]:
#materials = ['Ag3SCl', 'Ag3SBr', 'Ag3SI', 'Ag3SeCl', 'Ag3SeBr', 'Ag3SeI', 
#             'Cu3SCl', 'Cu3SBr', 'Cu3SI', 'Cu3SeCl', 'Cu3SeBr', 'Cu3SeI']

materials = ['Ag3SeBr']

data_path = 'data_training/'

Save six features for each ionic step, the material, the phase, the structure, the energy, the forces and the stresses

In [None]:
material_name = []
phase_number = []
structures = []
energies = []
forces = []
stresses = []

In [None]:
# generate a file where we store the number of data rows (accumulated) for each material
info_data = open('info_data.txt', 'w')
info_data.write('Material   #rows\n')
num_rows = 0

# run for each material
for mat in materials:
    phases = [d for d in os.listdir(data_path + mat) if os.path.isdir(os.path.join(data_path + mat, d))]

    # run for each phase of the given material
    for phase in phases:
        try:
            vasprun = Vasprun(data_path + mat + '/' + phase + '/vasprun.xml', exception_on_bad_xml=False)
        except:
            print('Error: vasprun not correctly loaded.')
            continue

        # save the desired features for each ionic step of each phase of the given material
        for step in vasprun.ionic_steps:
            material_name.append(mat)
            phase_number.append(phase)
            structures.append(step['structure'])
            energies.append(step['electronic_steps'][-1]['e_fr_energy'])
            forces.append(step['forces'])
            stresses.append(step['stress'])

            num_rows = num_rows + 1
    
    info_data.write(f'{mat}   {num_rows:06d}\n')

info_data.close()

Save the data in a pandas dataframe

In [None]:
data = {
    'material': material_name,
    'phase': phase_number,
    'structure': structures,
    'energy': energies,
    'force': forces,
    'stress': stresses
}

df_data = pd.DataFrame(data)

In [None]:
df_data.head()

Split in train, validation and test set

In [None]:
# identify all the different simulations, since all the ionic steps of one simulation should be in the same 
# train-validation-test set

df_data['material_phase'] = df_data['material'] + '_' + df_data['phase']

unique_elements = df_data['material_phase'].unique()
unique_elements = np.array(unique_elements)
np.random.shuffle(unique_elements)

print(unique_elements)

In [None]:
# assign the desired proportions of data for train-validation-test, note that ass steps are grouped by
# chunks of the same simulation, the final proportions may be slightly different
train_prop = 0.7
val_prop = 0.15
test_prop = 0.15

train_set = pd.DataFrame(columns=df_data.columns)
validation_set = pd.DataFrame(columns=df_data.columns)
test_set = pd.DataFrame(columns=df_data.columns)

total_rows = len(df_data)
num_rows = 0
for element in unique_elements:
    if (num_rows/total_rows) <= train_prop:
        new_elements = df_data[df_data['material_phase'] == element]

        train_set = pd.concat([train_set, new_elements], ignore_index=True)

        num_rows = len(train_set)
    elif ((num_rows/total_rows) > train_prop) and ((num_rows/total_rows) <= (val_prop + train_prop)):
        new_elements = df_data[df_data['material_phase'] == element]

        validation_set = pd.concat([validation_set, new_elements], ignore_index=True)

        num_rows = len(train_set) + len(validation_set)
    elif (num_rows/total_rows) > (val_prop + train_prop):
        new_elements = df_data[df_data['material_phase'] == element]

        test_set = pd.concat([test_set, new_elements], ignore_index=True)

        num_rows = len(train_set) + len(validation_set) + len(test_set)

n_test       = len(test_set)
n_validation = len(validation_set)
n_train      = len(train_set)

print(f'Using {n_train} samples to train, {n_validation} to evaluate, and {n_test} to test')


Convert into graphs and define parameters

In [None]:
model_load_path = 'M3GNet-MP-2021.2.8-PES'
model_save_path = 'finetuned_model'

stress_weight = 0 # consider stresses or not
batch_size = 128
max_epochs = 50
lr = 1e-4 # learning rate

In [None]:
all_data = []
for i in range(3):  # Iterate over train-validation-test sets
    name    = ['train', 'val', 'test'][i]
    dataset = [train_set, validation_set, test_set][i]

    # extract data from dataset
    structures = dataset.loc[:,'structure'].values.tolist()
    element_types = get_element_list(structures)
    converter = Structure2Graph(element_types=element_types, cutoff=5.0)

    # define data labels from dataset
    if stress_weight == 0:
        stresses = [np.zeros((3, 3)).tolist() for s in structures]
    else:
        stresses = dataset.loc[:,'stress'].values.tolist()

    labels = {
        'energies': dataset.loc[:,'energy'].values.tolist(),
        'forces':   dataset.loc[:,'force'].values.tolist(),
        'stresses': stresses,
    }

    # generate dataset
    data = M3GNetDataset(
        filename=f'dgl_graph-{name}.bin',
        filename_line_graph=f'dgl_line_graph-{name}.bin',
        filename_state_attr=f'state_attr-{name}.pt',
        filename_labels=f'labels-{name}.json',
        threebody_cutoff=4.0,
        structures=structures,
        converter=converter,
        labels=labels,
        name=f'M3GNetDataset-{name}',
    )
    all_data.append(data)

train_data, val_data, test_data = all_data

In [None]:
train_loader, val_loader, test_loader = MGLDataLoader(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    collate_fn=collate_fn_efs,
    batch_size=batch_size,
    num_workers=1,
    pin_memory=True,
)

Retrain the model

In [None]:
m3gnet_nnp       = matgl.load_model(model_load_path)
model_pretrained = m3gnet_nnp.model

lit_module_finetune = PotentialLightningModule(model=model_pretrained,
                                               stress_weight=stress_weight,
                                               loss='mse_loss', 
                                               lr=lr)

In [None]:
logger  = CSVLogger('logs', 
                    name='M3GNet_finetuning')

trainer = pl.Trainer(max_epochs=max_epochs, 
                     accelerator='auto', 
                     logger=logger, 
                     inference_mode=False)

trainer.fit(model=lit_module_finetune, 
            train_dataloaders=train_loader, 
            val_dataloaders=val_loader)

# Save trained model
model_pretrained.save(model_save_path)

### Analyze metrics

In [None]:
# E_MAE = meV/atom, F_MAE = eV/A, S_MAE = GPa
trainer.test(model=lit_module_finetune,
            dataloaders=test_loader
           )

In [None]:
# Read the CSV file
current_version = 0
path_to_csv = f'logs/M3GNet_finetuning/version_{current_version}'
df = pd.read_csv(f'{path_to_csv}/metrics.csv')
df.head()

In [None]:
# NaN to zero
df = df.fillna(0)

# Calculate the sum of every two consecutive rows
df = df.groupby(df.index // 2).sum()
df.head()

In [None]:
# Get the list of loss column names
loss_columns = [col for col in df.columns if col.startswith('val_') or col.startswith('train_')]

# Create a figure and axis
fig = plt.subplots(figsize=(10, 6))

# Plot each loss
for loss_column in loss_columns:
    plt.plot(df.index, np.log(df[loss_column]), label=loss_column)

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc=(1.01, 0))
plt.savefig(f'm3gnet_loss.eps', dpi=100, bbox_inches='tight')
plt.show()

### Cleanup the notebook

In [None]:
# This code just performs cleanup for this notebook from temporal files

patterns = ['dgl_graph*.bin', 'dgl_line_graph*.bin', 'state_attr*.pt', 'labels*.json', '*labels.txt']
for pattern in patterns:
    files = glob.glob(pattern)
    for file in files:
        try:
            os.remove(file)
        except FileNotFoundError:
            pass

shutil.rmtree('logs')
#shutil.rmtree('trained_model')
#shutil.rmtree('finetuned_model')