In [None]:
import os
import numpy as np
from ase.io import read, write
from ase.constraints import FixAtoms, Hookean
from ase.md.verlet import VelocityVerlet
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from ase.optimize import FIRE
from mace.calculators.mace import MACECalculator

structures_dir = '/Users/sarahmatesic/Documents/GitHub/ML4MSD---Final-Project---SarahM/hdpe_hbn_project/structures'
traj_dir = '/Users/sarahmatesic/Documents/GitHub/ML4MSD---Final-Project---SarahM/hdpe_hbn_project/results'
os.makedirs(traj_dir, exist_ok=True)

mace_model_path = '/Users/sarahmatesic/.cache/mace/20231210mace128L0_energy_epoch249model'

num_slab_atoms = 32  
bond_k = 50.0        
bond_r0 = 1.54       
temperature = 300.0
md_steps = 50
dt = 0.2

files_to_run = sorted([f for f in os.listdir(structures_dir)
                       if f.startswith('config_') and f.endswith('.xyz')])

def compute_adsorption_energy(system, slab, polymer):
    E_total = system.get_potential_energy()
    E_slab = slab.get_potential_energy()
    E_polymer = polymer.get_potential_energy()
    return E_total - (E_slab + E_polymer)

def process_config(config_file):
    path = os.path.join(structures_dir, config_file)
    system = read(path)

    # split slab and polymer
    slab = system[:num_slab_atoms].copy()
    polymer = system[num_slab_atoms:].copy()

    # freeze slab
    mask = np.zeros(len(system), dtype=bool)
    mask[:num_slab_atoms] = True
    system.set_constraint(FixAtoms(mask=mask))

    # hookean bonds 
    carbon_indices = [i for i in range(num_slab_atoms, len(system), 3)]
    bond_pairs = [(carbon_indices[i], carbon_indices[i+1]) for i in range(len(carbon_indices)-1)]
    for a1, a2 in bond_pairs:
        system.constraints.append(Hookean(a1, a2, k=bond_k, rt=bond_r0))

    calc = MACECalculator(mace_model_path, dtype='float32')
    system.calc = calc
    slab.calc = calc
    polymer.calc = calc

    # small energy minimization
    FIRE(system).run(fmax=0.05, steps=100)

    MaxwellBoltzmannDistribution(system, temperature_K=temperature)
    dyn = VelocityVerlet(system, timestep=dt)
    for step in range(md_steps):
        dyn.run(1)

    traj_path = os.path.join(traj_dir, f"{config_file.replace('.xyz','')}.traj")
    write(traj_path, system)

    final_xyz = os.path.join(traj_dir, f"{config_file.replace('.xyz','')}_final.xyz")
    write(final_xyz, system)

    E_ads = compute_adsorption_energy(system, slab, polymer)
    print(f"{config_file}: Adsorption energy = {E_ads:.4f} eV")
    return E_ads

adsorption_energies = []
batch_size = 2
batch_pause = 5  # to cool laptop

for i in range(0, len(files_to_run), batch_size):
    batch = files_to_run[i:i+batch_size]
    for cfg in batch:
        E_ads = process_config(cfg)
        adsorption_energies.append((cfg, E_ads))
    print(f"Batch complete. Waiting {batch_pause}s...")
    import time; time.sleep(batch_pause)

results_file = os.path.join(traj_dir, "adsorption_energies.txt")
with open(results_file, 'w') as f:
    for cfg, E_ads in adsorption_energies:
        f.write(f"{cfg} {E_ads:.6f}\n")



  torch.load(f=model_path, map_location=device)


Using head Default out of ['Default']
No dtype selected, switching to float64 to match model dtype.
      Step     Time          Energy          fmax
