# Fine-tune the pretrained CHGNet for better accuracy


In [1]:
from sklearn.model_selection import KFold, train_test_split
import time
import sys
from chgnet.data.dataset import StructureData, get_train_val_test_loader
import shutil
sys.path.append("../../")
from utils.save_and_load import load_from_json, save_to_json
from utils.experiment_tracking import log_mean_std_based_on_test_metrics

In [2]:
try:
    from chgnet.model import CHGNet
except ImportError:
    # install CHGNet (only needed on Google Colab or if you didn't install CHGNet yet)
    # !pip install chgnet
    pass

In [3]:
import numpy as np
from pymatgen.core import Structure

# If the above line fails in Google Colab due to numpy version issue,
# please restart the runtime, and the problem will be solved

In [4]:
import os
import re
import warnings
from typing import TYPE_CHECKING
import torch
import random

from monty.io import zopen
from monty.os.path import zpath
from pymatgen.io.vasp.outputs import Oszicar, Vasprun

from chgnet.utils import write_json
from aim import Run

if TYPE_CHECKING:
    from pymatgen.core import Structure


In [5]:
# # ./my_vasp_calc_dir contains vasprun.xml OSZICAR etc.
# dataset_dict = parse_vasp_dir(
#     base_dir="./my_vasp_calc_dir", save_path="./my_vasp_calc_dir/chgnet_dataset.json"
# )
# print(list(dataset_dict))
import orjson


