# Loading data for training ML models
This goal of this tutorial is to provide a simple hands-on example of how a user might load PLINDER dataset in prepparation for training machine learning models. Here, we are going to show an example of how to get the fasta sequence and smiles proteins and ligands. In the process, we will show:
- How load use `PlinderDataset` to load data direclty from a given split dataframe
- How one might use a diversity sampler along with the dataset loader
- Extract specific data one might want to use for training specific ML model

### Load dataset
We recommend users interact with the dataset using PLINDER Python API. class {class}`PlinderDataset` is the primary method of access data.

- `df`: the split to use
- `split`: the split to sample from
- `file_with_system_ids`: path to a file containing a list of system ids (default: full index)
- `store_file_path`: if True, include the file path of the source structures in the dataset
- `load_alternative_structures`: if True, include alternative structures in the dataset
- `num_alternative_structures`: number of alternative structures (apo and pred) to include

Below, we are providing a function to access class {class}`PlinderDataset` and samples from protein-ligand similarity cluster based on sampling user-defined function `sampler_func` via a warpper function `load_dataset_path`. 



In [93]:
from __future__ import annotations

from typing import Callable
import pandas as pd
from plinder.core import get_split, PlinderDataset
from pathlib import Path
from plinder.core.utils.log import setup_logger

LOG = setup_logger(__name__)

def load_dataset_path(
        split_df: pd.DataFrame,
        split: str = "train",
        split_parquet_path: Path | None =None,
        store_file_path: bool = True,
        load_alternative_structures: bool = True,
        num_alternative_structures: int = 1,
        max_num_sample: int | None = None,
        sampler_func: Callable | None =None):
    """
    Load dataset from splits dataframe

    Parameters
    ----------
    split_df : pd.DataFrame | None
        the split to use
    split : str
        the split to sample from
    file_with_system_ids : str | Path
        path to a file containing a list of system ids (default: full index)
    store_file_path : bool, default=True
        if True, include the file path of the source structures in the dataset
    load_alternative_structures : bool, default=False
        if True, include alternative structures in the dataset
    num_alternative_structures : int, default=1
        number of alternative structures (apo and pred) to include
    max_num_sample: int | None, default = None
        maximum number of sample to return
    sampler_func: Callable | None, default=None
        user-defined diversity sampler
    """
    dataset = PlinderDataset(
        df=split_df,
        split=split,
        split_parquet_path=split_parquet_path,
        store_file_path=store_file_path,
        load_alternative_structures=load_alternative_structures,
        num_alternative_structures=num_alternative_structures)
    if sampler_func is not None:
        sampler = sampler_func(split_df)
        for i in sampler:
            if (max_num_sample is not None) & ((i+1)%max_num_sample == 0):
                break
            try:
                yield dataset[i]
            except Exception as e:
                LOG.error(e)
    else:
        for i in range(len(dataset)):
            if (i+1)%max_num_sample == 0:
                break
            try:
                yield dataset[i]
            except Exception as e:
                LOG.error(e)

#### Define diversity sampler function
Here, we have provided an example of how one might use `torch.utils.data.WeightedRandomSampler`. However, users are free to sample diversity any how they see fit. For this example, we are going to use the sample dversity based on the `cluster` column in the splits dataframe.

In [90]:
def make_sampler(split_df):
    from torch.utils.data import WeightedRandomSampler
    cluster_counts = (split_df["cluster"].value_counts().rename("cluster_count"))
    split_df = split_df.merge(
        cluster_counts,
        left_on="cluster",
        right_index=True)
    cluster_weights = 1.0 / split_df.cluster_count.values
    return WeightedRandomSampler(
        weights=cluster_weights,
        num_samples=len(cluster_counts))

### Extract specific molecular format needed for training
The function `get_model_input` wraps it all together, allowing us to extract the sequence fasta and smiles needed for training.

In [96]:
def get_model_input(
        split="train",
        sampler_func=None,
        max_num_sample=10,
        num_alternative_structures=1):
    from rdkit import Chem
    split_df = get_split()
    training_set = load_dataset_path(
            split_df,
            split=split,
            split_parquet_path=None,
            store_file_path=True,
            load_alternative_structures=True,
            num_alternative_structures=num_alternative_structures,
            max_num_sample=max_num_sample,
            sampler_func=sampler_func)
    protein_ligand_path = []
    for data in training_set:
        system_dir = Path(data["path"]).parent
        protein_fasta = system_dir / "sequences.fasta"
        for ligand_sdf in (system_dir / "ligand_files/").glob("*sdf"):
            smiles = Chem.MolToSmiles(next(Chem.SDMolSupplier(ligand_sdf)))
            protein_ligand_path.append((protein_fasta, smiles))
    return protein_ligand_path