FIRE:    0 17:26:42     -327.884923        8.041129
FIRE:    1 17:26:43     -329.558341        5.347364
FIRE:    2 17:26:43     -331.007573        2.976269
FIRE:    3 17:26:43     -331.489840        5.819813
FIRE:    4 17:26:43     -331.794136        3.802242
FIRE:    5 17:26:43     -332.110139        2.000332
FIRE:    6 17:26:43     -332.259566        1.651493
FIRE:    7 17:26:43     -332.298107        2.772379
FIRE:    8 17:26:44     -332.319224        2.647429
FIRE:    9 17:26:44     -332.358689        2.400411
FIRE:   10 17:26:44     -332.411384        2.037654
FIRE:   11 17:26:44     -332.470649        1.569462
FIRE:   12 17:26:44     -332.529422        1.310695
FIRE:   13 17:26:44     -332.581648        1.265206
FIRE:   14 17:26:44     -332.624030        1.216421
FIRE:   15 17:26:44     -332.660947        1.155611
FIRE:   16 17:26:4

  torch.load(f=model_path, map_location=device)


      Step     Time          Energy          fmax
FIRE:    0 17:26:55     -327.648480        8.135503
FIRE:    1 17:26:55     -329.339648        5.330086
FIRE:    2 17:26:55     -330.757162        3.142701
FIRE:    3 17:26:55     -331.216199        5.968620
FIRE:    4 17:26:55     -331.527826        3.888370
FIRE:    5 17:26:55     -331.845363        2.003524
FIRE:    6 17:26:55     -331.987234        1.655917
FIRE:    7 17:26:56     -332.015968        2.801396
FIRE:    8 17:26:56     -332.037758        2.672617
FIRE:    9 17:26:56     -332.078360        2.417848
FIRE:   10 17:26:56     -332.132277        2.043326
FIRE:   11 17:26:56     -332.192375        1.559466
FIRE:   12 17:26:56     -332.251135        1.326422
FIRE:   13 17:26:56     -332.302208        1.280234
FIRE:   14 17:26:57     -332.342347        1.231325
FIRE:   15 17:26:57     -332.376119        1.170557
FIRE:   16 17:26:57     -332.409877        1.560503
FIRE:   17 17:26:57     -332.455743        1.819660
FIRE:   18 17:

  torch.load(f=model_path, map_location=device)


Using head Default out of ['Default']
No dtype selected, switching to float64 to match model dtype.
      Step     Time          Energy          fmax
FIRE:    0 17:27:22     -327.784897        8.088717
FIRE:    1 17:27:23     -329.458339        5.343470
FIRE:    2 17:27:23     -330.893858        3.083657
FIRE:    3 17:27:23     -331.362881        5.917048
FIRE:    4 17:27:23     -331.671051        3.853389
FIRE:    5 17:27:24     -331.987800        1.998398
FIRE:    6 17:27:24     -332.133268        1.649288
FIRE:    7 17:27:24     -332.166977        2.806362
FIRE:    8 17:27:24     -332.188449        2.678314
FIRE:    9 17:27:24     -332.228523        2.425035
FIRE:   10 17:27:24     -332.281896        2.052802
FIRE:   11 17:27:24     -332.341677        1.572047
FIRE:   12 17:27:25     -332.400571        1.310240
FIRE:   13 17:27:25     -332.452359        1.264590
FIRE:   14 17:27:25     -332.493743        1.215691
FIRE:   15 17:27:25     -332.529186        1.154802
FIRE:   16 17:27:2

  torch.load(f=model_path, map_location=device)


      Step     Time          Energy          fmax
FIRE:    0 17:27:44     -327.268052        7.962640
FIRE:    1 17:27:44     -328.912236        5.295750
FIRE:    2 17:27:45     -330.296415        3.251274
FIRE:    3 17:27:45     -330.737823        6.175913
FIRE:    4 17:27:45     -331.064975        3.991904
FIRE:    5 17:27:45     -331.387123        2.000216
FIRE:    6 17:27:45     -331.518591        1.653845
FIRE:    7 17:27:45     -331.538930        2.916949
FIRE:    8 17:27:46     -331.562562        2.782351
FIRE:    9 17:27:46     -331.606508        2.515119
FIRE:   10 17:27:46     -331.664613        2.120001
FIRE:   11 17:27:46     -331.728840        1.605867
FIRE:   12 17:27:46     -331.790669        1.326119
FIRE:   13 17:27:47     -331.842940        1.279856
FIRE:   14 17:27:47     -331.882229        1.231429
FIRE:   15 17:27:47     -331.913733        1.171162
FIRE:   16 17:27:47     -331.945555        1.719247
FIRE:   17 17:27:47     -331.991778        1.967121
FIRE:   18 17:

  torch.load(f=model_path, map_location=device)


