# Loading data for training ML models
This goal of this tutorial is to provide a simple hands-on example of how a user might load PLINDER dataset in prepparation for training machine learning models. Here, we are going to show an example of how to get the fasta sequence and smiles proteins and ligands. In the process, we will show:
- How load use `PlinderDataset` to load data direclty from a given split dataframe
- How one might use a diversity sampler along with the dataset loader
- Extract specific data one might want to use for training specific ML model

### Load dataset
We recommend users interact with the dataset using PLINDER Python API. class {class}`PlinderDataset` is the primary method of access data.

- `df`: the split to use
- `split`: the split to sample from
- `file_with_system_ids`: path to a file containing a list of system ids (default: full index)
- `store_file_path`: if True, include the file path of the source structures in the dataset
- `load_alternative_structures`: if True, include alternative structures in the dataset
- `num_alternative_structures`: number of alternative structures (apo and pred) to include

Below, we are providing a function to access class {class}`PlinderDataset` and samples from protein-ligand similarity cluster based on sampling user-defined function `sampler_func` via a warpper function `load_dataset_path`. 



In [23]:
from __future__ import annotations

from typing import Callable
import pandas as pd
from plinder.core import get_split, PlinderDataset
from pathlib import Path
from plinder.core.utils.log import setup_logger

LOG = setup_logger(__name__)

def load_dataset_path(
        split_df: pd.DataFrame,
        split: str = "train",
        split_parquet_path: Path | None =None,
        store_file_path: bool = True,
        load_alternative_structures: bool = True,
        num_alternative_structures: int = 1,
        max_num_sample: int | None = None,
        sampler_func: Callable | None =None):
    """
    Load dataset from splits dataframe

    Parameters
    ----------
    split_df : pd.DataFrame | None
        the split to use
    split : str
        the split to sample from
    file_with_system_ids : str | Path
        path to a file containing a list of system ids (default: full index)
    store_file_path : bool, default=True
        if True, include the file path of the source structures in the dataset
    load_alternative_structures : bool, default=False
        if True, include alternative structures in the dataset
    num_alternative_structures : int, default=1
        number of alternative structures (apo and pred) to include
    max_num_sample: int | None, default = None
        maximum number of sample to return
    sampler_func: Callable | None, default=None
        user-defined diversity sampler
    """
    dataset = PlinderDataset(
        df=split_df,
        split=split,
        split_parquet_path=split_parquet_path,
        store_file_path=store_file_path,
        load_alternative_structures=load_alternative_structures,
        num_alternative_structures=num_alternative_structures)
    if sampler_func is not None:
        sampler = sampler_func(split_df, split)
        for i in sampler:
            if (max_num_sample is not None) & ((i+1)%max_num_sample == 0):
                break
            try:
                yield dataset[i]
            except Exception as e:
                LOG.warn(e)
    else:
        for i in range(len(dataset)):
            if (i+1)%max_num_sample == 0:
                break
            try:
                yield dataset[i]
            except Exception as e:
                LOG.warn(e)

#### Define diversity sampler function
Here, we have provided an example of how one might use `torch.utils.data.WeightedRandomSampler`. However, users are free to sample diversity any how they see fit. For this example, we are going to use the sample dversity based on the `cluster` column in the splits dataframe.

In [47]:
def make_sampler(split_df, split="train"):
    from torch.utils.data import WeightedRandomSampler
    split_df = split_df[split_df.split == split]
    cluster_counts = (split_df["cluster"].value_counts().rename("cluster_count"))
    split_df = split_df.merge(
        cluster_counts,
        left_on="cluster",
        right_index=True)
    cluster_weights = 1.0 / split_df.cluster_count.values
    return WeightedRandomSampler(
        weights=cluster_weights,
        num_samples=len(cluster_weights))

### Extract specific molecular format needed for training
The function `get_model_input` wraps it all together, allowing us to extract the sequence fasta and smiles needed for training.

