# Fine tuning the pretrained model
In this notebook, we will fine-tune the pretrained model on a custom dataset.

In [1]:
import yaml
import pickle
import jax
from tqdm import tqdm
from flax.training import train_state
import optax
import numpy as np
from ase.io import Trajectory
from typing import Any
from reaxnet.egnn.nequip import NequIPEnergyModel
from reaxnet.egnn.data import AtomicNumberTable, load_from_atomic_list, graph_from_configuration
from reaxnet.egnn.dataloader import GraphDataLoader
from reaxnet.egnn.loss import WeightedLossFunction, EvaluationLossFunction
from reaxnet.egnn.compute import compute_fn

## Loading the dataset
The demo dataset only contains the equivariant graph neural network potential references because it has been aligned with the long-range interactions:

$E_{\rm ref} = E_{\rm DFT} - E_{\rm long-range}$

$F_{\rm ref} = F_{\rm DFT} - F_{\rm long-range}$

$S_{\rm ref} = S_{\rm DFT} - S_{\rm long-range}$

In [2]:
def random_split(num, ratio):
    indices = np.random.permutation(num)
    split = int(num * ratio)
    train_idx = indices[split:].tolist()
    val_idx = indices[:split].tolist()
    return train_idx, val_idx
traj = Trajectory('demo.traj')
val_ratio = 0.2
train_idx, val_idx = random_split(len(traj), val_ratio)


In [3]:
model_path = '../pretrained/'
with open(model_path+'model_config.yaml', 'r') as f:
    model_dict = yaml.safe_load(f)
with open(model_path+'params.pickle', 'rb') as f:
    params = pickle.load(f)
ztable = AtomicNumberTable.from_dict(model_path+'mapping.yaml')
model = NequIPEnergyModel(**model_dict) 

In [4]:
configs = load_from_atomic_list(traj, model_dict['r_max'])

Loading configs from trajectory: 100%|██████████| 500/500 [00:00<00:00, 779.28it/s]


## Generating dataloader

In [5]:
multiplier = 2.5
batch_size = 32
train_loader = GraphDataLoader([graph_from_configuration(configs[i], z_table=ztable) for i in train_idx],
                               n_node=int(model_dict['n_neighbors'] * batch_size * multiplier),
                               n_edge=int(model_dict['n_neighbors'] * batch_size * multiplier),
                               n_graph=batch_size,
                               shuffle=False)
val_loader = GraphDataLoader([graph_from_configuration(configs[i], z_table=ztable) for i in val_idx],
                                n_node=int(model_dict['n_neighbors'] * batch_size * multiplier),
                                n_edge=int(model_dict['n_neighbors'] * batch_size * multiplier),
                                n_graph=batch_size,
                                shuffle=False)


## Freezing the model parameters
Here, we will freeze the model parameters to only train the last two linear layers.

In [6]:
def split_params(params, trainable_layers=None):
    if trainable_layers is None:
        return {}, params
    frozen_params = {}
    trainable_params = {}
    for top_key, top_value in params.items():
        if isinstance(top_value, dict):
            frozen_sub = {}
            trainable_sub = {}
            
            for key, value in top_value.items():
                if key in trainable_layers:
                    trainable_sub[key] = value
                else:
                    frozen_sub[key] = value
            
            if frozen_sub:
                frozen_params[top_key] = frozen_sub
            if trainable_sub:
                trainable_params[top_key] = trainable_sub
        else:
            if top_key in trainable_layers:
                trainable_params[top_key] = top_value
            else:
                frozen_params[top_key] = top_value
    
    return frozen_params, trainable_params

trainable_layers = ['Linear_1', 'Linear_2']
frozen_params, trainable_params = split_params(params, trainable_layers)

In [7]:
if 'params' in params:
    all_layers = list(params['params'].keys())
    frozen_layers = [layer for layer in all_layers if layer not in trainable_layers]
    print(f"All layers: {all_layers}")
    print(f"Fine-tuning layers: {trainable_layers}")
    print(f"Frozen layers: {frozen_layers}")

