In [1]:
import os
import numpy as np
from ase.io import read, write
from ase.constraints import FixAtoms, Hookean
from ase.md.verlet import VelocityVerlet
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from ase.optimize import FIRE
from mace.calculators.mace import MACECalculator

structures_dir = '/Users/sarahmatesic/Documents/GitHub/ML4MSD---Final-Project---SarahM/hdpe_hbn_project/structures'
traj_dir = '/Users/sarahmatesic/Documents/GitHub/ML4MSD---Final-Project---SarahM/hdpe_hbn_project/results/new'
os.makedirs(traj_dir, exist_ok=True)

mace_model_path = '/Users/sarahmatesic/.cache/mace/20231210mace128L0_energy_epoch249model'

num_slab_atoms = 32  
bond_k = 50.0        
bond_r0 = 1.54       
temperature = 300.0
md_steps = 50
dt = 0.2

files_to_run = sorted([f for f in os.listdir(structures_dir)
                       if f.startswith('config_') and f.endswith('.xyz')])

def compute_adsorption_energy(system, slab, polymer):
    E_total = system.get_potential_energy()
    E_slab = slab.get_potential_energy()
    E_polymer = polymer.get_potential_energy()
    return E_total - (E_slab + E_polymer)

def process_config(config_file):
    path = os.path.join(structures_dir, config_file)
    system = read(path)

    # split slab and polymer
    slab = system[:num_slab_atoms].copy()
    polymer = system[num_slab_atoms:].copy()

    # freeze slab
    mask = np.zeros(len(system), dtype=bool)
    mask[:num_slab_atoms] = True
    system.set_constraint(FixAtoms(mask=mask))

    # hookean bonds 
    carbon_indices = [i for i in range(num_slab_atoms, len(system), 3)]
    bond_pairs = [(carbon_indices[i], carbon_indices[i+1]) for i in range(len(carbon_indices)-1)]
    for a1, a2 in bond_pairs:
        system.constraints.append(Hookean(a1, a2, k=bond_k, rt=bond_r0))

    calc = MACECalculator(mace_model_path, dtype='float32')
    system.calc = calc
    slab.calc = calc
    polymer.calc = calc

    # small energy minimization
    FIRE(system).run(fmax=0.05, steps=100)

    MaxwellBoltzmannDistribution(system, temperature_K=temperature)
    dyn = VelocityVerlet(system, timestep=dt)
    for step in range(md_steps):
        dyn.run(1)

    traj_path = os.path.join(traj_dir, f"{config_file.replace('.xyz','')}.traj")
    write(traj_path, system)

    final_xyz = os.path.join(traj_dir, f"{config_file.replace('.xyz','')}_final.xyz")
    write(final_xyz, system)

    E_ads = compute_adsorption_energy(system, slab, polymer)
    print(f"{config_file}: Adsorption energy = {E_ads:.4f} eV")
    return E_ads

adsorption_energies = []
batch_size = 2
batch_pause = 5  # to cool laptop

for i in range(0, len(files_to_run), batch_size):
    batch = files_to_run[i:i+batch_size]
    for cfg in batch:
        E_ads = process_config(cfg)
        adsorption_energies.append((cfg, E_ads))
    print(f"Batch complete. Waiting {batch_pause}s...")
    import time; time.sleep(batch_pause)

results_file = os.path.join(traj_dir, "adsorption_energies.txt")
with open(results_file, 'w') as f:
    for cfg, E_ads in adsorption_energies:
        f.write(f"{cfg} {E_ads:.6f}\n")



  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))


cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.


  torch.load(f=model_path, map_location=device)


Using head Default out of ['Default']
No dtype selected, switching to float64 to match model dtype.
      Step     Time          Energy          fmax
FIRE:    0 19:48:39      281.038069      295.889408
FIRE:    1 19:48:39      202.828326      271.812347
FIRE:    2 19:48:40      131.623115      244.066428
FIRE:    3 19:48:40       68.468851      208.977920
FIRE:    4 19:48:40       14.855206      152.466322
FIRE:    5 19:48:40      -26.792020      112.372710
FIRE:    6 19:48:40      -22.341702      908.579662
FIRE:    7 19:48:40      -65.963453      130.657362
FIRE:    8 19:48:41      -90.811304      111.481277
FIRE:    9 19:48:41     -122.728717       78.073311
FIRE:   10 19:48:41     -149.427238       92.141962
FIRE:   11 19:48:41     -171.245130       68.429871
FIRE:   12 19:48:41     -178.761089      548.522762
FIRE:   13 19:48:41     -187.792547      195.107335
FIRE:   14 19:48:41     -191.080528      168.243151
FIRE:   15 19:48:42     -194.877146       66.160073
FIRE:   16 19:48:4

  torch.load(f=model_path, map_location=device)


      Step     Time          Energy          fmax
FIRE:    0 19:49:33      281.952160      296.053854
FIRE:    1 19:49:33      203.679250      272.030860
FIRE:    2 19:49:33      132.432850      244.240375
FIRE:    3 19:49:33       69.313022      209.130973
FIRE:    4 19:49:33       15.831310      152.479101
FIRE:    5 19:49:34      -25.543583      113.568821
FIRE:    6 19:49:34      -21.367126      889.594204
FIRE:    7 19:49:34      -64.132151      131.926591
FIRE:    8 19:49:34      -91.211440      112.319023
FIRE:    9 19:49:34     -122.754770       73.868204
FIRE:   10 19:49:34     -149.294022       84.885224
FIRE:   11 19:49:34     -171.407371       80.718767
FIRE:   12 19:49:35     -182.197873      493.340416
FIRE:   13 19:49:35     -190.809259      119.779571
FIRE:   14 19:49:35     -193.026056      108.704894
FIRE:   15 19:49:35     -196.996808       83.504002
FIRE:   16 19:49:35     -198.431099       89.688908
FIRE:   17 19:49:35     -200.891405       97.506807
FIRE:   18 19:

  torch.load(f=model_path, map_location=device)


Using head Default out of ['Default']
No dtype selected, switching to float64 to match model dtype.
      Step     Time          Energy          fmax
FIRE:    0 19:50:36      281.117430      295.935415
FIRE:    1 19:50:37      202.898268      271.938957
FIRE:    2 19:50:37      131.679134      244.240801
FIRE:    3 19:50:37       68.516339      209.124937
FIRE:    4 19:50:37       14.916546      152.439417
FIRE:    5 19:50:37      -26.690126      112.794331
FIRE:    6 19:50:38      -22.180135      907.845601
FIRE:    7 19:50:38      -65.921521      130.619989
FIRE:    8 19:50:38      -90.735724      111.542408
FIRE:    9 19:50:38     -122.660488       77.471722
FIRE:   10 19:50:38     -149.333410       90.983710
FIRE:   11 19:50:39     -171.032673       67.670760
FIRE:   12 19:50:39     -178.019516      553.631661
FIRE:   13 19:50:39     -187.767772      190.858406
FIRE:   14 19:50:39     -190.988506      167.290348
FIRE:   15 19:50:39     -194.541269       65.826762
FIRE:   16 19:50:3

KeyboardInterrupt: 