## Generating and evaluating samples from MatterGen

In this tutorial, you will learn about:
1. **Conditional generation of crystal structures**: speed-accuracy trade-offs and how to steer the generation process.
2. **Evaluation of generated samples**: approximately evaluating stability, accounting for disorder when assessing diversity and novelty.

### Loading model checkpoint

Suppose you want to sample stable structures with a particular space group (e.g., 225). Here, we load a conditional model checkpoint (called `space_group`) from [Hugging Face](https://huggingface.co/microsoft/mattergen) and use the model for sampling. The model has been pre-trained on a large number of stable structures and fine-tuned on a labeled set of structures with space group labels.

To **speed up sampling**, we will reduce the number of denoising steps from 1000 to 100. This will result in lower quality samples, but will allow us to generate samples ~10x faster.

In addition, we are **preventing the model from generating certain elements** (e.g., radioactive ones) by explicitly setting the logits of the model's element predictions to `-inf` for these elements. While the model should have learned to avoid these elements during training, it may still accidentally generate them.

In [None]:
from mattergen.common.utils.data_classes import MatterGenCheckpointInfo

pretrained_model_name = "space_group"
num_steps = 100  # Number of denoising steps for sampling (default: 1000)

# Set number of steps for denoising atomic numbers
config_overrides = [f"lightning_module.diffusion_module.corruption.discrete_corruptions.atomic_numbers.d3pm.schedule.num_steps={num_steps}"]

# Disable generating unsupported elements (should be unlikely anyway, but can happen)
config_overrides += [
    "++lightning_module.diffusion_module.model.element_mask_func={_target_:'mattergen.denoiser.mask_disallowed_elements',_partial_:True}"
]

checkpoint_info = MatterGenCheckpointInfo.from_hf_hub(
    model_name=pretrained_model_name,
    config_overrides=config_overrides
)

### Generating samples

To conditionally generate samples, we will use [classifier-free guidance](https://arxiv.org/pdf/2207.12598):
$$\nabla_x \log_\gamma p(x | y) = (1 - \gamma) \nabla_x \log p(x) + \gamma \nabla_x \log p(x | y),$$
where $\gamma$ controls the guidance strength. Setting it to zero corresponds to unconditional generation, and increasing it further tends to produce samples that adhere more to the target property values, though at the expense of diversity and realism of samples.


We instantiate a `CrystalGenerator` object given the model checkpoint and the following (main) parameters:
- `batch_size` * `num_batches`: The number of samples to generate. In general, you want to maximize the batch size subject to GPU memory constraints. Here we keep it moderate to speed up sampling.
- `properties_to_condition_on`: A dictionary of properties mapping to target values. The properties must match the properties the model has seen during fine-tuning.
- `diffusion_guidance_factor`: Corresponds to the γ parameter in classifier-free guidance.

In [None]:
import os
from pathlib import Path

from mattergen.generator import CrystalGenerator

output_path = "outputs/"  # Directory to save generated structures
batch_size = 8

generator = CrystalGenerator(
    checkpoint_info=checkpoint_info,
    batch_size=batch_size,
    num_batches=1,
    properties_to_condition_on={'space_group': 225},
    diffusion_guidance_factor=2.0,
    record_trajectories=False,  # whether to store intermediate denoising steps
    sampling_config_overrides=[f"sampler_partial.N={num_steps}"],  # additional sampling overrides
)

if not os.path.exists(output_path):
    os.makedirs(output_path)

structures = generator.generate(output_dir=Path(output_path))
print(len(structures))

In [None]:
# Since sampling a large enough batch of structures takes too long for this tutorial, we will continue with some pre-sampled structures.

import pickle
with open("sampled_structures.pkl", "rb") as f:
    structures = pickle.load(f)

### Relaxing the generated structures and computing their properties

Since many properties are only well-defined for structures at the ground state, we will relax them using the [MatterSim](https://arxiv.org/pdf/2405.04967) force field. This will give us both the relaxed structures and their energies. In practice, we should ideally use DFT for this, but even then pre-relaxation with a force field is useful.

In [None]:
from mattergen.evaluation.utils.relaxation import relax_structures

relaxed_structures, energies = relax_structures(structures)

### Visualizing marginal statistics

Before we move on, let's try to get a feel for what we have generated.

In [None]:
# Plot distribution over number of distinct elements
import matplotlib.pyplot as plt

element_counts = [len(set(structure.composition.elements)) for structure in relaxed_structures]
plt.hist(element_counts, bins=range(1, max(element_counts) + 2), align='left', rwidth=0.8)
plt.xlabel('Number of Elements')
plt.ylabel('Count')
plt.show()

In [None]:
# Plot distribution of elements
from collections import Counter

elements = Counter()
for structure in relaxed_structures:
    elements.update(structure.composition.as_dict())

# sort and only keep the 15 most common elements
elements = dict(sorted(elements.items(), key=lambda x: x[1], reverse=True)[:15])

plt.bar(elements.keys(), elements.values())
plt.xlabel("Element")
plt.ylabel("Count")
plt.show()


Next, we compute their properties. Here, we compute the space groups using the `SpaceGroupAnalyzer` from `pymatgen` with slightly looser tolerances (`symprec=0.1`, `angle_tolerance=5.0`).

In [None]:
from collections import Counter

from pymatgen.core import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

from mattergen.evaluation.utils.symmetry_analysis import DefaultSpaceGroupAnalyzer


def get_space_group(
    structure: Structure,
    space_group_analyzer_cls: type[SpacegroupAnalyzer] = DefaultSpaceGroupAnalyzer,
) -> str:
    try:
        return space_group_analyzer_cls(structure=structure).get_space_group_number()
    except TypeError:
        # space group analysis failed, most likely due to overlapping atoms
        return 1

properties = {"space_group": [get_space_group(s) for s in relaxed_structures]}
Counter(properties["space_group"]).most_common(5)

### Evaluating the samples

Ideally, we would like to generate a batch of synthesizable, diverse and novel structures tha satisfy the target constraints. However, many of these properties are hard to evaluate. For example, predicting synthesizability is still an open research problem. Thus, we resort to predicting thermodynamic stability at 0K as a proxy. Likewise, we use uniqueness as a proxy for diversity.

- **Stable**: We consider a structure to be stable if its energy above the convex hull is less than 0.1 eV/atom, where the convex hull is determinted by a reference dataset. To account for deficiencies in PBE-GGA DFT energies, we use the [Materials Project energy correction scheme](https://docs.materialsproject.org/methodology/materials-methodology/thermodynamic-stability/thermodynamic-stability).
- **Unique**: We consider a structure to be unique if it is different from all other structures in the batch according to the `StructureMatcher` from `pymatgen`.
- **Novel**: We consider a structure to be novel if it is different from all other structures in the reference dataset according to the `StructureMatcher` from `pymatgen`.
- **Satisfies property constraints**: A structure satisfies the property constraints if the properties are within the specified interval given by the `property_constraints` input dictionary.

However, determining whether a structure is different from another one is [not straightforward](https://chemrxiv.org/engage/api-gateway/chemrxiv/assets/orp/resource/item/65957d349138d231611ad8f7/original/challenges-in-high-throughput-inorganic-material-prediction-and-autonomous-synthesis.pdf), e.g., because of disorder effects. To account for compositional disorder, we use a disordered structure matcher.

**NOTE**: If running the next cell results in an error, you may have to download the reference dataset first:
```bash
git lfs pull -I data-release/alex-mp/reference_MP2020correction.gz --exclude=""
```

In [None]:
import warnings

from pymatgen.io.vasp.sets import BadInputSetWarning
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility

from mattergen.evaluation.metrics.evaluator import MetricsEvaluator
from mattergen.evaluation.utils.structure_matcher import DefaultDisorderedStructureMatcher

# pymatgen can be quite verbose...
warnings.filterwarnings("ignore", category=UserWarning, message="Failed to guess oxidation states for Entry")
warnings.filterwarnings("ignore", category=UserWarning, message=".* without an oxidation state is initialized as low spin by default")
warnings.filterwarnings("ignore", category=BadInputSetWarning)

evaluator = MetricsEvaluator.from_structures_and_energies(
    structures=relaxed_structures,
    energies=energies,
    properties=properties,
    reference=None,  # load default (Alex-MP with MP energy correction)
    stability_threshold=0.1,  # energy above hull threshold (eV/atom)
    energy_correction_scheme=MaterialsProject2020Compatibility(),
    structure_matcher=DefaultDisorderedStructureMatcher(),
    property_constraints={"space_group": (225, 225)},  # (min, max)
    original_structures=structures
)

In [None]:
S = evaluator.is_stable
U = evaluator.is_unique
N = evaluator.is_novel
P = evaluator.property_capability.satisfies_property_constraints

print(f"S: {S.mean():.2f}")
print(f"U: {U.mean():.2f}")
print(f"N: {N.mean():.2f}")
print(f"SUN: {(S & U & N).mean():.2f}")

print(f"P: {P.mean():.2f}")
print(f"SUN+P: {(S & U & N & P).mean():.2f}")

If we want, we can look at a lot more metrics to understand what's going on. Fortunately, the evaluator caches intermediate results, so re-computing or composing metrics is fast.

In [None]:
from mattergen.evaluation.utils.logging import logger

logger.setLevel("WARN")  # reduce verbosity

evaluator.compute_metrics(metrics="all", pretty_print=True)

Optimizing all of these metrics at the same time is difficult as many of them are conflicting. Fortunately, the user can prioritize certain metrics over others at sampling time: For example, using more denoising time steps will increase stability at the cost of diversity, increasing the guidance factor will increase adherence to the target properties at the cost of diversity and stability, etc.