In [None]:
!pip install mace-torch[wandb]
!pip install cuequivariance cuequivariance-torch cuequivariance-ops-torch-cu12

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import sys
sys.path.append('/content/drive/My Drive/colab_temp/mace_fine-tuning_expts/scripts')

In [None]:
import wandb
wandb.login()

In [None]:
%cd "drive/MyDrive/colab_temp/mace_fine-tuning_expts"

In [None]:
import json
import mace.tools.scripts_utils as scripts_utils
import argparse
import logging
import json
import numpy as np
from mace import tools
from mace.data import KeySpecification

def patched_setup_wandb(args: argparse.Namespace):
    logging.info("Using Weights and Biases for logging")
    import wandb

    wandb_config = {}
    args_dict = vars(args)

    for key, value in args_dict.items():
        if isinstance(value, np.ndarray):
            args_dict[key] = value.tolist()

    class CustomEncoder(json.JSONEncoder):
        def default(self, o):
            if isinstance(o, KeySpecification):
                return o.__dict__
            return super().default(o)

    clean_dict = {}
    for k,v in args_dict.items():
        if k == 'compute_atomic_dipole':
            print("found the bad one")
            continue
        clean_dict[k] = v

    args_dict_json = json.dumps(clean_dict, cls=CustomEncoder)
    for key in args.wandb_log_hypers:
        wandb_config[key] = args_dict[key]
    tools.init_wandb(
        project=args.wandb_project,
        entity=args.wandb_entity,
        name=args.wandb_name,
        config=wandb_config,
        directory=args.wandb_dir,
    )
    wandb.run.summary["params"] = args_dict_json

scripts_utils.setup_wandb = patched_setup_wandb

In [None]:
%%writefile configs/dpp_small_ft-HfOx_v0.9.yml

model: 'MACE'
foundation_model: 'small'
multiheads_finetuning: False
train_file: 'data/dpp_train.xyz'
valid_file: 'data/dpp_validation.xyz'
test_file: 'data/test.xyz'
energy_key: "REF_energy"
forces_key: "REF_forces"
E0s: {8: -443.8224565134432, 72: -1529.4984727695407}
name: "small_ft-HfOx"
model_dir: "models/small_dpp_HfOx-ft_8"
log_dir: "models/small_dpp_HfOx-ft_8"
results_dir: "models/small_dpp_HfOx-ft_8"
checkpoints_dir: "models/small_dpp_HfOx-ft_8"
device: cuda
batch_size: 14
max_num_epochs: 500
lr: 0.01
ema: True
ema_decay: 0.99
optimizer: adam
weight_decay: 5e-4
swa: True
seed: 123
stress_weight: 0.0
forces_weight: 100.0
energy_weight: 1.0
#scheduler_patience: 20
#lr_factor: 0.2
huber_delta: 0.001
scheduler: ExponentialLR
wandb: True
enable_cueq: True
wandb_project: "MACE_HfOx_from-scratch"
wandb_name: "HfOx_1"
wandb_log_hypers: ["lr", "lr_factor", "lr_scheduler_gamma", "scheduler",
                   "swa", "swa_lr", "start_swa", "scheduler_patience",
                   "swa_energy_weight", "swa_forces_weight",
                   "energy_weight", "forces_weight",
                   "ema", "ema_decay", "weight_decay",
                   "huber_delta", "optimizer", "amsgrad", "beta",
                   "batch_size", "max_num_epochs", "seed"]



In [None]:
import warnings
warnings.filterwarnings("ignore")
from mace.cli.run_train import main as mace_run_train_main
import sys
import logging

def train_mace(config_file_path):
    logging.getLogger().handlers.clear()
    sys.argv = ["program", "--config", config_file_path, "--plot", "true", "--plot_frequency", "10"]
    mace_run_train_main()

In [None]:
train_mace("configs/dpp_small_ft-HfOx_v0.9.yml")

In [None]:
import warnings
warnings.filterwarnings("ignore")
from mace.cli.plot_train import main as mace_plot_train_main
import sys
import logging

def train_mace(config_file_path):
    logging.getLogger().handlers.clear()
    sys.argv = ["program", "--config", config_file_path]
    mace_run_train_main()