All layers: ['BesselEmbedding_0', 'Linear_0', 'Linear_1', 'Linear_2', 'NequIPConvolution_0', 'NequIPConvolution_1', 'NequIPConvolution_2', 'NequIPConvolution_3', 'NequIPConvolution_4', 'NequIPConvolution_5']
Fine-tuning layers: ['Linear_1', 'Linear_2']
Frozen layers: ['BesselEmbedding_0', 'Linear_0', 'NequIPConvolution_0', 'NequIPConvolution_1', 'NequIPConvolution_2', 'NequIPConvolution_3', 'NequIPConvolution_4', 'NequIPConvolution_5']


## Define the training state

In [8]:
class FineTuneTrainState(train_state.TrainState):
    frozen_params: Any
    loss_scale: float

def create_finetune_state(
    pretrained_params,
    trainable_layers,
    model,
    learning_rate=1e-3,
    weight_decay=0
):
    frozen_params, trainable_params = split_params(pretrained_params, trainable_layers)
    
    optimizer = optax.adamw(learning_rate=learning_rate, weight_decay=weight_decay)
    return FineTuneTrainState.create(
        apply_fn=model.apply,
        params=trainable_params,
        tx=optimizer,
        frozen_params=frozen_params,
        loss_scale=1.0,
    )

## Define the optimizer and loss function

In [9]:
def merged_forward(state, batch):
    merged_params = {}
    for key in set(list(state.frozen_params.keys()) + list(state.params.keys())):
        merged_params[key] = {}
        if key in state.frozen_params:
            if isinstance(state.frozen_params[key], dict):
                for sub_key, sub_value in state.frozen_params[key].items():
                    merged_params[key][sub_key] = sub_value
            else:
                merged_params[key] = state.frozen_params[key]
        if key in state.params:
            if isinstance(state.params[key], dict):
                for sub_key, sub_value in state.params[key].items():
                    merged_params[key][sub_key] = sub_value
            else:
                merged_params[key] = state.params[key]
    return compute_fn(model=model, params=merged_params, graph=batch)

energy_weight = 1.0
forces_weight = 1.0
stress_weight = 1.0
_loss_fn = WeightedLossFunction(energy_weight=energy_weight, forces_weight=forces_weight, stress_weight=stress_weight)
_evaluate_fn = EvaluationLossFunction(energy_weight=energy_weight, forces_weight=forces_weight, stress_weight=stress_weight)

@jax.jit
def compute_loss(state, batch):
    return _loss_fn(batch, merged_forward(state, batch))

@jax.jit
def evaluate_fn(state, batch):
    return _evaluate_fn(batch, merged_forward(state, batch))

@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        state_with_params = state.replace(params=params)
        loss = compute_loss(
            state_with_params, batch
        )
        return loss
    loss, grads = jax.value_and_grad(loss_fn, has_aux=False)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

@jax.jit
def evaluate(state, batch):
    results = evaluate_fn(state, batch)
    total_loss = results[0]
    rmse_e = results[1]
    rmse_f = results[2]
    rmse_s = results[3]
    mae_e = results[4]
    mae_f = results[5]
    mae_s = results[6]
    return total_loss, rmse_e, rmse_f, rmse_s, mae_e, mae_f, mae_s

## Fine-tuning the model

In [10]:
def finetune(
    state,
    train_loader,
    val_loader,
    epochs=100,
):
    best_val_loss = float('inf')
    best_state = state
    
    epochs_bar = tqdm(range(epochs), desc="Training")
    
    for epoch in epochs_bar:
        epoch_train_losses = []
        for batch in train_loader:
            state, loss = train_step(state, batch)
            epoch_train_losses.append(loss)
        
        avg_train_loss = np.mean(epoch_train_losses)
        
        epoch_val_losses = []
        epoch_val_energy_mae = []
        epoch_val_force_mae = []
        epoch_val_stress_mae = []
        for batch in val_loader:
            val_loss, rmse_e, rmse_f, rmse_s, mae_e, mae_f, mae_s = evaluate(state, batch)
            epoch_val_losses.append(val_loss)
            epoch_val_energy_mae.append(mae_e)
            epoch_val_force_mae.append(mae_f)
            epoch_val_stress_mae.append(mae_s) 
        avg_val_loss = np.mean(epoch_val_losses)
        avg_val_energy_loss = np.mean(epoch_val_energy_mae)
        avg_val_force_loss = np.mean(epoch_val_force_mae)
        avg_val_stress_loss = np.mean(epoch_val_stress_mae)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_state = state
        
        print("{:>15} {:>15} {:>15} {:>15} {:>15} {:>15}".format(
            "Epoch", "Train Loss", "Val Loss", "Energy MAE", "Force MAE", "Stress MAE"))
        print("{:>15} {:>15.6f} {:>15.6f} {:>15.6f} {:>15.6f} {:>15.6f}".format(
            epoch, avg_train_loss, avg_val_loss, avg_val_energy_loss, avg_val_force_loss, avg_val_stress_loss))

            
    return best_state
        
