In [8]:
from ase import units
from ase.md.langevin import Langevin
from ase.io import read
import torch
import numpy as np
import time

from mace.calculators import MACECalculator
from mace.modules import MACE

# --- Configuration ---
model_path = '/global/cfs/cdirs/m5047/train_sam/mace_gs/checkpoints/MACE_model_run-123_epoch-264.pt'
device = 'cuda'

# 1. Load the full training checkpoint
print("Loading checkpoint...")
state_dict_loaded = torch.load(model_path, map_location=device)

# 2. Extract the model state dictionary (parameters) from the 'model' key
# This is where the weights and configuration parameters are stored
model_data = state_dict_loaded['model']

# 3. Extract necessary architecture parameters
r_max = model_data['r_max'].item()
num_interactions = model_data['num_interactions'].item()

atomic_energies_tensor = model_data['atomic_energies_fn.atomic_energies']
atomic_numbers_tensor = model_data['atomic_numbers']

# Robustly extract atomic numbers and energies
# .flatten().tolist() ensures we get a flat list of values, regardless of tensor shape
numbers_list = atomic_numbers_tensor.flatten().tolist()
energies_list = atomic_energies_tensor.flatten().tolist()

# Convert to the required dictionary format: {atomic_number: energy}
atomic_energies = {
    int(z): float(e) 
    for z, e in zip(numbers_list, energies_list)
}

# 4. Construct the configuration dictionary (some parameters must be inferred)
# WARNING: 'hidden_irreps', 'MLP_irreps', and 'avg_num_neighbors' are common MACE defaults.
# If the simulation fails or energies are wrong, you must confirm these three values from your training logs.
config = {
    "r_max": r_max,
    "num_interactions": num_interactions,
    "num_species": len(numbers_list),
    "hidden_irreps": "128x0e + 128x1o", 
    "MLP_irreps": "128x0e",
    "atomic_energies": atomic_energies,
    "gate": "silu",
    "avg_num_neighbors": 18.0, 
}

# 5. Create the model architecture
print(f"Reconstructing MACE model (r_max={config['r_max']}, interactions={config['num_interactions']})...")
model = MACE(
    r_max=config["r_max"],
    num_interactions=config["num_interactions"],
    num_species=config["num_species"],
    hidden_irreps=config["hidden_irreps"],
    MLP_irreps=config["MLP_irreps"],
    atomic_energies=config["atomic_energies"],
    gate=config["gate"],
    avg_num_neighbors=config["avg_num_neighbors"],
)

# 6. Load the learned parameters into the model architecture
# model_data is the dictionary containing all the correct key-value pairs for the model state dict.
model.load_state_dict(model_data, strict=True) 

# 7. Initialize the MACECalculator with the model object
print("Initializing MACECalculator...")
calculator = MACECalculator(models=[model], device=device)

# --- ASE MD Setup ---
init_conf = read('./initial.xyz', '0')
init_conf.set_calculator(calculator)

dyn = Langevin(
    init_conf, 
    0.5 * units.fs, 
    temperature_K=310, 
    friction=5e-3
)

# Define callback function to write frames
def write_frame():
    dyn.atoms.write('mace_ace_gs.xyz', append=True)

# Attach the writer function to the MD run
dyn.attach(write_frame, interval=50)

# Run MD
print("Starting MD simulation for 100 steps...")
dyn.run(100)
print("MD finished!")

Loading checkpoint...
Reconstructing MACE model (r_max=6.0, interactions=2)...


  state_dict_loaded = torch.load(model_path, map_location=device)


TypeError: MACE.__init__() got an unexpected keyword argument 'num_species'