In [33]:
def get_model_input(
        split="train",
        sampler_func=None,
        max_num_sample=10,
        num_alternative_structures=1):
    from rdkit import Chem
    split_df = get_split()
    training_set = load_dataset_path(
            split_df,
            split=split,
            split_parquet_path=None,
            store_file_path=True,
            load_alternative_structures=True,
            num_alternative_structures=num_alternative_structures,
            max_num_sample=max_num_sample,
            sampler_func=sampler_func)
    protein_ligand_path = []
    for data in training_set:
        system_dir = Path(data["path"]).parent
        protein_fasta = system_dir / "sequences.fasta"
        for ligand_sdf in (system_dir / "ligand_files/").glob("*sdf"):
            smiles = Chem.MolToSmiles(next(Chem.SDMolSupplier(ligand_sdf)))
            protein_ligand_path.append((protein_fasta, smiles))
    return protein_ligand_path


### Sample diverse training set 

In [48]:
# Get splits dataframe
split_df = get_split()
# Sample training set
training_set = get_model_input(split="train", sampler_func=make_sampler)

2024-08-28 21:58:57,748 | plinder.core.split.utils.get_split:24 | INFO : runtime succeeded: 0.00s
2024-08-28 21:58:57,749 | plinder.core.split.utils.get_split:24 | INFO : runtime succeeded: 0.00s


[GSPath('gs://plinder/2024-06/v2/systems/sn.zip')]


2024-08-28 21:59:01,081 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.30s
2024-08-28 21:59:01,223 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
  LOG.warn(e)


[GSPath('gs://plinder/2024-06/v2/systems/vf.zip')]


2024-08-28 21:59:06,345 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 2.46s
2024-08-28 21:59:06,483 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-08-28 21:59:10,063 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-08-28 21:59:10,396 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.96s


[GSPath('gs://plinder/2024-06/v2/linked_structures/vf.zip')]


2024-08-28 21:59:15,371 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 2.36s
2024-08-28 21:59:15,507 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s


[GSPath('gs://plinder/2024-06/v2/systems/hr.zip')]


2024-08-28 21:59:22,184 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 3.10s
2024-08-28 21:59:22,284 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.10s
2024-08-28 21:59:26,731 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.18s
2024-08-28 21:59:26,925 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.85s


[GSPath('gs://plinder/2024-06/v2/systems/c2.zip')]


2024-08-28 21:59:32,604 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 3.02s
2024-08-28 21:59:32,740 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-08-28 21:59:37,130 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.18s
2024-08-28 21:59:37,452 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 2.03s


[GSPath('gs://plinder/2024-06/v2/linked_structures/c2.zip')]


2024-08-28 21:59:42,284 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 2.22s
2024-08-28 21:59:42,394 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.11s


[GSPath('gs://plinder/2024-06/v2/systems/d5.zip')]


2024-08-28 21:59:50,254 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 3.52s
2024-08-28 21:59:50,395 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-08-28 21:59:55,512 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.18s
2024-08-28 21:59:55,671 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.87s


[GSPath('gs://plinder/2024-06/v2/systems/zc.zip')]


2024-08-28 21:59:58,594 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.25s
2024-08-28 21:59:58,693 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.10s
2024-08-28 22:00:01,061 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.18s
2024-08-28 22:00:01,230 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.90s


In [49]:
# Inspect result
for i in training_set:
    print(i)

(PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/8vfv__1__1.M__1.ZA/sequences.fasta'), 'CC(=O)N[C@H]1CO[C@H](CO)[C@@H](O)[C@@H]1O')
(PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/8hrp__1__1.B_1.C__1.R_1.S_1.U/sequences.fasta'), 'OCC(O)CO')
(PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/8hrp__1__1.B_1.C__1.R_1.S_1.U/sequences.fasta'), 'O=C[C@H](O)COP(=O)(O)O')
(PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/8hrp__1__1.B_1.C__1.R_1.S_1.U/sequences.fasta'), 'NC(=O)c1ccc[n+]([C@@H]2O[C@H](CO[P@@](=O)([O-])O[P@@](=O)(O)OC[C@H]3O[C@@H](n4cnc5c(N)ncnc54)[C@H](O)[C@@H]3O)[C@@H](O)[C@H]2O)c1')
(PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/8c2r__1__1.C__1.NA/sequences.fasta'), 'CC(=O)N[C@H]1CO[C@H](CO)[C@@H](O)[C@@H]1O')
(PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/8d55__1__1.A__1.KA/sequences.fasta'), 'CC(=O)N[C@H]1CO[C@H](CO)[C@@H](O)[C@@H]1O')
(PosixPath('/Users/yusuf/.local/share/plind

### Get validation set without cluster sampling
Here, we will show how to get validation set without cluster sampling.

In [13]:
# Sample validation set without cluster sampling
validation_set = get_model_input(split="val", sampler_func=None)

2024-08-28 21:29:56,522 | plinder.core.split.utils.get_split:24 | INFO : runtime succeeded: 0.00s


[GSPath('gs://plinder/2024-06/v2/systems/av.zip')]


2024-08-28 21:29:59,543 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.31s
2024-08-28 21:29:59,691 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.15s
2024-08-28 21:30:01,478 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.17s
2024-08-28 21:30:01,674 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.87s


[GSPath('gs://plinder/2024-06/v2/systems/av.zip')]


2024-08-28 21:30:04,555 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.28s
2024-08-28 21:30:04,697 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-08-28 21:30:06,456 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.17s
2024-08-28 21:30:06,657 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.89s


[GSPath('gs://plinder/2024-06/v2/linked_structures/av.zip')]


2024-08-28 21:30:09,549 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.29s
2024-08-28 21:30:09,686 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s


[GSPath('gs://plinder/2024-06/v2/systems/av.zip')]


2024-08-28 21:30:13,513 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.32s
2024-08-28 21:30:13,647 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.13s
2024-08-28 21:30:15,376 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.19s
2024-08-28 21:30:15,580 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.87s


[GSPath('gs://plinder/2024-06/v2/linked_structures/av.zip')]


2024-08-28 21:30:18,517 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.31s
2024-08-28 21:30:18,666 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.15s


[GSPath('gs://plinder/2024-06/v2/systems/b4.zip')]


2024-08-28 21:30:22,415 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.29s
2024-08-28 21:30:22,555 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-08-28 21:30:24,409 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.18s
2024-08-28 21:30:24,572 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.93s


[GSPath('gs://plinder/2024-06/v2/systems/b4.zip')]


2024-08-28 21:30:27,518 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.28s
2024-08-28 21:30:27,658 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-08-28 21:30:29,494 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.19s
2024-08-28 21:30:29,654 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.90s


[GSPath('gs://plinder/2024-06/v2/systems/b4.zip')]


2024-08-28 21:30:32,583 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.30s
2024-08-28 21:30:32,690 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.11s
2024-08-28 21:30:34,519 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.18s
2024-08-28 21:30:34,679 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.90s


[GSPath('gs://plinder/2024-06/v2/systems/bl.zip')]


2024-08-28 21:30:37,580 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.29s
2024-08-28 21:30:37,722 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-08-28 21:30:39,688 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.18s
2024-08-28 21:30:39,848 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.87s


[GSPath('gs://plinder/2024-06/v2/systems/bl.zip')]


2024-08-28 21:30:42,783 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.30s
2024-08-28 21:30:42,918 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.13s
2024-08-28 21:30:44,853 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.18s
2024-08-28 21:30:45,013 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.83s


[GSPath('gs://plinder/2024-06/v2/systems/by.zip')]


2024-08-28 21:30:48,017 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.29s
2024-08-28 21:30:48,154 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-08-28 21:30:50,155 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.17s
2024-08-28 21:30:50,338 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 1.86s


In [14]:
# Inspect result
for i in validation_set:
    print(i)

(PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/1avb__1__1.A_1.B__1.C/sequences.fasta'), 'CC(=O)N[C@H]1[C@H](O[C@H]2[C@H](O)[C@@H](NC(C)=O)CO[C@@H]2CO)O[C@H](CO)[C@@H](O)[C@@H]1O')
(PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/1avb__1__1.A__1.E/sequences.fasta'), 'CC(=O)N[C@H]1CO[C@H](CO)[C@@H](O)[C@@H]1O')
(PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/1avb__1__1.B__1.H/sequences.fasta'), 'CC(=O)N[C@H]1CO[C@H](CO)[C@@H](O)[C@@H]1O')
(PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/1b4k__1__1.A__1.C_1.E/sequences.fasta'), 'O=S(=O)([O-])[O-]')
(PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/1b4k__1__1.A__1.C_1.E/sequences.fasta'), 'CCCCC(=O)O')
(PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/1b4k__1__2.A__2.C_2.E/sequences.fasta'), 'O=S(=O)([O-])[O-]')
(PosixPath('/Users/yusuf/.local/share/plinder/2024-06/v2/systems/1b4k__1__2.A__2.C_2.E/sequences.fasta'), 'CCCCC(=O)O')
(PosixPath('/U