In [1]:
# Libraries for QM9 dataset
from typing import Dict
from torch_geometric.datasets import QM9
import py3Dmol
from utils import qm9_to_xyz


# Libraries for OC20 dataset
from fairchem.core.preprocessing import AtomsToGraphs
from multiprocessing import Pool
import lmdb
import ase
import pickle
import os
import numpy as np
import torch
from tqdm import tqdm
import re
import pandas as pd
import random
from fairchem.core.datasets import LmdbDataset


# Libraries for Gold dataset
from ase.db import connect
from utils import AtomGraphConverter
from torch_geometric.data import Data
from ase import Atoms

# QM9 dataset

In [2]:
path ='data/qm9_data'
dataset = QM9(path)

In [3]:
# Visualizing an entry of the QM9 dataset
xyz_str = qm9_to_xyz(dataset[10])
viewer = py3Dmol.view(width=400, height=400)
viewer.addModel(xyz_str, 'xyz')  # Load the molecule from XYZ string.
viewer.setStyle({'sphere': {}})  # Use stick representations (alternatives: 'sphere', 'line').
viewer.zoomTo()  # Automatically zoom to the molecule.
viewer.show()

# OC20 dataset
The dataset can be downloaded from [this link](https://fair-chem.github.io/core/datasets/oc20.html) using the command-line parameters `--task s2ef --split 200k`. `s2ef_200k.csv` containing energy and force values is available in data directory. We then extract a specific subset focusing on inter-metallic slabs and nonmetal slabs.

In addition, trajectory files for all the catalyst slabs are available for download [here](https://dl.fbaipublicfiles.com/opencatalystproject/data/slab_trajectories.tar), and a pickle file containing the mapping between adslab and slab can be obtained from [this link](https://dl.fbaipublicfiles.com/opencatalystproject/data/mapping_adslab_slab.pkl).



In [4]:
# Load mappings between adslab and slab
with open('data/mapping_adslab_slab.pkl', 'rb') as f:
    adslab_slab_mappings = pickle.load(f)

In [5]:
# Set configuration variables
NUM_WORKERS = 8

material_classes = {
    "intermetallics": 0,
    "metalloids": 1,
    "nonmetals": 2,
    "halides": 3
}
MATERIAL_CLASS = "intermetallics"  # Select the desired material class

DFT_DATA_PATH = "data/"  # Path to DFT data
SIZE = 5000  # Output split size (number of entries to process)
OUTPUT_LMDB_PATH = f"data/{MATERIAL_CLASS}_{SIZE}"  # Output LMDB directory

os.makedirs(OUTPUT_LMDB_PATH, exist_ok=True)

# Load and preprocess the 200k split of the OC20 training dataset
df_200k = pd.read_csv("data/s2ef_200k.csv")
df_200k = df_200k.drop_duplicates(subset=['sid'], keep='first')
df_200k["sid_fid"] = df_200k["sid"] + df_200k["fid"].astype(str)
df_200k.set_index("sid_fid", inplace=True)

# Select entries for the chosen material class
df_material = df_200k.query(f"cat_class == {material_classes[MATERIAL_CLASS]}")

# Randomly select a subset of the data
indices = list(range(len(df_material)))
random.shuffle(indices)
selected_indices = indices[:SIZE]
df_subset = df_material.iloc[selected_indices]

# Create a dictionary from the DataFrame. Keys are assumed to be in the format "random<sid>frame<fid>"
dict_subset = df_subset.to_dict("index")
dict_keys = list(dict_subset.keys())

# split the numbers from the letters in the fid "frame123" > ("frame","123")
categorize = re.compile("([a-zA-Z]+)([0-9]+)")


In [6]:

# Initialize the Atoms-to-Graphs converter and compile regex pattern
a2g = AtomsToGraphs(
    max_neigh=50,
    radius=6,
    r_energy=False,    # For test data
    r_forces=False,    # For test data
    r_distances=False,
    r_fixed=True,
)



In [7]:
# Define the function to write LMDB files for a given split
def write_lmdbs(mp_args):
    lmdb_idx, key_split = mp_args
    idx = 0
    lmdb_filepath = os.path.join(OUTPUT_LMDB_PATH, f"{lmdb_idx}.lmdb")
    pattern = re.compile(r'^random(\d+)frame(\d+)$')

    # Open an LMDB database for the current split
    train_db = lmdb.open(
        lmdb_filepath,
        map_size=1099511627776 * 2,
        subdir=False,
        meminit=False,
        map_async=True,
    )
    
    for key in tqdm(key_split, desc=f"Processing LMDB {lmdb_idx}"):

        match = pattern.fullmatch(key)
        if match:
            # If it matches, extract the sid as the first captured number
            system_id = int(match.group(1))
        else:
            # Retrieve the system id and get the corresponding slab id
            system_id = df_subset[key]["sid"]

        if system_id not in adslab_slab_mappings:
            continue
        
        slab_sid = adslab_slab_mappings[system_id]
        
        # Read the last frame of the trajectory for the slab
        traj_path = os.path.join(DFT_DATA_PATH, f"{slab_sid}.traj")
        atoms = ase.io.read(traj_path, -1)
        
        # Convert the atoms object to its graph representation
        image = a2g.convert(atoms)
        image.y = atoms.get_potential_energy()
        image.force = torch.tensor(atoms.get_forces())
        image.sid = torch.LongTensor([int(slab_sid[6:])])  # Extract numeric part from slab_sid (e.g., "slab123" -> 123)
        image.fid = torch.LongTensor([-1])
        
        # Set tags: mobile atoms (1) and fixed atoms (0)
        tags = np.ones(len(atoms))
        fixed_indices = atoms.constraints[0].index
        tags[fixed_indices] = 0
        image.tags = torch.LongTensor(tags)
        
        # Write the current image data to LMDB
        txn = train_db.begin(write=True)
        txn.put(str(idx).encode("ascii"), pickle.dumps(image, protocol=-1))
        txn.commit()
        train_db.sync()
        idx += 1
    
    # Store the total number of entries in the LMDB
    txn = train_db.begin(write=True)
    txn.put("length".encode("ascii"), pickle.dumps(idx, protocol=-1))
    txn.commit()
    train_db.sync()
    train_db.close()


In [8]:
# Parallelize the writing of LMDB files
dict_subset = df_subset.to_dict("index")
dict_keys = list(dict_subset.keys())
key_splits = np.array_split(dict_keys, NUM_WORKERS)
mp_args = [(i, split) for i, split in enumerate(key_splits)]

with Pool(NUM_WORKERS) as pool:
    list(pool.imap(write_lmdbs, mp_args))
    

Processing LMDB 5: 100%|██████████| 625/625 [00:00<00:00, 289438.00it/s]



Processing LMDB 3: 100%|██████████| 625/625 [00:00<00:00, 197650.61it/s]





Executing this code creates a folder with 5000 inter-metallic slab systems. By setting `MATERIAL_CLASS` to `nonmetals`, you can similarly generate a dataset of 5000 nonmetal slab systems. Both categories—approximately 5000 systems each—will be saved in the `metal_data5000` and `nonmetal_data5000` folders within the data directory. You can change the number of systems by changing the value for `SIZE`.

In [9]:
lmdb_path1 = "data/metals_5000"
lmdb_path2 = "data/nonmetals_5000"
metal_data = LmdbDataset({'src': lmdb_path1})
nonmetal_data = LmdbDataset({'src': lmdb_path2})

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
converter = AtomGraphConverter(cutoff=6.0, max_neighbors=50)

# Initialize the list to store SchNet input data
metal_input_data = []

# Iterate over the results list
for slab in metal_data:
    atomic_numbers = slab.atomic_numbers.clone().detach().long()  
    positions = slab.pos.clone().detach().float()  
    cell = slab.cell[0].clone().detach().float()  
    energy = torch.tensor(slab.y, dtype=torch.float32)  # Energy as a tensor

    # # Step 2: Create ASE Atoms object
    atoms = Atoms(numbers=atomic_numbers.numpy(), positions=positions.numpy(), cell=cell.numpy(), pbc=True)
    
    # # Step 3: Use the AtomGraphConverter to get edge_index and edge_weight (distances)
    edge_index, edge_weight, offsets = converter(atoms)
    
    # # Step 4: Create a PyTorch Geometric Data object
    data = Data(
        z=atomic_numbers.to(device),  # Atomic numbers
        pos=positions.to(device),     # Atomic positions
        y=energy.to(device),          # Target property (energy)
        cell=cell.to(device),         # Cell tensor
        edge_index=edge_index.to(device),  # Add computed edge index
        edge_weight=edge_weight.to(device),  # Add computed edge weights (distances)
    )
    
    # # Step 5: Append the data to the list
    metal_input_data.append(data)

  return torch._C._cuda_getDeviceCount() > 0


In [11]:
# Visualizing an entry of the OC20 dataset
print(metal_input_data[100])
xyz_str = qm9_to_xyz(metal_input_data[100])
viewer = py3Dmol.view(width=400, height=400)
viewer.addModel(xyz_str, 'xyz')  # Load the molecule from XYZ string.
viewer.setStyle({'sphere': {}})  # Use stick representations (alternatives: 'sphere', 'line').
viewer.zoomTo()  # Automatically zoom to the molecule.
viewer.show()

Data(edge_index=[2, 1621], y=-60.20738983154297, pos=[80, 3], z=[80], cell=[3, 3], edge_weight=[1621])


# Gold Dataset

The dataset is taken from the article :
Boes, J. R., Groenenboom, M. C., Keith, J. A., & Kitchin, J. R. (2016). Neural network and Reaxff comparison for Au properties. Int. J. Quantum Chem., 116(13), 979–987. http://dx.doi.org/10.1002/qua.25115

In [12]:
!wget https://figshare.com/ndownloader/files/11948267 -O data/gold.db

--2025-04-11 09:47:33--  https://figshare.com/ndownloader/files/11948267
Resolving figshare.com (figshare.com)... 54.229.133.209, 54.75.186.94, 54.246.172.114, ...
Connecting to figshare.com (figshare.com)|54.229.133.209|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://s3-eu-west-1.amazonaws.com/pstorage-cmu-348901238291901/11948267/data.db?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAI266R7V6O36O5JUA/20250411/eu-west-1/s3/aws4_request&X-Amz-Date=20250411T164735Z&X-Amz-Expires=10&X-Amz-SignedHeaders=host&X-Amz-Signature=f68a3ec7f184af8a3a2832dbf8090ddbba6deadfc813b648c681069009d135a9 [following]
--2025-04-11 09:47:34--  https://s3-eu-west-1.amazonaws.com/pstorage-cmu-348901238291901/11948267/data.db?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAI266R7V6O36O5JUA/20250411/eu-west-1/s3/aws4_request&X-Amz-Date=20250411T164735Z&X-Amz-Expires=10&X-Amz-SignedHeaders=host&X-Amz-Signature=f68a3ec7f184af8a3a2832dbf8090ddbba6deadfc813b648c681

In [13]:
# Loading the from DB file
db = connect('data/gold.db')
query = ['xc=PBE']

results = list(db.select(query))

results_amorphous = list(db.select(['xc=PBE', f'cluster={'amorphous'}']))
results_bulk = list(db.select(['xc=PBE','bulk']))

In [14]:
# Initialize the AtomGraphConverter class
converter = AtomGraphConverter(cutoff=6.0, max_neighbors=50)

# Initialize the list to store SchNet input data
model_data = []

# Iterate over the results list
for i, result in enumerate(results):
    # Step 1: Extract atomic numbers, positions, and cell
    if result.structure=='surface':
        spec = result.surf
    else:
        spec = result[f'{result.structure}']
    
    atomic_numbers = torch.tensor(result.numbers, dtype=torch.long)  # Atomic numbers as a tensor
    positions = torch.tensor(result.positions, dtype=torch.float32)  # Atomic positions as a tensor
    cell = torch.tensor(result.cell, dtype=torch.float32)  # Cell tensor
    energy = torch.tensor(result.energy, dtype=torch.float32)  # Energy as a tensor
    forces = torch.tensor(result.forces,dtype=torch.float32)
    force_magnitudes = torch.norm(forces, dim=1)
    max_force = torch.max(force_magnitudes)

    max_force_per_atom = torch.max(force_magnitudes)/atomic_numbers.size(0)
    energy_per_atom = energy/atomic_numbers.size(0)
    

    # Step 2: Create ASE Atoms object
    atoms = Atoms(numbers=atomic_numbers.numpy(), positions=positions.numpy(), cell=cell.numpy(), pbc=True)
    
    # Step 3: Use the AtomGraphConverter to get edge_index and edge_weight (distances)
    edge_index, edge_weight, offsets = converter(atoms)

    # Step 4: Create a PyTorch Geometric Data object
    data = Data(
        z=atomic_numbers,  # Atomic numbers
        pos=positions,     # Atomic positions
        y=energy,          # Target property (energy)
        y_atom = energy_per_atom,
        fmax_atom= max_force_per_atom,
        cell=cell,         # Cell tensor
        edge_index=edge_index,  # Add computed edge index
        edge_weight=edge_weight,  # Add computed edge weights (distances)
        structure=result.structure,
        spec=spec,
        idx=i
    )
    
    # Step 5: Append the data to the list
    model_data.append(data)

In [15]:
# Visualizing an entry of the Gold dataset
print(model_data[-500])
xyz_str = qm9_to_xyz(model_data[-500])
viewer = py3Dmol.view(width=400, height=400)
viewer.addModel(xyz_str, 'xyz')  # Load the molecule from XYZ string.
viewer.setStyle({'sphere': {}})  # Use stick representations (alternatives: 'sphere', 'line').
viewer.zoomTo()  # Automatically zoom to the molecule.
viewer.show()

Data(edge_index=[2, 750], y=-47.75715637207031, pos=[15, 3], z=[15], y_atom=-3.1838104724884033, fmax_atom=0.020437462255358696, cell=[3, 3], edge_weight=[750], structure='bulk', spec='hcp', idx=8458)