Using head Default out of ['Default']
No dtype selected, switching to float64 to match model dtype.
      Step     Time          Energy          fmax
FIRE:    0 17:28:14     -327.793757        8.174734
FIRE:    1 17:28:14     -329.471322        5.425750
FIRE:    2 17:28:14     -330.915626        3.288643
FIRE:    3 17:28:15     -331.396015        6.009830
FIRE:    4 17:28:15     -331.708506        3.842747
FIRE:    5 17:28:15     -332.026876        1.991471
FIRE:    6 17:28:15     -332.170897        1.651148
FIRE:    7 17:28:15     -332.205695        2.988947
FIRE:    8 17:28:15     -332.227948        2.852100
FIRE:    9 17:28:15     -332.269443        2.581078
FIRE:   10 17:28:15     -332.324605        2.181836
FIRE:   11 17:28:16     -332.386183        1.664358
FIRE:   12 17:28:16     -332.446489        1.310226
FIRE:   13 17:28:16     -332.498994        1.264765
FIRE:   14 17:28:16     -332.540361        1.215801
FIRE:   15 17:28:16     -332.575433        1.154742
FIRE:   16 17:28:1

  torch.load(f=model_path, map_location=device)


      Step     Time          Energy          fmax
FIRE:    0 17:28:35     -327.933090        8.056630
FIRE:    1 17:28:35     -329.603864        5.384347
FIRE:    2 17:28:35     -331.059899        2.953953
FIRE:    3 17:28:35     -331.538133        5.875247
FIRE:    4 17:28:35     -331.843322        3.811261
FIRE:    5 17:28:35     -332.159004        1.991580
FIRE:    6 17:28:35     -332.307248        1.652243
FIRE:    7 17:28:36     -332.345399        2.811746
FIRE:    8 17:28:36     -332.366560        2.686201
FIRE:    9 17:28:36     -332.406116        2.437864
FIRE:   10 17:28:36     -332.458946        2.072717
FIRE:   11 17:28:36     -332.518376        1.600332
FIRE:   12 17:28:36     -332.577300        1.310231
FIRE:   13 17:28:36     -332.629590        1.264847
FIRE:   14 17:28:36     -332.671875        1.215956
FIRE:   15 17:28:37     -332.708497        1.154883
FIRE:   16 17:28:37     -332.745088        1.502147
FIRE:   17 17:28:37     -332.793143        1.773391
FIRE:   18 17:

  torch.load(f=model_path, map_location=device)


Using head Default out of ['Default']
No dtype selected, switching to float64 to match model dtype.
      Step     Time          Energy          fmax
FIRE:    0 17:28:55     -327.918917        8.126436
FIRE:    1 17:28:55     -329.597450        5.371953
FIRE:    2 17:28:56     -331.047337        3.093644
FIRE:    3 17:28:56     -331.530183        5.913544
FIRE:    4 17:28:56     -331.836727        3.830227
FIRE:    5 17:28:56     -332.152686        1.992077
FIRE:    6 17:28:56     -332.299106        1.652381
FIRE:    7 17:28:57     -332.335057        2.879542
FIRE:    8 17:28:57     -332.356608        2.747794
FIRE:    9 17:28:57     -332.396830        2.487370
FIRE:   10 17:28:57     -332.450397        2.104963
FIRE:   11 17:28:57     -332.510392        1.611377
FIRE:   12 17:28:57     -332.569497        1.310234
FIRE:   13 17:28:57     -332.621476        1.264803
FIRE:   14 17:28:57     -332.663049        1.215819
FIRE:   15 17:28:58     -332.698778        1.154646
FIRE:   16 17:28:5

  torch.load(f=model_path, map_location=device)


