In [24]:
import yaml
import json
from xtalpaint.inpainting.config_schema import InpaintingWorkGraphConfig
from xtalpaint.data import BatchedStructures
from xtalpaint.utils.relaxation_utils import relax_structures
from ase.io import read
from pymatgen.io.ase import AseAtomsAdaptor
from IPython.display import clear_output

In [2]:
input_structures = read("test-structures.extxyz", index=':')
input_structures = {
    a.info['uuid'].replace("-", "_"): AseAtomsAdaptor.get_structure(a) for a in input_structures
}

In [3]:
len(input_structures)

5

In [4]:
param_grid = {
    "N_steps": 5,
    "coordinates_snr": 0.2,
    "n_corrector_steps": 1,
    "batch_size": 1000,
}

In [16]:
ENV_ACTIVATION_CMD = "source ~/.aiida_venvs/dev-mattergen-inpainting/bin/activate"
ENV_ACTIVATION_CMD = "source ~/.aiida_venvs/test-xtalpaint/bin/activate"

inputs = InpaintingWorkGraphConfig(
    inpainting_pipeline_params={
        "record_trajectories": False,
        "predictor_corrector": "baseline",
        "inpainting_model_params": param_grid,
        # "pretrained_name": "mattergen_base",
        'model_path': '/home/reents_t/project/test-xtalpaint/new-td-20-perc',
        "sampling_config_path": "/home/reents_t/project/test-xtalpaint/git/mattergen/sampling_conf",
    },
    structures=BatchedStructures(
        {k.replace("-", "_"): s for k, s in input_structures.items()}
    ),
    gen_inpainting_candidates_params={
        "n_inp": {
            k.replace("-", "_"): int(s.composition["H"])
            for k, s in input_structures.items()
        },
        "element": "H",
        "num_samples": 1,
    },
    relax=True,
    full_relax=True,
    full_relax_wo_pre_relax=False,
    relax_kwargs={
        "elements_to_relax": ["H"],
        "fmax": 0.01,
        "max_natoms_per_batch": 5000,
        "load_path": "MatterSim-v1.0.0-5M.pth",
        "max_n_steps": 50,
        "device": "cuda",
        "mlip": "mattersim",
        "optimizer": "BFGS",
        "return_initial_energies": False,
        "return_initial_forces": False,
        "return_final_forces": False,
    },
    gen_inpainting_candidates_options={
        "custom_scheduler_commands": f"{ENV_ACTIVATION_CMD}\nexport PYTHONBREAKPOINT=0",
    },
    options={
        "prepend_text": f"{ENV_ACTIVATION_CMD}\nexport PYTHONBREAKPOINT=0",
    },
    evaluate_params={"max_workers": 5, "metrics": ["match", "rmsd"]},
    evaluate=True,
)

## Running the inpainting workflow without AiiDA

In [17]:
print(f"Processing {len(input_structures)} structures")

Processing 5 structures


## Generate inpainting candidates

In [18]:
from xtalpaint.inpainting.generate_candidates import (
    generate_inpainting_candidates,
)

In [19]:
# Step 1: Generate inpainting candidates
print("Running inpainting pipeline...")

n_inp_dict = inputs.gen_inpainting_candidates_params.n_inp
element = inputs.gen_inpainting_candidates_params.element
num_samples = inputs.gen_inpainting_candidates_params.num_samples

inpainting_candidates = generate_inpainting_candidates(
    structures=input_structures,
    n_inp=n_inp_dict,
    element=element,
    num_samples=num_samples,
)

print(f"Generated {len(inpainting_candidates)} inpainted structures")

Running inpainting pipeline...
Generated 5 inpainted structures


In [20]:
inpainting_candidates