def load_ds_and_train(run: Run, ds_name: str = "dataset_of_max_abs_shg"):
    dataset_dict = load_from_json(f"../../data/final_data/{ds_name}.json")

    # with open("../data_shg/total_ds_no_qmof_no_max_abs.json", "rb") as f:
    #     dataset_dict = orjson.loads(f.read())

    dataset_dict = {
        "structure": [
            Structure.from_dict(v["structure"]) for k, v in dataset_dict.items()
        ],
        "energy_per_atom": [v["shg"] for k, v in dataset_dict.items()],  # aka shg
        "idxes": [k for k, v in dataset_dict.items()],
    }

    # Structure to json

    dict_to_json = [struct.as_dict() for struct in dataset_dict["structure"]]
    write_json(dict_to_json, "CHGNet_structures.json")

    # Structure to pickle
    import pickle

    with open("CHGNet_structures.p", "wb") as f:
        pickle.dump(dataset_dict, f)

    os.makedirs("cif_dir", exist_ok=True)
    # Structure to cif
    for idx, struct in enumerate(dataset_dict["structure"]):
        struct.to(filename=f"cif_dir/{idx}.cif")

    # Structure to CHGNet graph
    from chgnet.graph import CrystalGraphConverter

    converter = CrystalGraphConverter()
    for idx, struct in enumerate(dataset_dict["structure"]):
        graph = converter(struct)
        graph.save(fname=f"cif_dir/{idx}.pt")

        # from chgnet.utils import read_json
    ## 1. Prepare Training Data

    # dataset_dict = read_json("./my_vasp_calc_dir/chgnet_dataset.json")
    structures = [struct for struct in dataset_dict["structure"]]
    structure_names = [structure_name for structure_name in dataset_dict["idxes"]]
    shg = dataset_dict["energy_per_atom"]
    forces = dataset_dict.get("force") or None
    stresses = dataset_dict.get("stress") or None
    magmoms = dataset_dict.get("magmom") or None

    ## 2. Define DataSet

    dataset = StructureData(
        structures=structures,
        energies=shg,
        forces=forces,
        stresses=stresses,  # can be None
        magmoms=magmoms,  # can be None
    )

    # enable deterministic learning
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
    random.seed(42)
    np.random.seed(42)
    %env CUBLAS_WORKSPACE_CONFIG=:4096:8

    k_folds = 10
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)
    data_range = np.arange(0, len(dataset))

    requires_grad = True
    run["hparams"] = dict(requires_grad=requires_grad)

    fold_partition = {}
    fold_partition_names = {}
    crystal_predictions_log_for_all_folds = {}
    run.set_artifacts_uri(f"file://{os.getcwd()}/artifacts/")
    for fold, (train_val_idx, test_idx) in enumerate(kfold.split(data_range)):
        print(f"FOLD {fold}")

        train_idx, val_idx = train_test_split(
            train_val_idx, train_size=8 / 9, random_state=42
        )

        train_loader, val_loader, test_loader = get_train_val_test_loader(
            dataset,
            batch_size=8,
            train_idx=train_idx,
            val_idx=val_idx,
            test_idx=test_idx,
        )

        ## 3. Define model and trainer
        from chgnet.model import CHGNet
        from chgnet.trainer import Trainer

        # Load pretrained CHGNet
        model = CHGNet.load()
        run["hparams"] = run["hparams"] | model.model_args

        # Optionally fix the weights of some layers
        for layer in [
            model.atom_embedding,
            model.bond_embedding,
            model.angle_embedding,
            model.bond_basis_expansion,
            model.angle_basis_expansion,
            model.atom_conv_layers[:-1],
            model.bond_conv_layers,
            model.angle_layers,
        ]:
            for param in layer.parameters():
                param.requires_grad = requires_grad

        run["model_parameters_count"] = sum(p.numel() for p in model.parameters())
        run["model_trainable_parameters_count"] = sum(
            p.numel() for p in model.parameters() if p.requires_grad
        )

        train_idx, val_idx = train_test_split(
            train_val_idx, train_size=8 / 9, random_state=42
        )

        # log fold partition
        fold_partition = dict(
            fold=fold_partition.get("fold", {})
            | {
                fold: dict(
                    train_idx=train_idx.tolist(),
                    val_idx=val_idx.tolist(),
                    test_idx=test_idx.tolist(),
                )
            }
        )
        fold_partition_names = dict(
            fold=fold_partition_names.get("fold", {})
            | {
                fold: dict(
                    train_names=[structure_names[x] for x in train_idx.tolist()],
                    val_names=[structure_names[x] for x in val_idx.tolist()],
                    test_names=[structure_names[x] for x in test_idx.tolist()],
                )
            }
        )

        # log dataset once for all folds
        dataset_key_target = {
            structure_names[idx]: v[1]["e"].data.cpu().item()
            for idx, v in enumerate(dataset)
        }

        # Define Trainer
        trainer = Trainer(
            model=model,
            targets="e",
            optimizer="Adam",
            scheduler="CosLR",
            criterion="MAE",
            # criterion="MSE",
            epochs=20,
            learning_rate=1e-2,
            use_device="cpu",
            print_freq=20,
        )
        trainer.run = run
        trainer.fold = fold
        run["trainer_args"] = trainer.trainer_args

        ## 4. Start training
        trainer.crystal_predictions_log = {}

        trainer.train(train_loader, val_loader, test_loader)

        crystal_predictions_log_for_all_folds = dict(
            fold=crystal_predictions_log_for_all_folds.get("fold", {})
            | {fold: trainer.crystal_predictions_log.copy()}
        )
        # print(trainer.crystal_predictions_log.copy())
    log_mean_std_based_on_test_metrics(run)
    jsons_to_log = dict(
        fold_partition=fold_partition,
        fold_partition_names=fold_partition_names,
        dataset_key_target=dataset_key_target,
        crystal_predictions_log_for_all_folds=crystal_predictions_log_for_all_folds,
    )
    json_name = "train_info.json"
    save_to_json(jsons_to_log, json_name)
    run.log_artifact(json_name)
    os.makedirs("trained", exist_ok=True)
    os.makedirs(f"trained/{ds_name}", exist_ok=True)
    shutil.move(
        json_name,
        f"trained/{ds_name}/{json_name}",
    )
    run.close()

    # model = trainer.model
    # best_model = trainer.best_model  # best model based on validation energy MAE

In [6]:
# Initialize a new run
run = Run(
    experiment="CHGNet FT from 0.3.0 on dataset_of_max_abs_shg",
    log_system_params=True,
)
load_ds_and_train(run)