FIRE:    1 17:29:08     -329.638135        5.348200
FIRE:    2 17:29:08     -331.094768        2.839076
FIRE:    3 17:29:08     -331.575071        5.733857
FIRE:    4 17:29:08     -331.875650        3.757389
FIRE:    5 17:29:08     -332.189931        1.990059
FIRE:    6 17:29:08     -332.341260        1.650417
FIRE:    7 17:29:08     -332.382309        2.711353
FIRE:    8 17:29:08     -332.402921        2.591747
FIRE:    9 17:29:09     -332.441499        2.355342
FIRE:   10 17:29:09     -332.493142        2.008201
FIRE:   11 17:29:09     -332.551468        1.559923
FIRE:   12 17:29:09     -332.609675        1.308621
FIRE:   13 17:29:09     -332.661874        1.263367
FIRE:   14 17:29:09     -332.704741        1.214693
FIRE:   15 17:29:09     -332.742457        1.153903
FIRE:   16 17:29:09     -332.780100        1.407225
FIRE:   17 17:29:09     -332.828541        1.695428
FIRE:   18 17:29:10     -332.896611        1.597510
FIRE:   19 17:29:10     -332.979257        1.075249
FIRE:   20 1

  torch.load(f=model_path, map_location=device)


Using head Default out of ['Default']
No dtype selected, switching to float64 to match model dtype.
      Step     Time          Energy          fmax
FIRE:    0 17:29:25     -327.875726        8.121537
FIRE:    1 17:29:25     -329.554908        5.360355
FIRE:    2 17:29:25     -331.000717        3.115531
FIRE:    3 17:29:25     -331.482003        5.912838
FIRE:    4 17:29:26     -331.788762        3.837863
FIRE:    5 17:29:26     -332.105137        1.992026
FIRE:    6 17:29:26     -332.252029        1.651729
FIRE:    7 17:29:26     -332.288092        2.850124
FIRE:    8 17:29:26     -332.309574        2.720120
FIRE:    9 17:29:26     -332.349676        2.462975
FIRE:   10 17:29:26     -332.403110        2.085025
FIRE:   11 17:29:27     -332.462999        1.596734
FIRE:   12 17:29:27     -332.522058        1.310159
FIRE:   13 17:29:27     -332.574078        1.264700
FIRE:   14 17:29:27     -332.615771        1.215860
FIRE:   15 17:29:27     -332.651663        1.154940
FIRE:   16 17:29:2

  torch.load(f=model_path, map_location=device)


      Step     Time          Energy          fmax
FIRE:    0 17:29:45     -327.972294        7.952165
FIRE:    1 17:29:45     -329.630085        5.347962
FIRE:    2 17:29:45     -331.087926        2.825445
FIRE:    3 17:29:45     -331.571473        5.708565
FIRE:    4 17:29:46     -331.871168        3.742342
FIRE:    5 17:29:46     -332.184866        1.993036
FIRE:    6 17:29:46     -332.336339        1.650663
FIRE:    7 17:29:46     -332.377778        2.705538
FIRE:    8 17:29:46     -332.398351        2.586061
FIRE:    9 17:29:46     -332.436856        2.349908
FIRE:   10 17:29:47     -332.488404        2.003143
FIRE:   11 17:29:47     -332.546627        1.555432
FIRE:   12 17:29:47     -332.604745        1.308841
FIRE:   13 17:29:47     -332.656891        1.263581
FIRE:   14 17:29:47     -332.699763        1.214895
FIRE:   15 17:29:47     -332.737555        1.154096
FIRE:   16 17:29:47     -332.775330        1.401569
FIRE:   17 17:29:47     -332.823908        1.688763
FIRE:   18 17:

  torch.load(f=model_path, map_location=device)


Using head Default out of ['Default']
No dtype selected, switching to float64 to match model dtype.
      Step     Time          Energy          fmax
