# Fine-tune the pretrained CHGNet for better accuracy


In [None]:
try:
    from chgnet.model import CHGNet
except ImportError:
    # install CHGNet (only needed on Google Colab or if you didn't install CHGNet yet)
    !pip install chgnet

In [None]:
import numpy as np
from pymatgen.core import Structure

# If the above line fails in Google Colab due to numpy version issue,
# please restart the runtime, and the problem will be solved

## 0. Parse DFT outputs to CHGNet readable formats


CHGNet is interfaced to [Pymatgen](https://pymatgen.org/), the training samples (normally coming from different DFTs like VASP),
need to be converted to [pymatgen.core.structure](https://pymatgen.org/pymatgen.core.html#module-pymatgen.core.structure).

To convert VASP calculation to pymatgen structures and CHGNet labels, you can use the following [code](https://github.com/CederGroupHub/chgnet/blob/main/chgnet/utils/vasp_utils.py):


In [None]:
from chgnet.utils import parse_vasp_dir

# ./my_vasp_calc_dir contains vasprun.xml OSZICAR etc.
dataset_dict = parse_vasp_dir(
    file_root="./my_vasp_calc_dir", save_path="./my_vasp_calc_dir/chgnet_dataset.json"
)
print(list(dataset_dict))

The parsed python dictionary includes information for CHGNet inputs (structures), and CHGNet prediction labels (energy, force, stress ,magmom).

we can save the parsed structures and labels to disk, so that they can be easily reloaded during multiple rounds of training.

The json file can be saved by providing the save_path


The Pymatgen structures can be saved separately if you're interested to take a look into each structure.

Below are the example codes to save the structures in either json, pickle, cif, or CHGNet graph.

For super-large training dataset, like MPtrj dataset, we recommend [converting them to CHGNet graphs](https://github.com/CederGroupHub/chgnet/blob/main/examples/make_graphs.py). This will save significant memory and graph computing time.


In [None]:
# Structure to json
from chgnet.utils import write_json

dict_to_json = [struct.as_dict() for struct in dataset_dict["structure"]]
write_json(dict_to_json, "CHGNet_structures.json")


# Structure to pickle
import pickle

with open("CHGNet_structures.p", "wb") as f:
    pickle.dump(dataset_dict, f)


# Structure to cif
for idx, struct in enumerate(dataset_dict["structure"]):
    struct.to(filename=f"{idx}.cif")


# Structure to CHGNet graph
from chgnet.graph import CrystalGraphConverter

converter = CrystalGraphConverter()
for idx, struct in enumerate(dataset_dict["structure"]):
    graph = converter(struct)
    graph.save(fname=f"{idx}.pt")

For other types of DFT calculations, please refer to their interfaces
in [pymatgen.io](https://pymatgen.org/pymatgen.io.html#module-pymatgen.io).

see: [Quantum Espresso](https://pymatgen.org/pymatgen.io.html#module-pymatgen.io.pwscf)

see: [CP2K](https://pymatgen.org/pymatgen.io.cp2k.html#module-pymatgen.io.cp2k)

see: [Gaussian](https://pymatgen.org/pymatgen.io.html#module-pymatgen.io.gaussian)


## 1. Prepare Training Data


If you have parsed your VASP labels from step 0, you can reload the saved json file.


In [None]:
from chgnet.utils import read_json

dataset_dict = read_json("./my_vasp_calc_dir/chgnet_dataset.json")
structures = [Structure.from_dict(struct) for struct in dataset_dict["structure"]]
energies = dataset_dict["energy_per_atom"]
forces = dataset_dict["force"]
stresses = dataset_dict.get("stress") or None
magmoms = dataset_dict.get("magmom") or None

If you don't have any DFT calculations now, we can create a dummy fine-tuning dataset by using CHGNet prediction with some random noise.


In [None]:
try:
    from chgnet import ROOT

    lmo = Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif")
except Exception:
    from urllib.request import urlopen

    url = "https://raw.githubusercontent.com/CederGroupHub/chgnet/main/examples/mp-18767-LiMnO2.cif"
    cif = urlopen(url).read().decode("utf-8")
    lmo = Structure.from_str(cif, fmt="cif")

structures, energies_per_atom, forces, stresses, magmoms = [], [], [], [], []
chgnet = CHGNet.load()
for _ in range(100):
    structure = lmo.copy()
    # stretch the cell by a small amount
    structure.apply_strain(np.random.uniform(-0.1, 0.1, size=3))
    # perturb all atom positions by a small amount
    structure.perturb(0.1)

    pred = chgnet.predict_structure(structure)

    structures.append(structure)
    energies_per_atom.append(pred["e"] + np.random.uniform(-0.1, 0.1, size=1))
    forces.append(pred["f"] + np.random.uniform(-0.01, 0.01, size=pred["f"].shape))
    stresses.append(
        pred["s"] * -10 + np.random.uniform(-0.05, 0.05, size=pred["s"].shape)
    )
    magmoms.append(pred["m"] + np.random.uniform(-0.03, 0.03, size=pred["m"].shape))

Note that the stress output from CHGNet is in unit of GPa, here the -10 unit conversion
modifies it to be kbar in VASP raw unit.
If you're using stress labels from VASP, you don't need to do any unit conversions
StructureData dataset class takes in VASP units.


## 2. Define DataSet


In [None]:
from chgnet.data.dataset import StructureData, get_train_val_test_loader

In [None]:
dataset = StructureData(
    structures=structures,
    energies=energies_per_atom,
    forces=forces,
    stresses=stresses,  # can be None
    magmoms=magmoms,  # can be None
)
train_loader, val_loader, test_loader = get_train_val_test_loader(
    dataset, batch_size=8, train_ratio=0.9, val_ratio=0.05
)

100 structures imported


Alternatively, the dataset can be directly created from VASP calculation dir.
This function essentially parse the VASP directory first, save the labels to json file, and create the StructureData class


In [None]:
dataset = StructureData.from_vasp(
    file_root="./my_vasp_calc_dir", save_path="./my_vasp_calc_dir/chgnet_dataset.json"
)

The training set is used to optimize the CHGNet through gradient descent, the validation set is used to see validation error at the end of each epoch, and the test set is used to see the final test error at the end of training. The test set can be optional.

The `batch_size` is defined to be 8 for small GPU-memory. If > 10 GB memory is available, we highly recommend to increase `batch_size` for better speed.

If you have very large numbers (>100K) of structures (which is typical for AIMD), putting them all in a python list can quickly run into memory issues. In this case we highly recommend you to pre-convert all the structures into graphs and save them as shown in `examples/make_graphs.py`. Then directly train CHGNet by loading the graphs from disk instead of memory using the `GraphData` class defined in `data/dataset.py`.


## 3. Define model and trainer


In [None]:
from chgnet.model import CHGNet
from chgnet.trainer import Trainer

# Load pretrained CHGNet
chgnet = CHGNet.load()

CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on cpu


It's optional to freeze the weights inside some layers. This is a common technique to retain the learned knowledge during fine-tuning in large pretrained neural networks. You can choose the layers you want to freeze.


In [None]:
# Optionally fix the weights of some layers
for layer in [
    chgnet.atom_embedding,
    chgnet.bond_embedding,
    chgnet.angle_embedding,
    chgnet.bond_basis_expansion,
    chgnet.angle_basis_expansion,
    chgnet.atom_conv_layers[:-1],
    chgnet.bond_conv_layers,
    chgnet.angle_layers,
]:
    for param in layer.parameters():
        param.requires_grad = False

In [None]:
# Define Trainer
trainer = Trainer(
    model=chgnet,
    targets="efsm",
    optimizer="Adam",
    scheduler="CosLR",
    criterion="MSE",
    epochs=5,
    learning_rate=1e-2,
    use_device="cpu",
    print_freq=6,
)

## 4. Start training


In [None]:
trainer.train(train_loader, val_loader, test_loader)

Begin Training: using cpu device
training targets: efsm
Epoch: [0][1/12]	Time (0.476)  Data (0.016)  Loss 0.0033 (0.0033)  MAEs:  e 0.053 (0.053)  f 0.004 (0.004)  s 0.002 (0.002)  m 0.016 (0.016)  
Epoch: [0][6/12]	Time (0.426)  Data (0.015)  Loss 0.0040 (0.0039)  MAEs:  e 0.054 (0.056)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.015 (0.015)  
Epoch: [0][12/12]	Time (0.414)  Data (0.014)  Loss 0.0040 (0.0038)  MAEs:  e 0.054 (0.054)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.015 (0.014)  
*   e_MAE (0.028) 	f_MAE (0.006) 	s_MAE (0.002) 	m_MAE (0.015) 	
Epoch: [1][1/12]	Time (0.409)  Data (0.000)  Loss 0.0052 (0.0052)  MAEs:  e 0.064 (0.064)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.013 (0.013)  
Epoch: [1][6/12]	Time (0.393)  Data (0.000)  Loss 0.0036 (0.0039)  MAEs:  e 0.053 (0.055)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.014 (0.014)  
Epoch: [1][12/12]	Time (0.371)  Data (0.000)  Loss 0.0029 (0.0038)  MAEs:  e 0.053 (0.054)  f 0.005 (0.005)  s 0.003 (0.002)  m 0.012 (0.014)  
*   e_MAE (0.028) 	

After training, the trained model can be found in the directory of today's date. Or it can be accessed by:


In [None]:
model = trainer.model
best_model = trainer.best_model  # best model based on validation energy MAE

## Extras 1: GGA / GGA+U compatibility


### Q: Why and when do you care about this?

**When**: If you want to fine-tune the pretrained CHGNet with your own GGA+U VASP calculations, and you want to keep your VASP energy compatible to the pretrained dataset. In case your dataset is so large that the pretrained knowledge does not matter to you, you can ignore this.

**Why**: CHGNet is trained on both GGA and GGA+U calculations from Materials Project. And there has been developed methods in solving the compatibility between GGA and GGA+U calculations which makes the energies universally applicable for cross-chemistry comparison and phase-diagram constructions. Please refer to:

https://journals.aps.org/prb/abstract/10.1103/PhysRevB.84.045115

Below we show an example to apply the compatibility.


In [None]:
# Imagine this is the VASP raw energy
vasp_raw_energy = -58.97

print(f"The raw total energy from VASP of LMO is: {vasp_raw_energy} eV")

The raw total energy from VASP of LMO is: -58.97 eV


You can look for the energy correction applied to each element in :

https://github.com/materialsproject/pymatgen/blob/v2023.2.28/pymatgen/entries/MP2020Compatibility.yaml

Here LiMnO2 applies to both Mn in transition metal oxides correction and oxide correction.


To demystify `MaterialsProject2020Compatibility`, basically all that's happening is:


In [None]:
Mn_correction_in_TMO = -1.668
oxide_correction = -0.687
_, num_Mn, num_O = lmo.composition.values()


corrected_energy = (
    vasp_raw_energy + num_Mn * Mn_correction_in_TMO + num_O * oxide_correction
)
print(f"The corrected total energy after MP2020 = {corrected_energy:.4} eV")

The corrected total energy after MP2020 = -65.05 eV


You can also apply the `MaterialsProject2020Compatibility` through pymatgen


In [None]:
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
from pymatgen.entries.computed_entries import ComputedStructureEntry

params = {"hubbards": {"Mn": 3.9, "O": 0, "Li": 0}, "run_type": "GGA+U"}

cse = ComputedStructureEntry(lmo, vasp_raw_energy, parameters=params)

MaterialsProject2020Compatibility(check_potcar=False).process_entries(cse)
print(
    f"The total energy of LMO after MP2020Compatibility correction = {cse.energy:.4} eV"
)

The total energy of LMO after MP2020Compatibility correction = -62.31 eV


Now use this corrected energy as labels to tune CHGNet, you're good to go!


## Extras 2: AtomRef


### Q: Why and when do you care about this?

**When**: When you fine tune CHGNet to DFT labels that are incompatible with Materials Project, like r2SCAN functional, or other DFTs like Gaussian or QE. The large shifts in elemental energy is not of our interest and should be reconciled. For example, Li has -0.95 eV/atom in GGA (https://next-gen.materialsproject.org/materials/mp-135/tasks/mp-990455) and -1.17 eV/atom in R2SCAN (https://next-gen.materialsproject.org/materials/mp-135/tasks/mp-1943895)

**Why**: The GNN learns the interaction between the atoms and the composition model (AtomRef) in CHGNet is used to normalize the elemental energy contribution, similar to a formation-energy-like calculation. During fine-tuning, we want to keep the most of knowledge unchanged in the GNN and allow the AtomRef to shift for the elemental energy change. So that the finetuning on the graph layers can be focused on energy contribution from atom-atom interaction instead of meaningless atom reference energies.

Below I will show an example to fit the AtomRef layer:


### A quick and easy way to turn on training of AtomRef in the trainer (this is by default off):


In [None]:
trainer.train(train_loader, val_loader, test_loader, train_composition_model=True)

Begin Training: using cpu device
training targets: efsm
Epoch: [0][1/12]	Time (0.475)  Data (0.001)  Loss 0.0028 (0.0028)  MAEs:  e 0.047 (0.047)  f 0.005 (0.005)  s 0.003 (0.003)  m 0.014 (0.014)  
Epoch: [0][6/12]	Time (0.379)  Data (0.000)  Loss 0.0027 (0.0037)  MAEs:  e 0.046 (0.053)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.015 (0.014)  
Epoch: [0][12/12]	Time (0.359)  Data (0.000)  Loss 0.0010 (0.0038)  MAEs:  e 0.030 (0.054)  f 0.005 (0.005)  s 0.003 (0.002)  m 0.012 (0.014)  
*   e_MAE (0.028) 	f_MAE (0.006) 	s_MAE (0.002) 	m_MAE (0.015) 	
Epoch: [1][1/12]	Time (0.417)  Data (0.000)  Loss 0.0011 (0.0011)  MAEs:  e 0.027 (0.027)  f 0.004 (0.004)  s 0.002 (0.002)  m 0.015 (0.015)  
Epoch: [1][6/12]	Time (0.359)  Data (0.000)  Loss 0.0049 (0.0040)  MAEs:  e 0.062 (0.056)  f 0.005 (0.005)  s 0.003 (0.002)  m 0.015 (0.015)  
Epoch: [1][12/12]	Time (0.351)  Data (0.000)  Loss 0.0054 (0.0038)  MAEs:  e 0.073 (0.054)  f 0.004 (0.005)  s 0.002 (0.002)  m 0.013 (0.014)  
*   e_MAE (0.028) 	

### The more regorous way is to solve for the per-atom contribution by linear regression in your fine-tuning dataset


In [None]:
print("The pretrained Atom_Ref (per atom reference energy):")
for param in chgnet.composition_model.parameters():
    print(param)

The pretrained Atom_Ref (per atom reference energy):
Parameter containing:
tensor([[ -3.4431,  -0.1279,  -2.8300,  -3.4737,  -7.4946,  -8.2354,  -8.1611,
          -8.3861,  -5.7498,  -0.0236,  -1.7406,  -1.6788,  -4.2833,  -6.2002,
          -6.1315,  -5.8405,  -3.8795,  -0.0703,  -1.5668,  -3.4451,  -7.0549,
          -9.1465,  -9.2594,  -9.3514,  -8.9843,  -8.0228,  -6.4955,  -5.6057,
          -3.4002,  -0.9217,  -3.2499,  -4.9164,  -4.7810,  -5.0191,  -3.3316,
           0.5130,  -1.4043,  -3.2175,  -7.4994,  -9.3816, -10.4386,  -9.9539,
          -7.9555,  -8.5440,  -7.3245,  -5.2771,  -1.9014,  -0.4034,  -2.6002,
          -4.0054,  -4.1156,  -3.9928,  -2.7003,   2.2170,  -1.9671,  -3.7180,
          -6.8133,  -7.3502,  -6.0712,  -6.1699,  -5.1471,  -6.1925, -11.5829,
         -15.8841,  -5.9994,  -6.0798,  -5.9513,  -6.0400,  -5.9773,  -2.5091,
          -6.0767, -10.6666, -11.8761, -11.8491, -10.7397,  -9.6100,  -8.4755,
          -6.2070,  -3.0337,   0.4726,  -1.6425,  -3.129

In [None]:
# A list of structures / graphs
structures = [
    lmo,
    Structure(
        species=["Li", "Mn", "Mn", "O", "O", "O"],
        lattice=np.random.rand(3, 3),
        coords=np.random.rand(6, 3),
    ),
    Structure(
        species=["Li", "Li", "Mn", "O", "O", "O"],
        lattice=np.random.rand(3, 3),
        coords=np.random.rand(6, 3),
    ),
    Structure(
        species=["Li", "Mn", "Mn", "O", "O", "O", "O"],
        lattice=np.random.rand(3, 3),
        coords=np.random.rand(7, 3),
    ),
]

# A list of energy_per_atom values (random values here)
energies_per_atom = [5.5, 6, 4.8, 5.6]

In [None]:
from chgnet.model.composition_model import AtomRef

print("We initialize another identical AtomRef layers")
new_atom_ref = AtomRef(is_intensive=True)
new_atom_ref.initialize_from_MPtrj()
for param in new_atom_ref.parameters():
    print(param[:, :3])

We initialize another identical AtomRef layers
tensor([[-3.4431, -0.1279, -2.8300]], grad_fn=<SliceBackward0>)


In [None]:
# Solve linear regression to find the per atom contribution in your fine-tuning dataset

new_atom_ref.fit(structures, energies_per_atom)
print("After refitting, the AtomRef looks like:")
for param in new_atom_ref.parameters():
    print(param)

After refitting, the AtomRef looks like:
Parameter containing:
tensor([[ 0.0000e+00,  0.0000e+00,  4.2667e+00, -3.3299e-15,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  2.9999e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1467e+01,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  