In [28]:
%load_ext aiida
%aiida

The aiida extension is already loaded. To reload it, use:
  %reload_ext aiida


In [2]:
from dbcsi_inpainting.aiida.workgraphs import setup_inpainting_wg
from aiida.orm import StructureData
from aiida_workgraph.orm.workgraph import WorkGraphNode

from aiida_workgraph import WorkGraph, task
import yaml
import json
from dbcsi_inpainting.inpainting.config_schema import InpaintingWorkGraphConfig
from dbcsi_inpainting.aiida.data import BatchedStructures, BatchedStructuresData
from dbcsi_inpainting.utils.relaxation_utils import relax_structures

MODELS_PROJECT_ROOT: /home/reents_t/project/dev-mattergen-inpainting/git/mattergen/mattergen


In [3]:
def load_mc3d_h_structures(label):
    query_structures = QueryBuilder().append(
        Group, filters={'label': f'MC3D-with-H/{label}'},
        tag='group'
    ).append(
        StructureData, tag='structure', with_group='group'
    )

    structures = query_structures.all(flat=True)
    return structures

In [8]:
query_structures = QueryBuilder().append(
    # Group, filters={'label': 'test/structures'},
    Group, filters={'label': 'MC3D-with-H-relaxed'},
    tag='group'
).append(
    StructureData, tag='structure', with_group='group'
).limit(30)

structures = query_structures.all(flat=True)
structures = [
    # s for s in structures if 20 < len(s.sites) <= 40
    s for s in structures if len(s.sites) >= 10
    ]

In [29]:
param_grid = {
    "N_steps": 5,
    "coordinates_snr": 0.2,
    "n_corrector_steps": 1,
    "batch_size": 1000,
}

In [None]:
ENV_ACTIVATION_CMD = "source ~/.aiida_venvs/dev-mattergen-inpainting/bin/activate"

inputs = InpaintingWorkGraphConfig(
    inpainting_pipeline_params={
        "record_trajectories": False,
        "predictor_corrector": "baseline",
        "inpainting_model_params": param_grid,
        "pretrained_name": "mattergen_base",
        "sampling_config_path": "/home/reents_t/project/dev-mattergen-inpainting/git/mattergen/sampling_conf",
    },
    structures=BatchedStructures(
        {s.uuid.replace("-", "_"): s.get_pymatgen() for s in structures}
    ),
    gen_inpainting_candidates_params={
        "n_inp": {
            s.uuid.replace("-", "_"): s.get_site_kindnames().count("H")
            for s in structures
        },
        "element": "H",
        "num_samples": 1,
    },
    relax=True,
    full_relax=True,
    full_relax_wo_pre_relax=False,
    relax_kwargs={
        "elements_to_relax": ["H"],
        "fmax": 0.01,
        "max_natoms_per_batch": 5000,
        "load_path": "MatterSim-v1.0.0-5M.pth",
        "max_n_steps": 50,
        "device": "cuda",
        "mlip": "mattersim",
        "optimizer": "BFGS",
        "return_initial_energies": False,
        "return_initial_forces": False,
        "return_final_forces": False,
    },
    gen_inpainting_candidates_options={
        "custom_scheduler_commands": f"{ENV_ACTIVATION_CMD}\nexport PYTHONBREAKPOINT=0",
    },
    options={
        "prepend_text": f"{ENV_ACTIVATION_CMD}\nexport PYTHONBREAKPOINT=0",
    },
    evaluate_params={"max_workers": 5, "metrics": ["match", "rmsd"]},
    evaluate=True,
)

## Running the inpainting workflow without AiiDA

In [30]:
input_structures = {s.uuid.replace('-', '_'): s.get_pymatgen() for s in structures}

print(f"Processing {len(input_structures)} structures")

Processing 8 structures


## Generate inpainting candidates

In [31]:
from dbcsi_inpainting.inpainting.generate_candidates import (
    generate_inpainting_candidates,
)