FIRE:    0 17:30:04     -327.855232        7.977965
FIRE:    1 17:30:04     -329.458276        5.363080
FIRE:    2 17:30:04     -330.877514        3.122818
FIRE:    3 17:30:05     -331.354351        5.670693
FIRE:    4 17:30:05     -331.648295        3.683286
FIRE:    5 17:30:05     -331.949413        1.942261
FIRE:    6 17:30:05     -332.084722        1.603563
FIRE:    7 17:30:05     -332.113045        2.788081
FIRE:    8 17:30:05     -332.133695        2.662009
FIRE:    9 17:30:05     -332.172167        2.413136
FIRE:   10 17:30:05     -332.223249        2.048421
FIRE:   11 17:30:06     -332.280189        1.578861
FIRE:   12 17:30:06     -332.335864        1.288674
FIRE:   13 17:30:06     -332.384223        1.242418
FIRE:   14 17:30:06     -332.422157        1.194043
FIRE:   15 17:30:06     -332.454043        1.135251
FIRE:   16 17:30:0

  torch.load(f=model_path, map_location=device)


      Step     Time          Energy          fmax
FIRE:    0 17:30:26     -327.961898        7.927239
FIRE:    1 17:30:26     -329.618216        5.348427
FIRE:    2 17:30:26     -331.076300        2.827719
FIRE:    3 17:30:26     -331.559966        5.695150
FIRE:    4 17:30:26     -331.859559        3.737879
FIRE:    5 17:30:26     -332.173343        1.991001
FIRE:    6 17:30:27     -332.325170        1.650514
FIRE:    7 17:30:27     -332.367092        2.693622
FIRE:    8 17:30:27     -332.387650        2.574778
FIRE:    9 17:30:27     -332.426131        2.339861
FIRE:   10 17:30:27     -332.477662        1.994892
FIRE:   11 17:30:27     -332.535890        1.549486
FIRE:   12 17:30:27     -332.594051        1.309058
FIRE:   13 17:30:27     -332.646291        1.264890
FIRE:   14 17:30:28     -332.689309        1.217814
FIRE:   15 17:30:28     -332.727297        1.158854
FIRE:   16 17:30:28     -332.765287        1.395812
FIRE:   17 17:30:28     -332.814072        1.684763
FIRE:   18 17:

  torch.load(f=model_path, map_location=device)


Using head Default out of ['Default']
No dtype selected, switching to float64 to match model dtype.
      Step     Time          Energy          fmax
FIRE:    0 17:30:46     -327.590329        8.018114
FIRE:    1 17:30:46     -329.247143        5.342596
FIRE:    2 17:30:46     -330.668827        3.230656
FIRE:    3 17:30:46     -331.132252        6.038186
FIRE:    4 17:30:47     -331.447963        3.895304
FIRE:    5 17:30:47     -331.767507        1.996639
FIRE:    6 17:30:47     -331.909119        1.652195
FIRE:    7 17:30:47     -331.940041        2.891673
FIRE:    8 17:30:47     -331.962374        2.759628
FIRE:    9 17:30:47     -332.004021        2.497684
FIRE:   10 17:30:48     -332.059388        2.110900
FIRE:   11 17:30:48     -332.121180        1.608426
FIRE:   12 17:30:48     -332.181649        1.316279
FIRE:   13 17:30:48     -332.234206        1.270387
FIRE:   14 17:30:48     -332.275460        1.221522
FIRE:   15 17:30:48     -332.310155        1.160887
FIRE:   16 17:30:4

  torch.load(f=model_path, map_location=device)


      Step     Time          Energy          fmax