### Sample diverse training set 

In [94]:
# Get splits dataframe
split_df = get_split()
# Sample training set
training_set = get_model_input(split="train", sampler_func=make_sampler)

2024-08-28 21:00:59,439 | plinder.core.split.utils.get_split:24 | INFO : runtime succeeded: 0.00s
2024-08-28 21:00:59,440 | plinder.core.split.utils.get_split:24 | INFO : runtime succeeded: 0.00s
2024-08-28 21:00:59,570 | __main__:57 | ERROR : 373800
2024-08-28 21:00:59,571 | __main__:57 | ERROR : 340393
2024-08-28 21:00:59,571 | __main__:57 | ERROR : 397286


[GSPath('gs://plinder/2024-06/v2/systems/xk.zip')]


2024-08-28 21:01:06,311 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 3.19s
2024-08-28 21:01:06,450 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-08-28 21:01:09,751 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.15s
2024-08-28 21:01:09,886 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.81s
2024-08-28 21:01:09,888 | __main__:57 | ERROR : 385364
2024-08-28 21:01:09,888 | __main__:57 | ERROR : 347662
2024-08-28 21:01:09,889 | __main__:57 | ERROR : 407110
2024-08-28 21:01:09,889 | __main__:57 | ERROR : 332777
2024-08-28 21:01:09,889 | __main__:57 | ERROR : 327980
2024-08-28 21:01:09,890 | __main__:57 | ERROR : 377411
2024-08-28 21:01:09,890 | __main__:57 | ERROR : 349934
2024-08-28 21:01:09,891 | __main__:57 | ERROR : 343094
2024-08-28 21:01:09,891 | __main__:57 | ERROR : 394661
2024-08-28 21:01:09,891 | __main__:57 | ERROR : 392186
2024-08-28 21:01:09,892 | __main__:57 | ERROR 

[GSPath('gs://plinder/2024-06/v2/systems/uk.zip')]


2024-08-28 21:01:15,323 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 2.77s
2024-08-28 21:01:15,421 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.10s
2024-08-28 21:01:19,532 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.20s
2024-08-28 21:01:19,894 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 2.11s


[GSPath('gs://plinder/2024-06/v2/linked_structures/uk.zip')]


2024-08-28 21:01:24,605 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 2.12s
2024-08-28 21:01:24,747 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-08-28 21:01:26,175 | __main__:57 | ERROR : 349915
2024-08-28 21:01:26,176 | __main__:57 | ERROR : 407318
2024-08-28 21:01:26,176 | __main__:57 | ERROR : 375131
2024-08-28 21:01:26,176 | __main__:57 | ERROR : 358255
2024-08-28 21:01:26,177 | __main__:57 | ERROR : 389181
2024-08-28 21:01:26,177 | __main__:57 | ERROR : 356046
2024-08-28 21:01:26,177 | __main__:57 | ERROR : 395434


[GSPath('gs://plinder/2024-06/v2/systems/s5.zip')]


2024-08-28 21:01:31,816 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 3.01s
2024-08-28 21:01:31,960 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-08-28 21:01:36,062 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.18s
2024-08-28 21:01:36,429 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 2.08s


[GSPath('gs://plinder/2024-06/v2/linked_structures/s5.zip')]


2024-08-28 21:01:40,051 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 1.00s
2024-08-28 21:01:40,176 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.12s
2024-08-28 21:01:40,679 | __main__:57 | ERROR : 395508
2024-08-28 21:01:40,679 | __main__:57 | ERROR : 390288


[GSPath('gs://plinder/2024-06/v2/systems/r1.zip')]


2024-08-28 21:01:50,471 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 7.10s
2024-08-28 21:01:50,612 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-08-28 21:01:58,112 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.18s
2024-08-28 21:01:58,285 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 2.01s
2024-08-28 21:01:58,289 | __main__:57 | ERROR : 363070
2024-08-28 21:01:58,290 | __main__:57 | ERROR : 401401
2024-08-28 21:01:58,290 | __main__:57 | ERROR : 387767
2024-08-28 21:01:58,290 | __main__:57 | ERROR : 402107
2024-08-28 21:01:58,291 | __main__:57 | ERROR : 362212


