In [None]:
!pip install mace-torch
!pip install cuequivariance cuequivariance-torch cuequivariance-ops-torch-cu12

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import sys
sys.path.append('/content/drive/My Drive/colab_temp/mace_fine-tuning_expts/scripts')

In [None]:
import utils

In [None]:
%cd "drive/MyDrive/colab_temp/mace_fine-tuning_expts"

In [None]:
from mace.calculators import mace_mp
macemp = mace_mp(model="small")

In [None]:
data_path = './data'
reference_configs = utils.load_configurations(data_path)

In [None]:
from ase.units import create_units
units_2006 = create_units('2006')

# The values I back-computed with 2006 CODATA values
E0_Hf_Ry = -112.41607411999959316330
E0_O_Ry = -32.6203517400001970321

# Convert to eV using ASE units
E0_Hf_eV = E0_Hf_Ry * units_2006['Ry']  # Rydberg to eV conversion
E0_O_eV = E0_O_Ry * units_2006['Ry']

new_E0s = {8:  E0_O_eV,
           72: E0_Hf_eV}

In [None]:
print(E0_O_eV)

In [None]:
alt_configs = {"HfOx_test_no0K": reference_configs["HfOx_test_no0K"]}

In [None]:
predicted_alt_configs = utils.compute_predictions(alt_configs, macemp, new_E0s)

In [None]:
errors = utils.calculate_errors(alt_configs, predicted_alt_configs)

In [None]:
utils.display_errors(errors)

In [None]:
%%writefile configs/small_ft-HfOx_v0.1.yml

model: 'MACE'
foundation_model: 'small'
multiheads_finetuning: False
#train_file: ['data/Hf/trainval', 'data/HfOx/trainval']
#train_file: 'data/HfOx/trainval'
train_file: 'data/HfOx/test/HfOx_amorphous_MC_rattled_form_sorted_test.xyz'
valid_fraction: 0.1
#test_file: ['data/Hf/test', 'data/HfOx/test']
#test_file: 'data/HfOx/test'
test_file: 'data/HfOx/test/HfOx_amorphous_MC_rattled_form_sorted_test.xyz'
energy_key: "REF_energy"
forces_key: "REF_forces"
#E0s: {8: -1529.4984727695407, 72: -443.8224565134432}
E0s: {8: -1.0, 72: -4.0}
name: "small_ft-HfOx_v0.1"
model_dir: "models/small_HfOx-ft_1"
log_dir: "models/small_HfOx-ft_1"
results_dir: "models/small_HfOx-ft_1"
#checkpoint_dir: "MACE_models"
device: cuda
batch_size: 5
max_num_epochs: 500
lr: 0.01
swa: True
seed: 123
stress_weight: 0.0
forces_weight: 10.0
energy_weight: 1.0

In [None]:
import warnings
warnings.filterwarnings("ignore")
from mace.cli.run_train import main as mace_run_train_main
import sys
import logging

def train_mace(config_file_path):
    logging.getLogger().handlers.clear()
    sys.argv = ["program", "--config", config_file_path]
    mace_run_train_main()

In [None]:
train_mace("configs/small_ft-HfOx_v0.1.yml")