FIRE:    0 17:31:08     -327.835396        8.187131
FIRE:    1 17:31:09     -329.516932        5.420283
FIRE:    2 17:31:09     -330.962315        3.268367
FIRE:    3 17:31:09     -331.443305        6.001253
FIRE:    4 17:31:09     -331.754659        3.842439
FIRE:    5 17:31:09     -332.072229        1.992074
FIRE:    6 17:31:09     -332.216136        1.651936
FIRE:    7 17:31:09     -332.250652        2.972985
FIRE:    8 17:31:10     -332.272778        2.836729
FIRE:    9 17:31:10     -332.314035        2.567058
FIRE:   10 17:31:10     -332.368887        2.170222
FIRE:   11 17:31:10     -332.430132        1.656471
FIRE:   12 17:31:10     -332.490148        1.310441
FIRE:   13 17:31:10     -332.542450        1.264975
FIRE:   14 17:31:10     -332.583712        1.215963
FIRE:   15 17:31:10     -332.618724        1.154807
FIRE:   16 17:31:10     -332.654150        1.628205
FIRE:   17 17:31:11     -332.702384        1.857051
FIRE:   18 17:

  torch.load(f=model_path, map_location=device)


Using head Default out of ['Default']
No dtype selected, switching to float64 to match model dtype.
      Step     Time          Energy          fmax
FIRE:    0 17:31:34     -327.946175        8.042061
FIRE:    1 17:31:34     -329.615142        5.349309
FIRE:    2 17:31:35     -331.068314        2.958336
FIRE:    3 17:31:35     -331.550874        5.816451
FIRE:    4 17:31:35     -331.854079        3.792356
FIRE:    5 17:31:35     -332.169108        1.991137
FIRE:    6 17:31:35     -332.318352        1.651539
FIRE:    7 17:31:36     -332.357205        2.780889
FIRE:    8 17:31:36     -332.378215        2.656112
FIRE:    9 17:31:36     -332.417492        2.409460
FIRE:   10 17:31:36     -332.469958        2.047245
FIRE:   11 17:31:36     -332.529008        1.579598
FIRE:   12 17:31:36     -332.587624        1.309525
FIRE:   13 17:31:36     -332.639771        1.264192
FIRE:   14 17:31:36     -332.682140        1.215380
FIRE:   15 17:31:36     -332.719065        1.154414
FIRE:   16 17:31:3

  torch.load(f=model_path, map_location=device)


FIRE:    1 17:31:46     -329.635974        5.348100
FIRE:    2 17:31:46     -331.095906        2.816405
FIRE:    3 17:31:47     -331.578745        5.682688
FIRE:    4 17:31:47     -331.877696        3.728708
FIRE:    5 17:31:47     -332.191175        1.990961
FIRE:    6 17:31:47     -332.343329        1.650314
FIRE:    7 17:31:47     -332.385693        2.685794
FIRE:    8 17:31:47     -332.406157        2.567969
FIRE:    9 17:31:47     -332.444474        2.335067
FIRE:   10 17:31:47     -332.495811        1.993027
FIRE:   11 17:31:47     -332.553863        1.551259
FIRE:   12 17:31:48     -332.611915        1.308563
FIRE:   13 17:31:48     -332.664137        1.263329
FIRE:   14 17:31:48     -332.707218        1.214688
FIRE:   15 17:31:48     -332.745311        1.153948
FIRE:   16 17:31:48     -332.783362        1.383215
FIRE:   17 17:31:48     -332.832081        1.675137
FIRE:   18 17:31:48     -332.900252        1.586983
FIRE:   19 17:31:49     -332.982987        1.070628
FIRE:   20 1

  torch.load(f=model_path, map_location=device)


Using head Default out of ['Default']
No dtype selected, switching to float64 to match model dtype.
      Step     Time          Energy          fmax
FIRE:    0 17:32:06     -327.873336        8.151734
FIRE:    1 17:32:06     -329.550460        5.399556
FIRE:    2 17:32:07     -330.996292        3.230210
FIRE:    3 17:32:07     -331.480774        5.958207
FIRE:    4 17:32:07     -331.790223        3.830253
FIRE:    5 17:32:07     -332.106980        1.991674
FIRE:    6 17:32:07     -332.251253        1.651310
FIRE:    7 17:32:07     -332.285850        2.964067
FIRE:    8 17:32:07     -332.307903        2.827111
FIRE:    9 17:32:07     -332.349014        2.556244
FIRE:   10 17:32:08     -332.403646        2.158139
FIRE:   11 17:32:08     -332.464613        1.643681
FIRE:   12 17:32:08     -332.524327        1.309714
FIRE:   13 17:32:08     -332.576364        1.264274
FIRE:   14 17:32:08     -332.617464        1.215286
FIRE:   15 17:32:08     -332.652442        1.154162
FIRE:   16 17:32:0

  torch.load(f=model_path, map_location=device)


      Step     Time          Energy          fmax
