In [1]:
pip install -e metatrain/

In [None]:
!python -m pip install pet-neighbors-convert --no-build-isolation

In [11]:
import torch

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [13]:
DATASET_FOLDER = "v1.5"
TYPE = "cleaned-pbc"
OUTPUT_FOLDER = f"{DATASET_FOLDER}_{TYPE}_merged"
SPLITS = ["train", "val", "test"]
BATCH_SIZE = 64

### Merge subsets

In [36]:
import os
from ase.io import read, write


os.makedirs(OUTPUT_FOLDER, exist_ok=True)


for split in SPLITS:
    output_path = os.path.join(OUTPUT_FOLDER, f"{split}.xyz")
    all_atoms = []

    for root, _, _files in os.walk(DATASET_FOLDER):
        target_file = os.path.join(root, TYPE, f"{split}.xyz")

        if os.path.exists(target_file):
            atoms_list = read(target_file, index=":")
            all_atoms.extend(atoms_list)

    write(output_path, all_atoms, format="xyz")

In [37]:
from metatrain.experimental.nativepet import NativePET

model = NativePET.load_checkpoint("pet-mad-latest.ckpt").eval().to(DEVICE)

In [38]:
from metatensor.torch.atomistic import ModelOutput

from metatrain.utils.data import read_systems
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists

from metatensor.torch.atomistic import ModelOutput
import torch


def get_llf(dataset_path):
    systems = read_systems(dataset_path)
    print("systems", len(systems))

    neighbor_list_options = model.requested_neighbor_lists()

    outputs = {
        "energy": ModelOutput(per_atom=False),
        "mtt::aux::energy_last_layer_features": ModelOutput(per_atom=False),
    }

    all_last_layer_features = []

    for i in range(0, len(systems), BATCH_SIZE):
        batch_systems = systems[i : i + BATCH_SIZE]

        processed_batch = []
        for system in batch_systems:
            system_with_nl = get_system_with_neighbor_lists(
                system, neighbor_list_options
            )
            processed_batch.append(
                system_with_nl.to(dtype=torch.float32, device=DEVICE)
            )

        with torch.no_grad():
            batch_predictions = model(processed_batch, outputs)
            batch_features = (
                batch_predictions["mtt::aux::energy_last_layer_features"].block().values
            )

        all_last_layer_features.append(batch_features)

        print(f"Processed batch {i//BATCH_SIZE + 1}/{(len(systems)-1)//BATCH_SIZE + 1}")

    last_layer_features = torch.cat(all_last_layer_features, dim=0)
    print("last_layer_features", len(last_layer_features))

    return last_layer_features

In [None]:
LFF_OUTPUT_FOLDER = f"{DATASET_FOLDER}_cleaned-pbc_llfs"

os.makedirs(LFF_OUTPUT_FOLDER, exist_ok=True)

for split in SPLITS:
    print(f"Processing {split}...")
    llfs_tensor = get_llf(os.path.join("v1.5_cleaned-pbc_merged", f"{split}.xyz"))

    torch.save(llfs_tensor, os.path.join(LFF_OUTPUT_FOLDER, f"{split}.pt"))