{'20acc66e_8e38_4e5e_9e7a_c2400262cdc8': Structure Summary
 Lattice
     abc : 5.169783747 5.169783747 5.169783747
  angles : 90.0 90.0 90.0
  volume : 138.17107311088554
       A : 5.169783747 0.0 0.0
       B : 0.0 5.169783747 0.0
       C : 0.0 0.0 5.169783747
     pbc : True True True
 PeriodicSite: N (2.309, 4.894, 2.86) [0.4467, 0.9467, 0.5533]
 PeriodicSite: N (4.894, 2.86, 2.309) [0.9467, 0.5533, 0.4467]
 PeriodicSite: N (2.86, 2.309, 4.894) [0.5533, 0.4467, 0.9467]
 PeriodicSite: N (0.2756, 0.2756, 0.2756) [0.0533, 0.0533, 0.0533]
 PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
 PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
 PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
 PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
 PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
 PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
 PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
 PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
 PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
 PeriodicSite

## Run inpainting

In [21]:
from xtalpaint.inpainting.inpainting_process import (
    run_inpainting_pipeline,
    run_mpi_parallel_inpainting_pipeline,
)

In [22]:
USE_MPI_FOR_PARALLEL_INPAINTING = False

inpainting_method = (
    run_mpi_parallel_inpainting_pipeline
    if USE_MPI_FOR_PARALLEL_INPAINTING
    else run_inpainting_pipeline
)

In [23]:
config = inputs.inpainting_pipeline_params.model_dump(
                exclude_none=True
            )

inpainting_outputs = run_inpainting_pipeline(
    structures=inpainting_candidates, config=config
)

clear_output()

Converting structures to numpy:   0%|          | 0/5 [00:00<?, ?it/s]

INFO:mattergen.common.utils.eval_utils:Loading model from checkpoint: /home/reents_t/project/test-xtalpaint/new-td-20-perc/checkpoints/last.ckpt



Model config:
auto_resume: false
checkpoint_path: null
data_module:
  _recursive_: true
  _target_: mattergen.common.data.datamodule.CrystDataModule
  average_density: 0.05771451654022283
  batch_size:
    train: 128
    val: 128
  dataset_transforms:
  - _partial_: true
    _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
  max_epochs: 2200
  num_workers:
    train: 128
    val: 128
  properties: []
  root_dir: /data/user/reents_t/projects/mlip/git/mattergen/mattergen/../datasets/cache/alex_mp_20_wo_mc3d_H
  train_dataset:
    _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
    cache_path: /data/user/reents_t/projects/mlip/git/mattergen/mattergen/../datasets/cache/alex_mp_20_wo_mc3d_H/train
    dataset_transforms:
    - _partial_: true
      _target_: mattergen.common.data.dataset_transform.filter_sparse_properties
    properties: []
    transforms:
    - _partial_: true
      _target_: mattergen.common.data.transform.symmetrize_lattic



  0%|          | 0/5 [00:00<?, ?it/s]

Generating samples:   0%|          | 0/1 [00:00<?, ?it/s]


ValueError: The TD-Paint compatible GemNetT model expects the latent vector z to have the same first dimension as the number of atoms.

In [None]:
from xtalpaint.time_dependent.gemnet import TDGemNetT

In [14]:
inpainting_outputs['structures'].get_structures(strct_type='pymatgen')

{'20acc66e_8e38_4e5e_9e7a_c2400262cdc8': Structure Summary
 Lattice
     abc : 5.1697835922241255 5.1697835922241255 5.169783592224121
  angles : 90.00000250447799 90.00000250447799 90.00000250447816
  volume : 138.171060700957
       A : 5.169783592224121 0.0 -2.2597841109472938e-07
       B : -2.2597843664143413e-07 5.169783592224116 -2.2597841109472938e-07
       C : 0.0 0.0 5.169783592224121
     pbc : True True True
 PeriodicSite: N (2.309, 4.894, 2.86) [0.4467, 0.9467, 0.5533]
 PeriodicSite: N (4.894, 2.86, 2.309) [0.9467, 0.5533, 0.4467]
 PeriodicSite: N (2.86, 2.309, 4.894) [0.5533, 0.4467, 0.9467]
 PeriodicSite: N (0.2756, 0.2756, 0.2756) [0.0533, 0.0533, 0.0533]
 PeriodicSite: H (1.0, 3.885, 2.78) [0.1935, 0.7514, 0.5377]
 PeriodicSite: H (3.033, 0.5016, 4.072) [0.5867, 0.09703, 0.7876]
 PeriodicSite: H (0.5171, 2.514, 3.879) [0.1, 0.4863, 0.7503]
 PeriodicSite: H (2.79, 0.9618, 2.302) [0.5397, 0.186, 0.4453]
 PeriodicSite: H (0.08463, 3.354, 2.823) [0.01637, 0.6488, 0.546]
 

## Relax structures

In [15]:
from xtalpaint.utils.relaxation_utils import relax_structures

In [16]:
relax_kwargs = inputs.relax_kwargs.model_dump()
print(json.dumps(relax_kwargs, indent=4))

structure_labels, inpainted_structures = map(
    list,
    zip(
        *inpainting_outputs["structures"].get_structures(strct_type="pymatgen").items()
    ),
)

if inputs.relax:
    constrained_relaxation_outputs = relax_structures(
        structures=inpainted_structures,
        **relax_kwargs,
    )
if inputs.full_relax:
    relax_kwargs.pop('elements_to_relax', None)
    full_relaxation_outputs = relax_structures(
        structures=inpainted_structures,
        **relax_kwargs,
    )

constrained_relaxation_structures = dict(
    zip(
        structure_labels, constrained_relaxation_outputs[0]
    )
)

full_relaxation_structures = dict(
    zip(
        structure_labels, full_relaxation_outputs[0]
    )
)


{
    "load_path": "MatterSim-v1.0.0-5M.pth",
    "fmax": 0.01,
    "elements_to_relax": [
        "H"
    ],
    "max_natoms_per_batch": 5000,
    "max_n_steps": 50,
    "device": "cuda",
    "filter": null,
    "optimizer": "BFGS",
    "mlip": "mattersim",
    "return_initial_energies": false,
    "return_initial_forces": false,
    "return_final_forces": false
}
[32m2026-01-07 15:38:28.324[0m | [1mINFO    [0m | [36mmattersim.forcefield.potential[0m:[36mfrom_checkpoint[0m:[36m891[0m - [1mLoading the pre-trained mattersim-v1.0.0-5M.pth model[0m
  0%|          | 0/5 [00:00<?, ?it/s]

  atoms.set_calculator(DummyBatchCalculator())


100%|██████████| 5/5 [00:33<00:00,  6.73s/it]
[32m2026-01-07 15:39:02.189[0m | [1mINFO    [0m | [36mmattersim.forcefield.potential[0m:[36mfrom_checkpoint[0m:[36m891[0m - [1mLoading the pre-trained mattersim-v1.0.0-5M.pth model[0m
100%|██████████| 5/5 [00:49<00:00,  9.90s/it]


## Evaluate the inpainted structures with repspect to the initial reference

In [17]:
from xtalpaint.eval import evaluate_inpainting
import pandas as pd

In [18]:
rmsd_inpainted_structures = evaluate_inpainting(
    inpainted_structures=inpainting_outputs["structures"],
    reference_structures=input_structures,
    metric="rmsd",
    max_workers=3,
    normalization_element='H',
)

matches_inpainted_structures = evaluate_inpainting(
    inpainted_structures=inpainting_outputs["structures"],
    reference_structures=input_structures,
    metric="match",
    max_workers=3,
)

inpainted_evaluation = pd.merge(
    rmsd_inpainted_structures, matches_inpainted_structures, left_index=True, right_index=True
)

rmsd_constrained_relaxation = evaluate_inpainting(
    inpainted_structures=constrained_relaxation_structures,
    reference_structures=input_structures,
    metric="rmsd",
    max_workers=3,
    normalization_element='H',
)
matches_constrained_relaxation = evaluate_inpainting(
    inpainted_structures=constrained_relaxation_structures,
    reference_structures=input_structures,
    metric="match",
    max_workers=3,
)

constrained_relaxation_evaluation = pd.merge(
    rmsd_constrained_relaxation, matches_constrained_relaxation, left_index=True, right_index=True
)

 80%|████████  | 4/5 [00:00<00:00, 13.49it/s]
 80%|████████  | 4/5 [00:00<00:00, 27.33it/s]
 80%|████████  | 4/5 [00:00<00:00, 15.04it/s]
 80%|████████  | 4/5 [00:00<00:00, 31.99it/s]


In [19]:
inpainted_evaluation

Unnamed: 0_level_0,rmsd,match
keys,Unnamed: 1_level_1,Unnamed: 2_level_1
20acc66e_8e38_4e5e_9e7a_c2400262cdc8,1.184509,False
47b9a869_9b1e_438b_8c93_f5ac654bfdd8,1.319913,False
662c7351_ee76_48ea_bab7_b733e1fdf607,1.804719,False
7fa282c5_4971_46f4_8b3b_776595a0fa06,0.813573,False
c436bbf4_9aef_44a8_8960_00227f79a32f,1.317161,False


In [20]:
constrained_relaxation_evaluation

Unnamed: 0_level_0,rmsd,match
keys,Unnamed: 1_level_1,Unnamed: 2_level_1
20acc66e_8e38_4e5e_9e7a_c2400262cdc8,0.783026,False
47b9a869_9b1e_438b_8c93_f5ac654bfdd8,0.326994,True
662c7351_ee76_48ea_bab7_b733e1fdf607,2.046668,False
7fa282c5_4971_46f4_8b3b_776595a0fa06,0.017497,True
c436bbf4_9aef_44a8_8960_00227f79a32f,1.001336,False