FIRE:    0 17:32:18     -327.911670        8.067560
FIRE:    1 17:32:19     -329.583899        5.368197
FIRE:    2 17:32:19     -331.035978        3.013198
FIRE:    3 17:32:19     -331.516723        5.872987
FIRE:    4 17:32:19     -331.822084        3.814965
FIRE:    5 17:32:19     -332.138047        1.991249
FIRE:    6 17:32:19     -332.286378        1.651609
FIRE:    7 17:32:19     -332.324321        2.814906
FIRE:    8 17:32:19     -332.345553        2.688183
FIRE:    9 17:32:20     -332.385228        2.437575
FIRE:   10 17:32:20     -332.438187        2.069295
FIRE:   11 17:32:20     -332.497713        1.593401
FIRE:   12 17:32:20     -332.556673        1.309746
FIRE:   13 17:32:20     -332.608941        1.264327
FIRE:   14 17:32:20     -332.651185        1.215468
FIRE:   15 17:32:20     -332.687805        1.154488
FIRE:   16 17:32:20     -332.724488        1.506687
FIRE:   17 17:32:21     -332.772736        1.773569
FIRE:   18 17:

  torch.load(f=model_path, map_location=device)


Using head Default out of ['Default']
No dtype selected, switching to float64 to match model dtype.
      Step     Time          Energy          fmax
FIRE:    0 17:32:43     -327.798563        7.966021
FIRE:    1 17:32:43     -329.421520        5.343591
FIRE:    2 17:32:43     -330.843594        3.101457
FIRE:    3 17:32:43     -331.310327        5.812252
FIRE:    4 17:32:44     -331.611465        3.766899
FIRE:    5 17:32:44     -331.919695        1.975567
FIRE:    6 17:32:44     -332.058951        1.631066
FIRE:    7 17:32:44     -332.089546        2.820986
FIRE:    8 17:32:44     -332.110750        2.693791
FIRE:    9 17:32:44     -332.150285        2.442176
FIRE:   10 17:32:44     -332.202844        2.072185
FIRE:   11 17:32:44     -332.261528        1.593598
FIRE:   12 17:32:45     -332.319033        1.301062
FIRE:   13 17:32:45     -332.369132        1.254933
FIRE:   14 17:32:45     -332.408597        1.207238
FIRE:   15 17:32:45     -332.441932        1.147387
FIRE:   16 17:32:4

  torch.load(f=model_path, map_location=device)


FIRE:    1 17:33:03     -329.637171        5.348138
FIRE:    2 17:33:03     -331.096362        2.815973
FIRE:    3 17:33:03     -331.577867        5.706002
FIRE:    4 17:33:04     -331.877512        3.740530
FIRE:    5 17:33:04     -332.191274        1.990079
FIRE:    6 17:33:04     -332.343015        1.650426
FIRE:    7 17:33:04     -332.384857        2.696983
FIRE:    8 17:33:04     -332.405383        2.578488
FIRE:    9 17:33:04     -332.443809        2.344256
FIRE:   10 17:33:04     -332.495276        2.000241
FIRE:   11 17:33:04     -332.553448        1.555859
FIRE:   12 17:33:05     -332.611572        1.308635
FIRE:   13 17:33:05     -332.663791        1.263394
FIRE:   14 17:33:05     -332.706781        1.214736
FIRE:   15 17:33:05     -332.744703        1.153971
FIRE:   16 17:33:05     -332.782560        1.394666
FIRE:   17 17:33:05     -332.831136        1.685146
FIRE:   18 17:33:05     -332.899241        1.592348
FIRE:   19 17:33:05     -332.981904        1.072638
FIRE:   20 1