StructureData imported 522 structures
env: CUBLAS_WORKSPACE_CONFIG=:4096:8
FOLD 0
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on cpu
Begin Training: using cpu device
training targets: e
Epoch: [0][1/52] | Time (0.106)(0.004) | Loss 16.3262(16.3262) | MAE e 16.326(16.326)  
Epoch: [0][20/52] | Time (0.075)(0.000) | Loss 5.1322(6.4583) | MAE e 5.132(6.458)  
Epoch: [0][40/52] | Time (0.075)(0.000) | Loss 1.0089(5.8708) | MAE e 1.009(5.871)  
*   e_MAE (5.382) 	
Epoch: [1][1/52] | Time (0.077)(0.001) | Loss 7.3779(7.3779) | MAE e 7.378(7.378)  
Epoch: [1][20/52] | Time (0.071)(0.000) | Loss 7.6233(4.4741) | MAE e 7.623(4.474)  
Epoch: [1][40/52] | Time (0.072)(0.000) | Loss 1.1731(4.3974) | MAE e 1.173(4.397)  
*   e_MAE (4.712) 	
Epoch: [2][1/52] | Time (0.077)(0.001) | Loss 1.1429(1.1429) | MAE e 1.143(1.143)  
Epoch: [2][20/52] | Time (0.073)(0.000) | Loss 4.9454(3.7259) | MAE e 4.945(3.726)  
Epoch: [2][40/52] | Time (0.074)(0.000) | Loss 6.2960(3.8466) | MAE e 6

In [7]:
run.close()

In [8]:
# Initialize a new run
ds_name = "base_dataset_of_eff_shg"
run = Run(
    experiment=f"CHGNet FT from 0.3.0 on {ds_name}",
    log_system_params=True,
)
load_ds_and_train(run, ds_name)

StructureData imported 974 structures
env: CUBLAS_WORKSPACE_CONFIG=:4096:8
FOLD 0
CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on cpu
Begin Training: using cpu device
training targets: e
Epoch: [0][1/98] | Time (2.295)(0.002) | Loss 0.5222(0.5222) | MAE e 0.522(0.522)  
Epoch: [0][20/98] | Time (1.787)(0.001) | Loss 0.5020(2.3136) | MAE e 0.502(2.314)  
Epoch: [0][40/98] | Time (1.998)(0.001) | Loss 5.9953(2.0341) | MAE e 5.995(2.034)  
Epoch: [0][60/98] | Time (1.812)(0.000) | Loss 0.5759(1.7822) | MAE e 0.576(1.782)  
Epoch: [0][80/98] | Time (1.754)(0.000) | Loss 0.1040(1.6240) | MAE e 0.104(1.624)  
*   e_MAE (0.702) 	
Epoch: [1][1/98] | Time (1.756)(0.003) | Loss 0.0870(0.0870) | MAE e 0.087(0.087)  
Epoch: [1][20/98] | Time (1.668)(0.001) | Loss 0.7858(2.4096) | MAE e 0.786(2.410)  
Epoch: [1][40/98] | Time (1.558)(0.000) | Loss 3.8590(1.8394) | MAE e 3.859(1.839)  
Epoch: [1][60/98] | Time (1.583)(0.000) | Loss 0.1314(7.6140) | MAE e 0.131(7.614)  
Epoch: [1

The training set is used to optimize the CHGNet through gradient descent, the validation set is used to see validation error at the end of each epoch, and the test set is used to see the final test error at the end of training. The test set can be optional.

The `batch_size` is defined to be 8 for small GPU-memory. If > 10 GB memory is available, we highly recommend to increase `batch_size` for better speed.

If you have very large numbers (>100K) of structures (which is typical for AIMD), putting them all in a python list can quickly run into memory issues. In this case we highly recommend you to pre-convert all the structures into graphs and save them as shown in `examples/make_graphs.py`. Then directly train CHGNet by loading the graphs from disk instead of memory using the `GraphData` class defined in `data/dataset.py`.


It's optional to freeze the weights inside some layers. This is a common technique to retain the learned knowledge during fine-tuning in large pretrained neural networks. You can choose the layers you want to freeze.


After training, the trained model can be found in the directory of today's date. Or it can be accessed by:
