# Install necessary libraries

In [None]:
wget https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh
bash Miniforge3-Linux-x86_64.sh
source ~/miniforge3/bin/activate
# mamba activate latentdiff

In [None]:
mamba create -n latentdiff python=3.10 -c defaults
mamba activate latentdiff

eval "$(mamba shell hook --shell bash)"
# install pytorch according to instructions (use CUDA version for your system)
# https://pytorch.org/get-started/
# mamba install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia -c defaults
mamba create -n latentdiff python=3.10 pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=11.8 -c pytorch -c nvidia



# install pytorch geometric (use CUDA version for your system)
# https://pytorch-geometric.readthedocs.io/
pip install torch_geometric
pip install torch_sparse torch_scatter torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.3.1+cu118.html

# install other libraries (see requirements.txt for versions)
pip install lightning==2.4.0 hydra-core==1.* hydra-colorlog
mamba install ase==3.23.0  # individually installed due to dependency conflict
mamba install matminer==0.9.2  # individually installed due to dependency conflict
mamba install smact==2.6 openbabel==3.1.1 jupyterlab pandas seaborn joblib yaml -c conda-forge
pip install pyxtal==0.6.7 mofchecker==0.9.6 rdkit==2024.3.5 e3nn==0.5.1 posebusters==0.3.1 download==0.3.5 ipdb wandb rootutils rich pathos p-tqdm einops svgwrite cairosvg reportlab lmdb torchdiffeq huggingface_hub

mamba install notebook ipykernel -c conda-forge

python -m ipykernel install --user --name=latentdiff --display-name "Python (latentdiff)"


# DataLoader

In [1]:
import os
import warnings
from typing import Callable, List, Optional

import numpy as np
import torch
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader

from src.data.components.preprocessing_utils import preprocess
from src.data.mp_20_datamodule import JointDataModule
# from src.data.joint_datamodule import JointDataModule


warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", DeprecationWarning)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MP20(InMemoryDataset):
    """The MP20 dataset from Materials Project, as a PyG InMemoryDataset.

    In order to create a torch_geometric.data.InMemoryDataset, you need to implement four fundamental methods:
    - InMemoryDataset.raw_file_names(): A list of files in the raw_dir which needs to be found in order to skip the download.
    - InMemoryDataset.processed_file_names(): A list of files in the processed_dir which needs to be found in order to skip the processing.
    - InMemoryDataset.download(): Downloads raw data into raw_dir.
    - InMemoryDataset.process(): Processes raw data and saves it into the processed_dir.

    Args:
        root (str): Root directory where the dataset should be saved.
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        pre_filter (callable, optional): A function that takes in an
            :obj:`torch_geometric.data.Data` object and returns a boolean
            value, indicating whether the data object should be included in the
            final dataset. (default: :obj:`None`)
        force_reload (bool, optional): Whether to re-process the dataset.
            (default: :obj:`False`)
    """

    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        pre_filter: Optional[Callable] = None,
        force_reload: bool = False,
    ) -> None:
        super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload)
        self.load(self.processed_paths[0])

    @property
    def raw_file_names(self) -> List[str]:
        return ["all.csv"]

    @property
    def processed_file_names(self) -> List[str]:
        return ["mp20.pt"]

    def download(self) -> None:
        from huggingface_hub import hf_hub_download

        hf_hub_download(
            repo_id="chaitjo/MP20_ADiT",
            filename="raw/all.csv",
            repo_type="dataset",
            local_dir=self.root,
        )

    def process(self) -> None:
        if os.path.exists(os.path.join(self.root, "raw/all.pt")):
            cached_data = torch.load(os.path.join(self.root, "raw/all.pt"))
        else:
            cached_data = preprocess(
                os.path.join(self.root, "raw/all.csv"),
                niggli=True,
                primitive=False,
                graph_method="crystalnn",
                prop_list=["formation_energy_per_atom"],
                use_space_group=True,
                tol=0.1,
                num_workers=32,
            )
            torch.save(cached_data, os.path.join(self.root, "raw/all.pt"))

        data_list = []
        for data_dict in cached_data:
            # extract attributes from data_dict
            graph_arrays = data_dict["graph_arrays"]
            atom_types = graph_arrays["atom_types"]
            frac_coords = graph_arrays["frac_coords"]
            cell = graph_arrays["cell"]
            lattices = graph_arrays["lattices"]
            lengths = graph_arrays["lengths"]
            angles = graph_arrays["angles"]
            num_atoms = graph_arrays["num_atoms"]

            # normalize the lengths of lattice vectors, which makes
            # lengths for materials of different sizes at same scale
            _lengths = lengths / float(num_atoms) ** (1 / 3)
            # convert angles of lattice vectors to be in radians
            _angles = np.radians(angles)
            # add scaled lengths and angles to graph arrays
            graph_arrays["length_scaled"] = _lengths
            graph_arrays["angles_radians"] = _angles
            graph_arrays["lattices_scaled"] = np.concatenate([_lengths, _angles])

            data = Data(
                id=data_dict["mp_id"],
                atom_types=torch.LongTensor(atom_types),
                frac_coords=torch.Tensor(frac_coords),
                cell=torch.Tensor(cell).unsqueeze(0),
                lattices=torch.Tensor(lattices).unsqueeze(0),
                lattices_scaled=torch.Tensor(graph_arrays["lattices_scaled"]).unsqueeze(0),
                lengths=torch.Tensor(lengths).view(1, -1),
                lengths_scaled=torch.Tensor(graph_arrays["length_scaled"]).view(1, -1),
                angles=torch.Tensor(angles).view(1, -1),
                angles_radians=torch.Tensor(graph_arrays["angles_radians"]).view(1, -1),
                num_atoms=torch.LongTensor([num_atoms]),
                num_nodes=torch.LongTensor([num_atoms]),  # special attribute used for PyG batching
                token_idx=torch.arange(num_atoms),
                dataset_idx=torch.tensor(
                    [0], dtype=torch.long
                ),  # 0 --> indicates periodic/crystal
            )
            # 3D coordinates (NOTE do not zero-center prior to graph construction)
            data.pos = torch.einsum(
                "bi,bij->bj",
                data.frac_coords,
                torch.repeat_interleave(data.cell, data.num_atoms, dim=0),
            )
            # space group number
            data.spacegroup = torch.LongTensor([data_dict["spacegroup"]])

            if self.pre_filter is not None and not self.pre_filter(data):
                continue
            if self.pre_transform is not None:
                data = self.pre_transform(data)

            data_list.append(data)

        self.save(data_list, os.path.join(self.root, "processed/mp20.pt"))