### Get validation set without cluster sampling
Here, we will show how to get validation set without cluster sampling.

In [99]:
# Sample validation set
validation_set = get_model_input(split="val", sampler_func=None)

2024-08-28 21:15:54,151 | plinder.core.split.utils.get_split:24 | INFO : runtime succeeded: 0.00s


[GSPath('gs://plinder/2024-06/v2/systems/av.zip')]


2024-08-28 21:16:00,629 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 3.16s
2024-08-28 21:16:00,742 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.11s
2024-08-28 21:16:03,830 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.18s
2024-08-28 21:16:03,995 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.85s


[GSPath('gs://plinder/2024-06/v2/systems/av.zip')]


2024-08-28 21:16:06,926 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.30s
2024-08-28 21:16:07,063 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-08-28 21:16:08,892 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.20s
2024-08-28 21:16:09,098 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.97s


[GSPath('gs://plinder/2024-06/v2/linked_structures/av.zip')]


2024-08-28 21:16:13,370 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 1.65s
2024-08-28 21:16:13,503 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.13s


[GSPath('gs://plinder/2024-06/v2/systems/av.zip')]


2024-08-28 21:16:17,059 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.29s
2024-08-28 21:16:17,199 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-08-28 21:16:19,030 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.18s
2024-08-28 21:16:19,230 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.97s


[GSPath('gs://plinder/2024-06/v2/linked_structures/av.zip')]


2024-08-28 21:16:22,519 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.30s
2024-08-28 21:16:22,662 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s


[GSPath('gs://plinder/2024-06/v2/systems/b4.zip')]


2024-08-28 21:16:28,322 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 2.26s
2024-08-28 21:16:28,460 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-08-28 21:16:31,801 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.18s
2024-08-28 21:16:31,962 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.95s


[GSPath('gs://plinder/2024-06/v2/systems/b4.zip')]


2024-08-28 21:16:34,989 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.30s
2024-08-28 21:16:35,097 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.11s
2024-08-28 21:16:36,854 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.15s
2024-08-28 21:16:37,022 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.85s


[GSPath('gs://plinder/2024-06/v2/systems/b4.zip')]


2024-08-28 21:16:40,010 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.28s
2024-08-28 21:16:40,113 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.10s
2024-08-28 21:16:41,930 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.17s
2024-08-28 21:16:42,096 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.90s


[GSPath('gs://plinder/2024-06/v2/systems/bl.zip')]


2024-08-28 21:16:47,138 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 2.36s
2024-08-28 21:16:47,273 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.13s
2024-08-28 21:16:51,180 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.18s
2024-08-28 21:16:51,347 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.90s


[GSPath('gs://plinder/2024-06/v2/systems/bl.zip')]


2024-08-28 21:16:54,357 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.29s
2024-08-28 21:16:54,495 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-08-28 21:16:56,478 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.17s
2024-08-28 21:16:56,655 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.91s


[GSPath('gs://plinder/2024-06/v2/systems/by.zip')]


2024-08-28 21:17:01,563 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 2.24s
2024-08-28 21:17:01,708 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-08-28 21:17:05,159 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.18s
2024-08-28 21:17:05,338 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.99s


In [100]:
validation_set

[(PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/1avb__1__1.A_1.B__1.C/sequences.fasta'),
  'CC(=O)N[C@H]1[C@H](O[C@H]2[C@H](O)[C@@H](NC(C)=O)CO[C@@H]2CO)O[C@H](CO)[C@@H](O)[C@@H]1O'),
 (PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/1avb__1__1.A__1.E/sequences.fasta'),
  'CC(=O)N[C@H]1CO[C@H](CO)[C@@H](O)[C@@H]1O'),
 (PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/1avb__1__1.B__1.H/sequences.fasta'),
  'CC(=O)N[C@H]1CO[C@H](CO)[C@@H](O)[C@@H]1O'),
 (PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/1b4k__1__1.A__1.C_1.E/sequences.fasta'),
  'O=S(=O)([O-])[O-]'),
 (PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/1b4k__1__1.A__1.C_1.E/sequences.fasta'),
  'CCCCC(=O)O'),
 (PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/1b4k__1__2.A__2.C_2.E/sequences.fasta'),
  'O=S(=O)([O-])[O-]'),
 (PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/1b4k__1__2.A__2.C_2.E/sequences.fasta'),
  '