In [14]:
# Step 1: Generate inpainting candidates
print("Running inpainting pipeline...")

n_inp_dict = inputs.gen_inpainting_candidates_params.n_inp
element = inputs.gen_inpainting_candidates_params.element
num_samples = inputs.gen_inpainting_candidates_params.num_samples

inpainted_candidates = generate_inpainting_candidates(
    structures=input_structures,
    n_inp=n_inp_dict,
    element=element,
    num_samples=num_samples,
)

print(f"Generated {len(inpainted_candidates)} inpainted structures")

Running inpainting pipeline...
Generated 8 inpainted structures


In [16]:
inpainted_candidates

{'20acc66e_8e38_4e5e_9e7a_c2400262cdc8': Structure Summary
 Lattice
     abc : 5.169783747 5.169783747 5.169783747
  angles : 90.0 90.0 90.0
  volume : 138.17107311088554
       A : 5.169783747 0.0 0.0
       B : 0.0 5.169783747 0.0
       C : 0.0 0.0 5.169783747
     pbc : True True True
 PeriodicSite: N (2.309, 4.894, 2.86) [0.4467, 0.9467, 0.5533]
 PeriodicSite: N (4.894, 2.86, 2.309) [0.9467, 0.5533, 0.4467]
 PeriodicSite: N (2.86, 2.309, 4.894) [0.5533, 0.4467, 0.9467]
 PeriodicSite: N (0.2756, 0.2756, 0.2756) [0.0533, 0.0533, 0.0533]
 PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
 PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
 PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
 PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
 PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
 PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
 PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
 PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
 PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
 PeriodicSite

## Run inpainting

In [17]:
from dbcsi_inpainting.inpainting.inpainting_process import (
    run_inpainting_pipeline,
    run_mpi_parallel_inpainting_pipeline,
)

In [18]:
USE_MPI_FOR_PARALLEL_INPAINTING = False

inpainting_method = (
    run_mpi_parallel_inpainting_pipeline
    if USE_MPI_FOR_PARALLEL_INPAINTING
    else run_inpainting_pipeline
)

In [19]:
config = inputs.inpainting_pipeline_params.model_dump(
                exclude_none=True
            )

inpainting_outputs = run_inpainting_pipeline(
    structures=inpainted_candidates, config=config
)


Converting structures to numpy:   0%|          | 0/8 [00:00<?, ?it/s]

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize_config_dir(str(self.model_path)):
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with hydra.initialize_config_dir(os.path.abspath(str(sampling_config_path))):
INFO:mattergen.common.utils.eval_utils:Loading model from checkpoint: /home/reents_t/.cache/huggingface/hub/models--microsoft--mattergen/snapshots/ea430eab64b80855029c2941b9fda15f245a771a/checkpoints/mattergen_base/checkpoints/last.ckpt



Model config:
auto_resume: true
checkpoint_path: null
data_module:
  _recursive_: true
  _target_: mattergen.common.data.datamodule.CrystDataModule
  average_density: 0.05771451654022283
  batch_size:
    train: 32
    val: 32
  max_epochs: 2200
  num_workers:
    train: 0
    val: 0
  properties:
  - dft_bulk_modulus
  - dft_band_gap
  - dft_mag_density
  - ml_bulk_modulus
  - hhi_score
  - space_group
  - energy_above_hull
  root_dir: datasets/cache/alex_mp_20/
  train_dataset:
    _target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
    cache_path: datasets/cache/alex_mp_20/train
    properties:
    - dft_bulk_modulus
    - dft_band_gap
    - dft_mag_density
    - ml_bulk_modulus
    - hhi_score
    - space_group
    - energy_above_hull
    transforms:
    - _partial_: true
      _target_: mattergen.common.data.transform.symmetrize_lattice
    - _partial_: true
      _target_: mattergen.common.data.transform.set_chemical_system_string
  transforms:
  - _partial_: 

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize_config_dir(str(self.model_path)):


{'pos': <mattergen.common.diffusion.corruption.NumAtomsVarianceAdjustedWrappedVESDE object at 0x7f6e52b65c00>}


Generating samples:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

Generating samples: 100%|██████████| 1/1 [00:08<00:00,  8.54s/it]

Returning mean trajectories as well.





In [20]:
inpainting_outputs['structures'].get_structures(strct_type='pymatgen')

{'20acc66e_8e38_4e5e_9e7a_c2400262cdc8': Structure Summary
 Lattice
     abc : 5.1697835922241255 5.1697835922241255 5.169783592224121
  angles : 90.00000250447799 90.00000250447799 90.00000250447816
  volume : 138.171060700957
       A : 5.169783592224121 0.0 -2.2597841109472938e-07
       B : -2.2597843664143413e-07 5.169783592224116 -2.2597841109472938e-07
       C : 0.0 0.0 5.169783592224121
     pbc : True True True
 PeriodicSite: N (2.309, 4.894, 2.86) [0.4467, 0.9467, 0.5533]
 PeriodicSite: N (4.894, 2.86, 2.309) [0.9467, 0.5533, 0.4467]
 PeriodicSite: N (2.86, 2.309, 4.894) [0.5533, 0.4467, 0.9467]
 PeriodicSite: N (0.2756, 0.2756, 0.2756) [0.0533, 0.0533, 0.0533]
 PeriodicSite: H (0.9766, 5.076, 4.842) [0.1889, 0.9819, 0.9366]
 PeriodicSite: H (4.015, 0.9061, 2.424) [0.7766, 0.1753, 0.4689]
 PeriodicSite: H (4.212, 2.475, 4.856) [0.8147, 0.4788, 0.9394]
 PeriodicSite: H (4.531, 4.754, 0.2646) [0.8764, 0.9195, 0.05118]
 PeriodicSite: H (3.03, 1.298, 4.513) [0.5861, 0.2511, 0.87

## Relax structures

In [21]:
from dbcsi_inpainting.utils.relaxation_utils import relax_structures

In [22]:
relax_kwargs = inputs.relax_kwargs.model_dump()
print(json.dumps(relax_kwargs, indent=4))

structure_labels, inpainted_structures = map(
    list,
    zip(
        *inpainting_outputs["structures"].get_structures(strct_type="pymatgen").items()
    ),
)

if inputs.relax:
    constrained_relaxation_outputs = relax_structures(
        structures=inpainted_structures,
        **relax_kwargs,
    )
if inputs.full_relax:
    relax_kwargs.pop('elements_to_relax', None)
    full_relaxation_outputs = relax_structures(
        structures=inpainted_structures,
        **relax_kwargs,
    )

constrained_relaxation_structures = dict(
    zip(
        structure_labels, constrained_relaxation_outputs[0]
    )
)

full_relaxation_structures = dict(
    zip(
        structure_labels, full_relaxation_outputs[0]
    )
)


{
    "load_path": "MatterSim-v1.0.0-5M.pth",
    "fmax": 0.01,
    "elements_to_relax": [
        "H"
    ],
    "max_natoms_per_batch": 5000,
    "max_n_steps": 50,
    "device": "cuda",
    "filter": null,
    "optimizer": "BFGS",
    "mlip": "mattersim",
    "return_initial_energies": false,
    "return_initial_forces": false,
    "return_final_forces": false
}
[32m2025-12-12 11:52:54.391[0m | [1mINFO    [0m | [36mmattersim.forcefield.potential[0m:[36mfrom_checkpoint[0m:[36m891[0m - [1mLoading the pre-trained mattersim-v1.0.0-5M.pth model[0m
  0%|          | 0/8 [00:00<?, ?it/s]

  atoms.set_calculator(DummyBatchCalculator())


100%|██████████| 8/8 [00:44<00:00,  5.55s/it]
[32m2025-12-12 11:53:39.021[0m | [1mINFO    [0m | [36mmattersim.forcefield.potential[0m:[36mfrom_checkpoint[0m:[36m891[0m - [1mLoading the pre-trained mattersim-v1.0.0-5M.pth model[0m

[32m2025-12-12 11:53:39.021[0m | [1mINFO    [0m | [36mmattersim.forcefield.potential[0m:[36mfrom_checkpoint[0m:[36m891[0m - [1mLoading the pre-trained mattersim-v1.0.0-5M.pth model[0m
100%|██████████| 8/8 [01:27<00:00, 10.97s/it]
100%|██████████| 8/8 [01:27<00:00, 10.97s/it]


## Evaluate the inpainted structures with repspect to the initial reference

In [23]:
from dbcsi_inpainting.eval import evaluate_inpainting
import pandas as pd

In [24]:
rmsd_inpainted_structures = evaluate_inpainting(
    inpainted_structures=inpainting_outputs["structures"],
    reference_structures=input_structures,
    metric="rmsd",
    max_workers=3,
    normalization_element='H',
)

matches_inpainted_structures = evaluate_inpainting(
    inpainted_structures=inpainting_outputs["structures"],
    reference_structures=input_structures,
    metric="match",
    max_workers=3,
)

inpainted_evaluation = pd.merge(
    rmsd_inpainted_structures, matches_inpainted_structures, left_index=True, right_index=True
)

rmsd_constrained_relaxation = evaluate_inpainting(
    inpainted_structures=constrained_relaxation_structures,
    reference_structures=input_structures,
    metric="rmsd",
    max_workers=3,
    normalization_element='H',
)
matches_constrained_relaxation = evaluate_inpainting(
    inpainted_structures=constrained_relaxation_structures,
    reference_structures=input_structures,
    metric="match",
    max_workers=3,
)

constrained_relaxation_evaluation = pd.merge(
    rmsd_constrained_relaxation, matches_constrained_relaxation, left_index=True, right_index=True
)

 88%|████████▊ | 7/8 [00:00<00:00, 18.24it/s]

 88%|████████▊ | 7/8 [00:00<00:00, 38.83it/s]

 88%|████████▊ | 7/8 [00:00<00:00, 18.12it/s]

 88%|████████▊ | 7/8 [00:00<00:00, 33.47it/s]



In [26]:
inpainted_evaluation

Unnamed: 0_level_0,rmsd,match
keys,Unnamed: 1_level_1,Unnamed: 2_level_1
20acc66e_8e38_4e5e_9e7a_c2400262cdc8,0.999285,False
47b9a869_9b1e_438b_8c93_f5ac654bfdd8,1.252835,False
662c7351_ee76_48ea_bab7_b733e1fdf607,1.525733,False
7fa282c5_4971_46f4_8b3b_776595a0fa06,1.00269,False
c436bbf4_9aef_44a8_8960_00227f79a32f,1.271059,False
dadf40a3_42bb_4247_a5bd_1bde7e84be75,0.953096,False
3e97806e_f4a1_49de_9030_ac1a8e3f2b35,1.323598,False
be415a08_7666_4f5b_9cbc_0a68f476086f,1.002258,False


In [27]:
constrained_relaxation_evaluation

Unnamed: 0_level_0,rmsd,match
keys,Unnamed: 1_level_1,Unnamed: 2_level_1
20acc66e_8e38_4e5e_9e7a_c2400262cdc8,0.889634,False
47b9a869_9b1e_438b_8c93_f5ac654bfdd8,1.214163,False
662c7351_ee76_48ea_bab7_b733e1fdf607,1.619039,False
7fa282c5_4971_46f4_8b3b_776595a0fa06,0.051663,True
c436bbf4_9aef_44a8_8960_00227f79a32f,1.00319,False
dadf40a3_42bb_4247_a5bd_1bde7e84be75,1.03112,False
3e97806e_f4a1_49de_9030_ac1a8e3f2b35,1.112725,False
be415a08_7666_4f5b_9cbc_0a68f476086f,0.320487,True