In [3]:
mp20_dataset = MP20(root="data/mp_20")

In [4]:
mp20_train_dataset = mp20_dataset[:27138]
mp20_val_dataset = mp20_dataset[27138 : 27138 + 9046]
mp20_test_dataset = mp20_dataset[27138 + 9046 :]

In [5]:
mp20_train_dataset = mp20_train_dataset[
    : int(len(mp20_train_dataset) * 1.0)
]
mp20_val_dataset = mp20_val_dataset[
    : int(len(mp20_val_dataset) * 1.0)
]
mp20_test_dataset = mp20_test_dataset[
    : int(len(mp20_test_dataset) * 1.0)
]

In [6]:
train_dataloader = DataLoader(
            dataset=mp20_train_dataset,
            batch_size=256,
            num_workers=8,
            pin_memory=False,
            shuffle=True,
            drop_last=True,
        )

In [7]:
for d in train_dataloader:
    break

In [8]:
d

DataBatch(id=[256], atom_types=[2643], frac_coords=[2643, 3], cell=[256, 3, 3], lattices=[256, 6], lattices_scaled=[256, 6], lengths=[256, 3], lengths_scaled=[256, 3], angles=[256, 3], angles_radians=[256, 3], num_atoms=[256], token_idx=[2643], dataset_idx=[256], pos=[2643, 3], spacegroup=[256], num_nodes=[1], batch=[2643], ptr=[257])

# Model Definination

In [9]:
from hydra import initialize_config_dir, compose
import hydra
from omegaconf import OmegaConf

# Initialize Hydra with the config path (absolute or relative to this notebook)
config_path = "/notebooks/Latent_Pretraining_new/all-atom-diffusion-transformer/configs"  # directory containing train_autoencoder.yaml
config_name = "train_autoencoder.yaml"

with initialize_config_dir(config_dir=config_path, version_base="1.3"):
    cfg = compose(config_name=config_name)

# Now you can access config parameters
# print(OmegaConf.to_yaml(cfg))

In [10]:
from src.models.vae_module import VariationalAutoencoderLitModule
from src.models.encoders.transformer import TransformerEncoder
from src.models.decoders.transformer import TransformerDecoder

In [21]:
encoder = TransformerEncoder(max_num_elements=100,d_model=512,nhead=8,dim_feedforward=2048,activation = "gelu", dropout=0.0,norm_first=True,bias=True,num_layers=8)
decoder = TransformerDecoder(max_num_elements=100,d_model=512,nhead=8,dim_feedforward=2048,activation = "gelu", dropout=0.0,norm_first=True,bias=True,num_layers=8)

In [14]:
optimizer = hydra.utils.instantiate(cfg.autoencoder_module.optimizer)
scheduler = hydra.utils.instantiate(cfg.autoencoder_module.scheduler)
scheduler_frequency = cfg.autoencoder_module.scheduler_frequency
loss_weights = cfg.autoencoder_module.loss_weights
augmentations = cfg.autoencoder_module.augmentations
visualization = cfg.autoencoder_module.visualization
compile = cfg.autoencoder_module.compile

In [31]:
visualization

{'visualize': True, 'save_dir': '${paths.viz_dir}/'}

In [22]:
vae_model = VariationalAutoencoderLitModule(encoder=encoder, decoder=decoder,latent_dim=8,optimizer=optimizer,scheduler=scheduler
                                           , scheduler_frequency=scheduler_frequency, loss_weights=loss_weights, augmentations=augmentations,
                                           visualization=visualization,compile=compile)

In [25]:
output = vae_model.training_step(d,0)

In [26]:
output

tensor(46.7873, grad_fn=<AddBackward0>)

In [None]:
CUDA_VISIBLE_DEVICES=0 python src/train_autoencoder.py

In [None]:
pip install performer-pytorch downloads torch 2.7.0