In [51]:
import os
import numpy as np
import copy
import yaml
import argparse

import ase.build
import ase.optimize
# import sella
import ase.visualize
import ase.io.trajectory
import ase.constraints
import logging

import fairchem.core.models.model_registry
import fairchem.core.common.relaxation.ase_utils

import matplotlib.pyplot as plt


In [11]:
checkpoint_path = fairchem.core.models.model_registry.model_name_to_local_file(
    # 'EquiformerV2-31M-S2EF-OC20-All+MD',
    'GemNet-OC-S2EFS-OC20+OC22',
    local_cache='/home/moon/surface/tmp/fairchem_checkpoints/'
)
calc = fairchem.core.common.relaxation.ase_utils.OCPCalculator(checkpoint_path=checkpoint_path, cpu=True, seed=400)


  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
  _Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
  checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
INFO:root:amp: true
cmd:
  checkpoint_dir: /home/moon/surface/surface_thermo/slab/checkpoints/2025-09-09-12-46-56
  commit: core:None,experimental:NA
  identifier: ''
  logs_dir: /home/moon/surface/surface_thermo/slab/logs/wandb/2025-09-09-12-46-56
  print_every: 100
  results_dir: /home/moon/surface/surface_thermo/slab/results/2025-09-09-12-46-56
  seed: null
  timestamp_id: 2025-09-09-12-46-56
  version: 1.10.0
dataset:
  format: oc22_lmdb
  key_mapping:
    force: forces
    y: energy
  normalize_labels: false
  oc20_ref: /checkpoint/janlan/ocp/other_data/final_ref_energies_02_07_2021.pkl
  train_on_oc20_total_energies: true
evaluation_metrics:
  metrics:
    energy:
    - mae
    forces:
    - forcesx_m

In [70]:
slab100 = ase.build.bcc100('Fe', size=(4, 4, 4), vacuum=10.0, a=2.85)

In [71]:
# Fix the bottom two layers
LAYERS_TO_FIX = 2

fixed_indices = []

z_values = list(set([pos[2] for pos in slab100.get_positions()]))
z_values.sort()
fixed_z_values = z_values[:LAYERS_TO_FIX]
for n, pos in enumerate(slab100.get_positions()):
    if pos[2] in fixed_z_values:
        fixed_indices.append(slab100[n].index)

fix_bottom_layers = ase.constraints.FixAtoms(indices=fixed_indices)
slab100.set_constraint(fix_bottom_layers)

In [72]:
slab100.calc = calc

In [73]:
opt = ase.optimize.BFGS(slab100)

In [74]:
opt.run(fmax=0.1, steps=100)

      Step     Time          Energy          fmax
BFGS:    0 14:48:38     -413.040192        1.269088
BFGS:    1 14:48:39     -413.468201        1.214955
BFGS:    2 14:48:40     -416.624573        0.292450
BFGS:    3 14:48:42     -416.938141        0.335229
BFGS:    4 14:48:43     -416.956451        0.299171
BFGS:    5 14:48:44     -417.532928        0.220524
BFGS:    6 14:48:45     -418.497498        0.058217


True

In [75]:

ase.visualize.view(slab100, viewer='x3d')

In [53]:
slab = ase.build.bcc110('Fe', size=(2, 2, 9), vacuum=10.0, a=2.85)

In [54]:
# Fix the bottom two layers
LAYERS_TO_FIX = 2

fixed_indices = []

z_values = list(set([pos[2] for pos in slab.get_positions()]))
z_values.sort()
fixed_z_values = z_values[:LAYERS_TO_FIX]
for n, pos in enumerate(slab.get_positions()):
    if pos[2] in fixed_z_values:
        fixed_indices.append(slab[n].index)

fix_bottom_layers = ase.constraints.FixAtoms(indices=fixed_indices)
slab.set_constraint(fix_bottom_layers)

In [55]:
slab.calc = calc

In [56]:
slab.get_potential_energy()

-248.6311798095703

In [57]:
slab.positions

