# Loading data for training ML models
This goal of this tutorial is to provide a simple hands-on example of how a user might load PLINDER dataset in prepparation for training machine learning models. Here, we are going to show an example of how to get the fasta sequence and smiles proteins and ligands. In the process, we will show:
- How load use `PlinderDataset` to load data direclty from a given split dataframe
- How one might use a diversity sampler along with the dataset loader
- Extract specific data one might want to use for training specific ML model

### Load dataset
We recommend users interact with the dataset using PLINDER Python API. class {class}`PlinderDataset` is the primary method of access data.

- `df`: the split to use
- `split`: the split to sample from
- `file_with_system_ids`: path to a file containing a list of system ids (default: full index)
- `store_file_path`: if True, include the file path of the source structures in the dataset
- `load_alternative_structures`: if True, include alternative structures in the dataset
- `num_alternative_structures`: number of alternative structures (apo and pred) to include

Below, we are providing a function to access class {class}`PlinderDataset` and samples from protein-ligand similarity cluster based on sampling user-defined function `sampler_func` via a warpper function `load_dataset_path`. 



In [None]:
from __future__ import annotations

from typing import Callable
import pandas as pd
from plinder.core import get_split, PlinderDataset
from pathlib import Path
from plinder.core.utils.log import setup_logger

LOG = setup_logger(__name__)

def load_dataset_path(
        split_df: pd.DataFrame,
        split: str = "train",
        split_parquet_path: Path | None =None,
        store_file_path: bool = True,
        load_alternative_structures: bool = True,
        num_alternative_structures: int = 1,
        max_num_sample: int | None = None,
        sampler_func: Callable | None =None):
    """
    Load dataset from splits dataframe

    Parameters
    ----------
    split_df : pd.DataFrame | None
        the split to use
    split : str
        the split to sample from
    file_with_system_ids : str | Path
        path to a file containing a list of system ids (default: full index)
    store_file_path : bool, default=True
        if True, include the file path of the source structures in the dataset
    num_alternative_structures : int, default=1
        number of alternative structures (apo and pred) to include
    max_num_sample: int | None, default = None
        maximum number of sample to return
    sampler_func: Callable | None, default=None
        user-defined diversity sampler
    """
    dataset = PlinderDataset(
        df=split_df,
        split=split,
        split_parquet_path=split_parquet_path,
        store_file_path=store_file_path,
        num_alternative_structures=num_alternative_structures)
    if sampler_func is not None:
        sampler = sampler_func(split_df, split)
        for i in sampler:
            if (max_num_sample is not None) & ((i+1)%max_num_sample == 0):
                break
            try:
                yield dataset[i]
            except Exception as e:
                LOG.warn(e)
    else:
        for i in range(len(dataset)):
            if (i+1)%max_num_sample == 0:
                break
            try:
                yield dataset[i]
            except Exception as e:
                LOG.warn(e)

#### Define diversity sampler function
Here, we have provided an example of how one might use `torch.utils.data.WeightedRandomSampler`. However, users are free to sample diversity any how they see fit. For this example, we are going to use the sample dversity based on the `cluster` column in the splits dataframe.

In [None]:
def make_sampler(split_df, split="train"):
    from torch.utils.data import WeightedRandomSampler
    split_df = split_df[split_df.split == split]
    cluster_counts = (split_df["cluster"].value_counts().rename("cluster_count"))
    split_df = split_df.merge(
        cluster_counts,
        left_on="cluster",
        right_index=True)
    cluster_weights = 1.0 / split_df.cluster_count.values
    return WeightedRandomSampler(
        weights=cluster_weights,
        num_samples=len(cluster_weights))

### Extract specific molecular format needed for training
The function `get_model_input` wraps it all together, allowing us to extract the sequence fasta and smiles needed for training.

In [None]:
def get_model_input(
        split="train",
        sampler_func=None,
        max_num_sample=10,
        num_alternative_structures=1):
    from rdkit import Chem
    split_df = get_split()
    training_set = load_dataset_path(
            split_df,
            split=split,
            split_parquet_path=None,
            store_file_path=True,
            num_alternative_structures=num_alternative_structures,
            max_num_sample=max_num_sample,
            sampler_func=sampler_func)
    protein_ligand_path = []
    for data in training_set:
        system_dir = Path(data["path"]).parent
        protein_fasta = system_dir / "sequences.fasta"
        for ligand_sdf in (system_dir / "ligand_files/").glob("*sdf"):
            smiles = Chem.MolToSmiles(next(Chem.SDMolSupplier(ligand_sdf)))
            protein_ligand_path.append((protein_fasta, smiles))
    return protein_ligand_path


### Sample diverse training set 

In [None]:
# Get splits dataframe
split_df = get_split()
# Sample training set
training_set = get_model_input(split="train", sampler_func=make_sampler)

In [None]:
# Inspect result
for i in training_set:
    print(i)

### Get validation set without cluster sampling
Here, we will show how to get validation set without cluster sampling.

In [None]:
# Sample validation set without cluster sampling
validation_set = get_model_input(split="val", sampler_func=None)

In [None]:
# Inspect result
for i in validation_set:
    print(i)