finetune_state = create_finetune_state(
    params,
    trainable_layers,
    model,
    learning_rate=1e-4,
    weight_decay=0,
)
best_state = finetune(
    finetune_state,
    train_loader,
    val_loader,
    epochs=100,
)

## Saving the model


In [None]:
def get_full_params(state):
    full_params = {}
    for key in set(list(state.frozen_params.keys()) + list(state.params.keys())):
        full_params[key] = {}
        if key in state.frozen_params:
            if isinstance(state.frozen_params[key], dict):
                for sub_key, sub_value in state.frozen_params[key].items():
                    full_params[key][sub_key] = sub_value
            else:
                full_params[key] = state.frozen_params[key]
        if key in state.params:
            if isinstance(state.params[key], dict):
                for sub_key, sub_value in state.params[key].items():
                    full_params[key][sub_key] = sub_value
            else:
                full_params[key] = state.params[key]
    return full_params

full_params = get_full_params(best_state)
 
# with open('finetuned_params.pickle', 'wb') as f:
#     pickle.dump(full_params, f)

# Fine-tuning long-range parameters

One can also fine-tune the long-range parameters like this form:
```python
import flax.linen as nn
import jax.numpy as jnp

class OptPQEqParameters(nn.Module):
    init_radius: float
    init_chi: float
    init_eta: float
    init_ks: float
    radius_range: Tuple[float, float]
    chi_range: Tuple[float, float]
    eta_range: Tuple[float, float]
    ks_range: Tuple[float, float]
    def setup(self):
        self.radius_unconstrained = self.param('radius_unconstrained', 
                                            lambda _: self._inverse_sigmoid_transform(
                                                self.init_radius, self.radius_range[0], self.radius_range[1]
                                            ))
        self.chi_unconstrained = self.param('chi_unconstrained',
                                            lambda _: self._inverse_sigmoid_transform(
                                                self.init_chi, self.chi_range[0], self.chi_range[1]
                                            ))
        self.eta_unconstrained = self.param('eta_unconstrained',
                                            lambda _: self._inverse_sigmoid_transform(
                                                self.init_eta, self.eta_range[0], self.eta_range[1]
                                            ))
        self.ks_unconstrained = self.param('ks_unconstrained',
                                            lambda _: self._inverse_sigmoid_transform(
                                                self.init_ks, self.ks_range[0], self.ks_range[1]
                                            ))

        def _sigmoid_transform(self, x, min_val, max_val):
        return min_val + (max_val - min_val) * jax.nn.sigmoid(x)
    
    def _inverse_sigmoid_transform(self, y, min_val, max_val):
        y_norm = (y - min_val) / (max_val - min_val)
        y_norm = jnp.clip(y_norm, 1e-6, 1-1e-6)
        return jnp.log(y_norm / (1 - y_norm))

    def get_constrained_parameters(self):
        radius = self._sigmoid_transform(self.radius_unconstrained, 
                                        self.radius_range[0], self.radius_range[1])
        chi = self._sigmoid_transform(self.chi_unconstrained, 
                                    self.chi_range[0], self.chi_range[1])
        eta = self._sigmoid_transform(self.eta_unconstrained, 
                                    self.eta_range[0], self.eta_range[1])
        ks = self._sigmoid_transform(self.ks_unconstrained, 
                                   self.ks_range[0], self.ks_range[1])
        return radius, chi, eta, ks
    
    def calculate_pqeq(self, radius, chi, eta, ks, **args_for_calculate_pqeq):
        # This function should be implemented to calculate the pqeq based on the parameters
        # and the input arguments. The implementation will depend on the specific use case.
        pass
    
    @nn.compact
    def __call__(self, **args_for_calculate_pqeq):
        radius, chi, eta, ks = self.get_constrained_parameters()
        return self.calculate_pqeq(radius, chi, eta, ks, **args_for_calculate_pqeq)
        # or return value and gradients
```

Or fine-tune the whole model.