array([[ 0.        ,  0.        , 10.        ],
       [ 2.85      ,  0.        , 10.        ],
       [ 1.425     ,  2.01525433, 10.        ],
       [ 4.275     ,  2.01525433, 10.        ],
       [ 0.        ,  2.01525433, 12.01525433],
       [ 2.85      ,  2.01525433, 12.01525433],
       [ 1.425     ,  4.03050865, 12.01525433],
       [ 4.275     ,  4.03050865, 12.01525433],
       [ 0.        ,  0.        , 14.03050865],
       [ 2.85      ,  0.        , 14.03050865],
       [ 1.425     ,  2.01525433, 14.03050865],
       [ 4.275     ,  2.01525433, 14.03050865],
       [ 0.        ,  2.01525433, 16.04576298],
       [ 2.85      ,  2.01525433, 16.04576298],
       [ 1.425     ,  4.03050865, 16.04576298],
       [ 4.275     ,  4.03050865, 16.04576298],
       [ 0.        ,  0.        , 18.06101731],
       [ 2.85      ,  0.        , 18.06101731],
       [ 1.425     ,  2.01525433, 18.06101731],
       [ 4.275     ,  2.01525433, 18.06101731],
       [ 0.        ,  2.01525433, 20.076

In [58]:
np.linalg.norm(slab.get_forces(), axis=1)

array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.02330373, 0.02330804,
       0.02330698, 0.0233066 , 0.00661873, 0.00662696, 0.00657817,
       0.00657122, 0.00561225, 0.00556841, 0.0055673 , 0.00553522,
       0.02711308, 0.02700921, 0.02697612, 0.02661962, 0.00549996,
       0.00704345, 0.00608834, 0.00552422, 0.35862106, 0.3559503 ,
       0.35680455, 0.3568937 , 0.77343106, 0.77309555, 0.7722886 ,
       0.7718477 ], dtype=float32)

In [59]:
0.7453342

0.7453342

In [60]:
opt = ase.optimize.BFGS(slab)

In [61]:
opt.run(fmax=0.1, steps=100)

      Step     Time          Energy          fmax
BFGS:    0 13:00:07     -248.631180        0.773431
BFGS:    1 13:00:08     -248.719757        0.714872
BFGS:    2 13:00:09     -248.906357        0.450381
BFGS:    3 13:00:09     -248.914948        0.303911
BFGS:    4 13:00:10     -248.929993        0.257799
BFGS:    5 13:00:11     -249.030518        0.245155
BFGS:    6 13:00:12     -249.227112        0.325691
BFGS:    7 13:00:13     -249.272232        0.393225
BFGS:    8 13:00:13     -249.337250        0.481450
BFGS:    9 13:00:14     -249.543060        0.912496
BFGS:   10 13:00:15     -249.685303        2.161289
BFGS:   11 13:00:15     -249.870087        1.888738
BFGS:   12 13:00:16     -250.899002        1.304098
BFGS:   13 13:00:17     -251.377380        1.070114
BFGS:   14 13:00:18     -251.462250        0.967750
BFGS:   15 13:00:18     -251.768311        1.239699
BFGS:   16 13:00:19     -251.987335        0.701335
BFGS:   17 13:00:20     -252.254242        0.620891
BFGS:   18 13:

True

In [50]:
slab.positions

array([[ 0.        ,  0.        , 10.        ],
       [ 2.85      ,  0.        , 10.        ],
       [ 1.425     ,  2.01525433, 10.        ],
       [ 4.275     ,  2.01525433, 10.        ],
       [ 0.        ,  2.01525433, 12.01525433],
       [ 2.85      ,  2.01525433, 12.01525433],
       [ 1.425     ,  4.03050865, 12.01525433],
       [ 4.275     ,  4.03050865, 12.01525433],
       [ 0.        ,  0.        , 14.03050865],
       [ 2.85      ,  0.        , 14.03050865],
       [ 1.425     ,  2.01525433, 14.03050865],
       [ 4.275     ,  2.01525433, 14.03050865],
       [ 0.        ,  2.01525433, 16.04576298],
       [ 2.85      ,  2.01525433, 16.04576298],
       [ 1.425     ,  4.03050865, 16.04576298],
       [ 4.275     ,  4.03050865, 16.04576298],
       [ 0.        ,  0.        , 18.06101731],
       [ 2.85      ,  0.        , 18.06101731],
       [ 1.425     ,  2.01525433, 18.06101731],
       [ 4.275     ,  2.01525433, 18.06101731],
       [ 0.        ,  2.01525433, 20.076

In [62]:

ase.visualize.view(slab, viewer='x3d')

In [9]:
origin = slab.positions[0, :]

In [10]:
np.linalg.norm(slab.positions[1:, :] - origin, axis=1)

array([ 2.85      ,  5.7       ,  8.55      ,  2.4681724 ,  4.72619033,
        7.4045172 , 10.17653551,  4.9363448 ,  6.98104577,  9.45238065,
       12.09152596,  7.4045172 ,  9.3443499 , 11.6641277 , 14.17857098,
        2.85      ,  4.03050865,  6.37279374,  9.01249133,  2.4681724 ,
        4.72619033,  7.4045172 , 10.17653551,  4.03050865,  6.37279374,
        9.01249133, 11.75085103,  6.21143099,  8.43041369, 10.94563269,
       13.59363362,  4.03050865,  4.9363448 ,  6.98104577,  9.45238065,
        4.72619033,  6.21143099,  8.43041369, 10.94563269,  6.37279374,
        8.06101731, 10.27582114, 12.74558747,  8.43041369, 10.17653551,
       12.340862  , 14.74031462,  6.37279374,  6.98104577,  8.55      ,
       10.66372355,  6.21143099,  7.4045172 ,  9.3443499 , 11.6641277 ,
        6.98104577,  8.55      , 10.66372355, 13.06034073,  8.43041369,
       10.17653551, 12.340862  , 14.74031462])

In [12]:
len(slab)

64