# Datasets and Dataloaders

In [1]:
from plinder.core.loader.dataset import PlinderDataset, get_torch_loader


`PlinderDataset` provides an interface to interact with _PLINDER_ data as a dataset. It is a subclass of `torch.utils.data.Dataset`, as such subclassing it and extending should be familiar to most users. Flexibility and general applicability is our top concern when designing this interface and `PlinderDataset` allows users to not only define their own split but to also bring their own featurizer.
It can be initialized with the following parameters
```
Parameters
    ----------
    df : pd.DataFrame | None
        the split to use
    split : str
        the split to sample from
    split_parquet_path : str | Path, default=None
        split parquet file
    input_structure_priority : str, default="apo"
        Which alternate structure to proritize
    featurizer: Callable[
            [Structure, int], dict[str, torch.Tensor]
    ] = structure_featurizer,
        Transformation to turn structure to input tensors
    padding_value : int
        Value for padding uneven array
    **kwargs : Any
        Any other keyword args
```

For an example of how to write your own featurizer see [Featurizer Example](https://github.com/plinder-org/plinder/blob/c36eef9b02823ce572de905c094f6c85c03800ca/src/plinder/core/loader/featurizer.py#L16). The signature is shown below:
```
def structure_featurizer(
    structure: Structure, pad_value: int = -100
    ) -> dict[str, Any]:
```
The input is a `Structure` object and it returns dictionary of padded tensor features.


:::{note}
This is where you may want to load a `train` dataset, but for the purposes of demonstration - we will start with `val` due to smaller memory footprint, and load only a small subset of systems containing `ATP` as ligand. We also set `use_alternate_structures=False` to prevent downloading and loading alternate structures for the docs.
:::

In [2]:
val_dataset = PlinderDataset(
    split="val",
    filters=[
        ("system_num_protein_chains", "==", 1),
        ("ligand_unique_ccd_code", "in", {"ATP"}),
    ],
    use_alternate_structures=False,
)
len(val_dataset)

2024-11-27 10:19:00,554 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-11-27 10:19:01,000 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.13s
2024-11-27 10:19:01,569 | plinder.core.loader.dataset:51 | INFO : Loading 9 systems


9

In [3]:
val_data = val_dataset[1]

2024-11-27 10:19:01,755 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.00s
2024-11-27 10:19:01,756 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.00s
2024-11-27 10:19:01,902 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.00s
2024-11-27 10:19:01,902 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.00s
2024-11-27 10:19:01,903 | plinder.core.index.utils:148 | INFO : loading entries from 1 zips
2024-11-27 10:19:01,905 | plinder.core.index.utils:163 | INFO : loaded 1 entries
2024-11-27 10:19:01,906 | plinder.core.index.utils.load_entries:24 | INFO : runtime succeeded: 0.15s
2024-11-27 10:19:02,775 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.14s
2024-11-27 10:19:03,165 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 0.72s
2024-11-27 10:19:03,547 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.00s
2024-11-27 10:19:03

In [4]:
val_data['features_and_coords']['apo_features']

{'apo_sequence_atom_mask_stacked': tensor([[0, 0, 0,  ..., 0, 0, 0]]),
 'apo_input_sequence_residue_mask_stacked': tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 

In [5]:
val_data['system_id']

'5ya1__1__1.A__1.E'

In [6]:
val_loader = get_torch_loader(val_dataset)
for data in val_loader:
    test_torch = data
    break

2024-11-27 10:19:07,298 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.00s
2024-11-27 10:19:07,298 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.00s
2024-11-27 10:19:07,407 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.00s
2024-11-27 10:19:07,407 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.00s
2024-11-27 10:19:07,407 | plinder.core.index.utils:148 | INFO : loading entries from 1 zips
2024-11-27 10:19:07,410 | plinder.core.index.utils:163 | INFO : loaded 1 entries
2024-11-27 10:19:07,410 | plinder.core.index.utils.load_entries:24 | INFO : runtime succeeded: 0.11s
2024-11-27 10:19:08,233 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.17s
2024-11-27 10:19:08,373 | plinder.core.scores.links.query_links:24 | INFO : runtime succeeded: 0.49s
2024-11-27 10:19:08,531 | plinder.core.utils.cpl.download_paths:24 | INFO : runtime succeeded: 0.00s
2024-11-27 10:19:08

In [7]:
test_torch.keys()

dict_keys(['system_ids', 'alternate_structure_id', 'plinder_system', 'paths', 'features_and_coords'])

In [8]:
test_torch["system_ids"]

['3ab8__1__1.A__1.C_1.E', '8vjk__1__1.H__1.IA']

In [9]:
for k, v in test_torch["features_and_coords"].items():
    print(k, v)

sequence_features {'input_sequence_residue_feat_stack': tensor([[[[   0.,    0.,    0.,  ...,    0.,    0.,    0.],
          [   0.,    1.,    0.,  ...,    0.,    0.,    0.],
          [   0.,    0.,    0.,  ...,    0.,    0.,    0.],
          ...,
          [-100., -100., -100.,  ..., -100., -100., -100.],
          [-100., -100., -100.,  ..., -100., -100., -100.],
          [-100., -100., -100.,  ..., -100., -100., -100.]]],


        [[[   0.,    0.,    0.,  ...,    0.,    0.,    0.],
          [   0.,    0.,    0.,  ...,    0.,    0.,    0.],
          [   0.,    0.,    0.,  ...,    0.,    0.,    0.],
          ...,
          [   0.,    0.,    0.,  ...,    0.,    0.,    0.],
          [   0.,    0.,    0.,  ...,    0.,    0.,    0.],
          [   0.,    0.,    0.,  ...,    0.,    0.,    0.]]]],
       dtype=torch.float64), 'input_sequence_full_atom_feat_stack': tensor([[[[   0.,    0.,    0.,  ...,    0.,    0.,    0.],
          [   0.,    0.,    0.,  ...,    0.,    0.,    0.],

In [10]:
test_torch["features_and_coords"].keys()

dict_keys(['sequence_features', 'holo_features', 'apo_features', 'ligand_features'])

In [11]:
test_torch["features_and_coords"]["apo_features"]

{'apo_sequence_atom_mask_stacked': tensor([[[   1,    1,    1,  ..., -100, -100, -100]],
 
         [[   0,    0,    0,  ...,    1,    1,    1]]]),
 'apo_input_sequence_residue_mask_stacked': tensor([[[   1.,    1.,    1.,  ..., -100., -100., -100.]],
 
         [[   0.,    0.,    0.,  ...,    1.,    1.,    1.]]],
        dtype=torch.float64),
 'apo_protein_coordinates_stacked': tensor([[[[   7.5160,   41.8130,   -8.8570],
           [   8.7190,   42.3560,   -8.1650],
           [   8.9950,   41.4740,   -6.9490],
           ...,
           [-100.0000, -100.0000, -100.0000],
           [-100.0000, -100.0000, -100.0000],
           [-100.0000, -100.0000, -100.0000]]],
 
 
         [[[ 230.5450,  234.2060,  226.0810],
           [ 231.5230,  235.2870,  226.0580],
           [ 231.2910,  236.2520,  227.2160],
           ...,
           [ 196.5490,  215.8520,  201.2660],
           [ 199.7080,  216.2970,  200.0450],
           [ 200.3650,  216.3350,  201.3000]]]]),
 'apo_protein_calpha_coor