<a href="https://colab.research.google.com/github/plinder-org/moving_beyond_memorisation/blob/main/notebooks/pinder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pinder

In [163]:
!pip install -q git+https://github.com/conda-incubator/condacolab.git@0.1.x
import condacolab
condacolab.install()

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
✨🍰✨ Everything looks OK!


In [164]:
!pip install -q pinder

In [162]:
!pip install torch-cluster py3Dmol



# Pinder index

## Download the dataset

NOTE: the default location for the dataset is `~/.local/share/pinder/<release version>`

If you want to use a different location, you can do so by setting the `PINDER_BASE_DIR` environment variable.

The base dir refers to a fully qualified path name up until the `<release version>` (not inclusive).

For instance, you could:
```
export PINDER_BASE_DIR=~/my-custom-location-for-pinder/pinder
```

You can always check the current location of the dataset like so:
```python
from pinder.core import get_pinder_location
get_pinder_location()
```

In [None]:
from pinder.core import get_pinder_location
get_pinder_location()


### To download the complete dataset run the following
Note: Pinder dataset contains millions of structures and takes time to download. Due to this we will explore the capabilities of Pinder by loading the dataset incrementally, using its lazy loading capabilities.


In [None]:
from pinder.core import download_dataset
# download_dataset()


### Alternatively, use the CLI script `pinder_download`

```bash
pinder_download --help

usage: Download latest pinder dataset to disk [-h] [--pinder_base_dir PINDER_BASE_DIR] [--pinder_release PINDER_RELEASE] [--skip_inflation]

optional arguments:
  -h, --help            show this help message and exit
  --pinder_base_dir PINDER_BASE_DIR
                        specify a non-default pinder base directory
  --pinder_release PINDER_RELEASE
                        specify a pinder dataset version
  --skip_inflation      if passed, will only download the compressed archives without unpacking
```

### The full dataset should look like this

```bash
~/.local/share/pinder/<release version>/
    pdbs/
    test_set_pdbs/
    mappings/
    index.parquet
    metadata.parquet
```

## Pinder Annotations

### RCSB Annotations

Annotations were obtained from the RCSB NextGen database. The following annotations are included:

1. Oligomeric state of the protein complex (homodimer, heterodimer, oligomer or higher-order complexes)

2. Structure determination method (X-Ray, CryoEM, NMR)

3. Resolution

4. Interfacial gaps, defined as structurally-unresolved segments on PPI interfaces

5. Number of distinct atom types. Many earlier Cryo-EM structures contain only a few atom-types such as only Cα or backbone atoms

6. Whether the interface is likely to be a physiological or crystal contact, annotated using Prodigy

7. Structural elongation, defined as the maximum variance of coordinates projected onto the largest principal component. This allows detection of long end-to-end stacked complexes, likely to be repetitive with small interfaces

8. Planarity, defined as deviation of interfacial Cα atoms from the fitted plane. This interface characteristic quantifies interfacial shape complementarity. Transient complexes have smaller and more planar interfaces than permanent and structural scaffold complexes

9. Number of components, defined as the number of connected components of a 10Å Cα radius graph. This allows detection of structurally discontinuous domains

10. Intermolecular contacts (labeled as polar or apolar)

### Pinder Annotations

A core philosophy behind pinder is to provide a large, unfiltered training dataset to derive data mixes for evaluating the impact of different data selection strategies. To that end, we provide extensive tooling for leveraging annotations in filters.

A large set of quality control annotations including interface cluster, resolution, interfacial gaps, planarity, elongation, and more can be accessed via the PinderSystem object or directly in data frames. We also provide the effective MSA Depth (number of effective sequences: $N_{eff}$
) calculated for each of the test members in `PINDER-XL/S/AF2` to allow accurate performance assessment by evolutionary information.

All systems are indexed and their annotations are stored in two main `parquet` files: `index.parquet` and `metadata.parquet`. Upon loading a pinder system, this data is embedded into individual `PinderSystem` objects. Each `PinderSystem` object features an `.entry` property, exposing primary annotations from the index, and a `.metadata` property, providing detailed metadata. For detailed schemas of these properties, we will explore the index and metadata files, and later with the `IndexEntry` and `MetadataEntry` objects. Their fields are shown below for reference:

</br></br>

Table 1: Pinder Index entry fields

| Field | Type | Description |
|-------|------|-------------|
| split | string | The type of data split (e.g., 'train', 'test'). |
| id | string | The unique identifier for the dataset entry. |
| pdb_id | string | The PDB identifier associated with the entry. |
| cluster_id | string | The cluster identifier associated with the entry. |
| cluster_id_R | string | The cluster identifier associated with receptor dimer body. |
| cluster_id_L | string | The cluster identifier associated with ligand dimer body. |
| pinder_s | boolean | Flag indicating if the entry is part of the Pinder-S dataset. |
| pinder_xl | boolean | Flag indicating if the entry is part of the Pinder-XL dataset. |
| pinder_af2 | boolean | Flag indicating if the entry is part of the Pinder-AF2 dataset. |
| uniprot_R | string | The UniProt identifier for the receptor protein. |
| uniprot_L | string | The UniProt identifier for the ligand protein. |
| holo_R_pdb | string | The PDB identifier for the holo form of the receptor protein. |
| holo_L_pdb | string | The PDB identifier for the holo form of the ligand protein. |
| predicted_R_pdb | string | The PDB identifier for the predicted structure of the receptor protein. |
| predicted_L_pdb | string | The PDB identifier for the predicted structure of the ligand protein. |
| apo_R_pdb | string | The PDB identifier for the apo form of the receptor protein. |
| apo_L_pdb | string | The PDB identifier for the apo form of the ligand protein. |
| apo_R_pdbs | string | The PDB identifiers for the apo forms of the receptor protein. |
| apo_L_pdbs | string | The PDB identifiers for the apo forms of the ligand protein. |
| holo_R | boolean | Flag indicating if the holo form of the receptor protein is available. |
| holo_L | boolean | Flag indicating if the holo form of the ligand protein is available. |
| predicted_R | boolean | Flag indicating if the predicted structure of the receptor protein is available. |
| predicted_L | boolean | Flag indicating if the predicted structure of the ligand protein is available. |
| apo_R | boolean | Flag indicating if the apo form of the receptor protein is available. |
| apo_L | boolean | Flag indicating if the apo form of the ligand protein is available. |
| apo_R_quality | string | Classification of apo receptor pairing quality. |
| apo_L_quality | string | Classification of apo ligand pairing quality. |
| chain1_neff | number | The Neff value for the first chain in the protein complex. |
| chain2_neff | number | The Neff value for the second chain in the protein complex. |
| chain_R | string | The chain identifier for the receptor protein. |
| chain_L | string | The chain identifier for the ligand protein. |
| contains_antibody | boolean | Flag indicating if the protein complex contains an antibody as per SAbDab. |
| contains_antigen | boolean | Flag indicating if the protein complex contains an antigen as per SAbDab. |
| contains_enzyme | boolean | Flag indicating if the protein complex contains an enzyme as per EC ID number. |

</br></br>

Table 2: Metadata Entry Fields

| Field | Type | Description |
|-------|------|-------------|
| id | string | The unique identifier for the PINDER entry. |
| entry_id | string | The RCSB entry identifier associated with the PINDER entry. |
| method | string | The experimental method for structure determination. |
| date | string | Date of deposition into RCSB PDB. |
| release_date | string | Date of initial public release in RCSB PDB. |
| resolution | number | The resolution of the experimental structure. |
| label | string | Classification of the interface. |
| probability | number | Probability that the protein complex is a true biological complex. |
| chain1_id | string | The Receptor chain identifier associated with the dimer entry. |
| chain2_id | string | The Ligand chain identifier associated with the dimer entry. |
| assembly | integer | Which bioassembly is used to derive the structure. |
| assembly_details | string | How the bioassembly information was derived. |
| oligomeric_details | string | Description of the oligomeric state of the protein complex. |
| oligomeric_count | integer | The oligomeric count associated with the dataset entry. |
| biol_details | string | The biological assembly details associated with the dataset entry. |
| complex_type | string | The type of the complex in the dataset entry. |
| chain_1 | string | New chain id generated post-bioassembly generation (receptor chain). |
| asym_id_1 | string | The first asymmetric identifier (author chain ID) |
| chain_2 | string | New chain id generated post-bioassembly generation (ligand chain). |
| asym_id_2 | string | The second asymmetric identifier (author chain ID) |
| length1 | integer | The number of amino acids in the first (receptor) chain. |
| length2 | integer | The number of amino acids in the second (ligand) chain. |
| length_resolved_1 | integer | The structurally resolved (CA) length of the first (receptor) chain. |
| length_resolved_2 | integer | The structurally resolved (CA) length of the second (ligand) chain. |
| number_of_components_1 | integer | The number of connected components in the first (receptor) chain. |
| number_of_components_2 | integer | The number of connected components in the second (receptor) chain. |
| link_density | number | Density of contacts at the interface as reported by PRODIGY-cryst. |
| planarity | number | Deviation of interfacial Cα atoms from the fitted plane. |
| max_var_1 | number | The maximum variance of coordinates projected onto the largest principal component. |
| max_var_2 | number | The maximum variance of coordinates projected onto the largest principal component. |
| num_atom_types | integer | Number of unique atom types in structure. |
| n_residue_pairs | integer | The number of residue pairs at the interface. |
| n_residues | integer | The number of residues at the interface. |
| buried_sasa | number | The buried solvent accessible surface area upon complex formation. |
| intermolecular_contacts | integer | The total number of intermolecular contacts at the interface. |
| charged_charged_contacts | integer | Intermolecular contacts between charged amino acids. |
| charged_polar_contacts | integer | Intermolecular contacts between charged and polar amino acids. |
| charged_apolar_contacts | integer | Intermolecular contacts between charged and apolar amino acids. |
| polar_polar_contacts | integer | Intermolecular contacts between polar amino acids. |
| apolar_polar_contacts | integer | Intermolecular contacts between apolar and polar amino acids. |
| apolar_apolar_contacts | integer | Intermolecular contacts between apolar amino acids. |
| interface_atom_gaps_4A | integer | Number of interface atoms within a 4Å radius of a residue gap. |
| missing_interface_residues_4A | integer | Number of interface residues within a 4Å radius of a residue gap. |
| interface_atom_gaps_8A | integer | Number of interface atoms within an 8Å radius of a residue gap. |
| missing_interface_residues_8A | integer | Number of interface residues within an 8Å radius of a residue gap. |
| entity_id_R | integer | The RCSB PDB entity_id corresponding to the receptor dimer chain. |
| entity_id_L | integer | The RCSB PDB entity_id corresponding to the ligand dimer chain. |
| pdb_strand_id_R | string | The RCSB PDB pdb_strand_id (author chain) corresponding to the receptor dimer chain. |
| pdb_strand_id_L | string | The RCSB PDB pdb_strand_id (author chain) corresponding to the ligand dimer chain. |
| ECOD_names_R | string | The RCSB-derived ECOD domain protein family name(s) for the receptor dimer chain. |
| ECOD_names_L | string | The RCSB-derived ECOD domain protein family name(s) for the ligand dimer chain. |


## Pinder metadata API

In [None]:
from pinder.core import get_index

index = get_index()
index

In [None]:
from pinder.core import get_metadata

metadata = get_metadata()
metadata


# Break - Q&A

# Pinder System

In [None]:
from pathlib import Path
import pandas as pd
from pinder.core import PinderSystem


Example usage of Pinder index API shown below. For more detailed usage examples, check the [pinder-index](https://pinder-org.github.io/pinder/pinder-index.html) tutorial.  

In [None]:
index = get_index()
hetero_test_apo = index.query(
    '(uniprot_L != uniprot_R) and split == "test" and (apo_R and apo_L)'
)
hetero_test_apo.reset_index(drop=True, inplace=True)
hetero_test_apo


In [None]:
pinder_id = list(hetero_test_apo.id)[2]
pinder_id


## PinderSystem API - base class representing `Structure`'s in a pinder entry

In [None]:
# Simplest interface - get a single pinder system
ps = PinderSystem(pinder_id)
ps


In [None]:
holo_L, holo_R = ps.holo_ligand, ps.holo_receptor
pred_L, pred_R = ps.pred_ligand, ps.pred_receptor
apo_L, apo_R = ps.apo_ligand, ps.apo_receptor

holo_L


Notice the printed `PinderSystem` object has the following properties:
* `native` - the ground-truth dimer complex
* `holo_receptor` - the receptor chain (monomer) from the ground-truth complex
* `holo_ligand` - the ligand chain (monomer) from the ground-truth complex
* `apo_receptor` - the canonical _apo_ chain (monomer) paired to the receptor chain
* `apo_ligand` - the canonical _apo_ chain (monomer) paired to the ligand chain
* `pred_receptor` - the AlphaFold2 predicted monomer paired to the receptor chain  
* `pred_ligand` - the AlphaFold2 predicted monomer paired to the ligand chain


These properties are pointers to `Structure` objects. The `Structure` object provides the most direct mode of access to structures and associated properties.

**Note: not all systems have an apo and/or predicted structure for all chains of the ground-truth dimer complex!**

As was the case in the example above, when the alternative monomers are not available, the property will have a value of `None`.

You can determine which systems have which alternative monomer pairings _a priori_ by looking at the boolean columns in the index `apo_R` and `apo_L` for the apo receptor and ligand, and `predicted_R` and `predicted_L` for the predicted receptor and ligand, respectively.


For instance, we can load a different system that _does_ have apo receptor and ligand as such:

In [None]:
apo_system = PinderSystem(index.query('apo_R and apo_L').id.iloc[0])

In [None]:
# Visualize some of the complexes
from pathlib import Path
import py3Dmol

ground_truth = apo_system.native.filepath
apo_complex = apo_system.create_apo_complex()
apo_complex.to_pdb()

# Extract interface residues in the apo-superimposed complex for visualization
apo_contacts = apo_complex.get_contacts()
apo_interface_res = apo_complex.get_interface_residues(apo_contacts)

# Viewer documentation: https://3dmol.org/doc/GLViewer.html
view = py3Dmol.view(width=900, height=900)
view.removeAllModels()
view.setViewStyle({'style':'outline','color':'black','width':0.1})
view.setBackgroundColor('black')
# Show the reference (holo) structure
view.addModel(open(ground_truth, 'r').read(),'pdb',
             {"style": {'cartoon': {'color':'hotpink'}}})
view.setStyle({'chain':'R'},{'cartoon': {'color':'dodgerblue', 'arrows':True, 'tubes': False, 'ribbon': False, 'style': 'edged'}})
view.setStyle({'chain':'L'},{'cartoon': {'color':'hotpink', 'arrows':True, 'tubes':False, 'ribbon': False, 'style':'edged'}})
view.addSurface(py3Dmol.VDW,{'opacity':0.6,'color':'white'})

# Show the monomer-superposed apo complex
view.addModel(
    open(apo_complex.filepath, 'r').read(),
    'pdb',
    {"style": {'cartoon': {'color':'lightskyblue'}}}
)
# Color the apo ligand chain in light pink
view.setStyle({'chain':'L', "model": 1},{'cartoon': {'color':'lightpink', 'arrows':True, 'tubes': False, 'ribbon': False, 'style': 'edged'}})
# Show the interface residues as stick representation
view.setStyle({"chain": "R", "resi": list(map(str, sorted(apo_interface_res["R"])))}, {"stick": {"colorscheme": "cyanCarbon", "radius": 0.15}})
view.setStyle({"chain": "L", "resi": list(map(str, sorted(apo_interface_res["L"])))}, {"stick": {"colorscheme": "pinkCarbon", "radius": 0.15}})
view.zoomTo()
view.spin("y")
view.show()

## Classify system difficulty based on degree of conformational shift in unbound and bound

In [None]:
ps.unbound_difficulty("apo")

In [None]:
ps.unbound_difficulty("predicted")

## Illustrating utilities available in `Structure` instances

In [None]:
holo_L.filter("atom_name", mask=["CA"])


In [None]:
apo_L.filter("atom_name", mask=["CA"])


## Can also filter "in place" rather than returning a copy (a la pandas)

In [None]:
apo_L.filter("atom_name", mask=["CA"], copy=False)

In [None]:
(
    ps.apo_ligand.filter("atom_name", mask=["CA"]),
    ps.holo_ligand.filter("atom_name", mask=["CA"])
)


## Create masked unbound complex aligned to bound for apo

In [None]:
apo_complex = ps.create_apo_complex()
apo_complex


In [None]:
# dataframe representation of the Structure atom_array
apo_complex.dataframe

# Accessing and loading data for training

In order to access the train and val splits for PINDER, please refer to the [pinder documentation](https://github.com/pinder-org/pinder/tree/main?tab=readme-ov-file#%EF%B8%8F-getting-the-dataset)

Once you have downloaded the pinder dataset, either via the `pinder` package or directly through `gsutil`, you will have all of the necessary files for training.

To get a list of those systems and their split labels, refer to the `pinder` index.

**We will start by looking at the most basic way to load items from the training and validation set: via `PinderSystem` objects**

In [None]:
index = get_index()

n_samples = 50

system_ids = index.query(
    f'(apo_R & apo_L) and (split == "train")'
).sample(n_samples).id.tolist()

loader = (PinderSystem(id) for id in system_ids)

### Using the PinderLoader to load, filter and transform systems

While the `PinderSystem` object provides a self-contained access to structures associated with a dimer system, the `PinderLoader` provides a base abstraction for how to iterate over systems, apply optional filters and/or transforms, and return the systems as an iterator. This construct is covered in a [different tutorial](https://pinder-org.github.io/pinder/pinder-loader.html).

Using the `PinderLoader` is **not** necessary to load systems in your own framework. It is simply one of the provided mechanisms if you find it useful.

Pinder loader brings together filters, transforms and writers to create a generic `PinderSystem` iterator. It takes either a split name or a list of system IDs as input and can be used to sample alternative monomers to form dimer complexes to serve as e.g. features.


### Loading a specific split
Note: only the test dataset has a subset defined (`pinder_s, pinder_xl, pinder_af2`)

For train and val, you could just do:
```python
train_loader = PinderLoader(split="train")
val_loader = PinderLoader(split="val")
```


In [None]:
import torch
from pinder.core import PinderLoader
from pinder.core.loader import filters

base_filters = [
    filters.FilterByMissingHolo(),
    filters.FilterSubByContacts(min_contacts=5, radius=10.0, calpha_only=True),
    filters.FilterDetachedHolo(radius=12, max_components=2),
]
sub_filters = [
    filters.FilterSubByAtomTypes(min_atom_types=4),
    filters.FilterByHoloOverlap(min_overlap=5),
    filters.FilterByHoloSeqIdentity(min_sequence_identity=0.8),
    filters.FilterSubRmsds(rmsd_cutoff=7.5),
    filters.FilterDetachedSub(radius=12, max_components=2),
]

loader = PinderLoader(
    split="test",
    subset="pinder_af2",
    monomer_priority="holo",
    base_filters = base_filters,
    sub_filters = sub_filters
)

loader

In [None]:
len(loader)

In [None]:
data = loader[0]
print(f"Data is a {type(data)}")
system, feature_complex, target_complex = data
type(system), type(feature_complex), type(target_complex)

In [None]:
# # You can also use it as an iterator
from tqdm import tqdm
max_samples = 10
loaded_ids = []
for (system, feature_complex, target_complex) in tqdm(loader):
    if len(loaded_ids) >= max_samples:
        break
    loaded_ids.append(system.entry.id)

### Loading a specific list of systems


In [None]:
systems = [
    "1df0__A1_Q07009--1df0__B1_Q64537",
    "117e__A1_P00817--117e__B1_P00817",
]
loader = PinderLoader(
    ids=systems,
    monomer_priority="holo",
    base_filters = base_filters,
    sub_filters = sub_filters
)
passing_ids = []
for item in loader:
    passing_ids.append(item[0].entry.id)

systems_removed_by_filters = set(systems) - set(passing_ids)
systems_removed_by_filters

In [None]:
len(systems) == len(passing_ids)

### Optional Pinder writer

Without defining a writer for the `PinderLoader`, the loaded systems are available as a tuple of (`PinderSystem`, `Structure`, `Structure`) objects, containing the original `PinderSystem` and the sampled feature and target complexes, respectively.

If you want to explicitly write the (potentially transformed) structure objects to a custom location or in a custom format (e.g. PDB, pickle, etc.), you can implement a subclass of `PinderWriterBase`.

The default writer implements writing to PDB files (leveraging the `Structure.to_pdb` method on the structure objects).



In [None]:
from pinder.core.loader.writer import PinderDefaultWriter

from pathlib import Path
from tempfile import TemporaryDirectory

with TemporaryDirectory() as tmp_dir:
    temp_dir = Path(tmp_dir)
    loader = PinderLoader(
        ids=systems,
        monomer_priority="pred",
        writer=PinderDefaultWriter(temp_dir)
    )
    assert set(loader.index.id) == set(systems)
    for i, r in loader.index.iterrows():
        loaded = loader[i]
        pinder_id = r.id
        system_dir = loader.writer.output_path / pinder_id
        assert system_dir.is_dir()
        print(list(system_dir.glob("af_*.pdb")))


## Constructing torch datasets and dataloaders from pinder systems

The remaining sections of this tutorial will be for those interested specifically in torch datasets and dataloaders.

Specifically, we will show how to:
* Implement a PyTorch `Dataset` to interface with pinder data
* Include apo and predicted monomers in the data pipeline, with an option to target specific monomer types or randomly sample from the available types
* Leverage `PinderSystem` and its associated methods to crop apo/predicted monomers to match the ground-truth holo monomers
* Write filters and transforms that operate on `Structure` objects
* Integrate annotations in data filtering and featurization
* Create example features to use for training (you will of course choose your own features)
* Incorporate diversity sampling in the data loader


The `pinder.core.loader.dataset` module provides two example implementations of how to integrate the pinder dataset into a torch-based machine learning pipeline.

1. `PinderDataset`: A map-style `torch.utils.data.Dataset` that can be used with torch `DataLoader`'s.
2. `PPIDataset`: A `torch_geometric.data.Dataset` that can be used with torch-geometric `DataLoader`'s. This dataset is designed to be used with the `torch_geometric` package.

Together, the two datasets provide an example implementation of how to abstract away the complexity of loading and processing multiple structures associated with each `PinderSystem` by leveraging the following utilities from pinder:

* `pinder.core.PinderLoader`
* `pinder.core.loader.filters`
* `pinder.core.loader.transforms`

The examples cover two different batch data item structures to illustrate two different use-cases:

* `PinderDataset`: A batch of `(target_complex, feature_complex)` pairs, where `target_complex` and `feature_complex` are `torch.Tensor` objects representing the atomic coordinates and atom types of the holo and sampled (decoy, holo/apo/pred) complexes, respectively.
* `PPIDataset`: A batch of `PairedPDB` objects, where the receptor and ligand are encoded separately in a heterogeneous graph, via `torch_geometric.data.HeteroData`, holding multiple node and/or edge types in disjunct storage objects.


The remaining sections will be split into:
1. Using the `PinderDataset` torch dataset
2. Using the `PPIDataset` torch-geometric dataset
3. How you could implement your own dataset & dataloader


### PinderDataset (torch Dataset)


The `PinderDataset` is an example implementation of a `torch.utils.data.Dataset` that represents its data items as a dict containing the following key, value pairs:
* `target_complex`: The ground-truth holo dimer, represented with a set of default properties encoded as `Tensor`'s
* `feature_complex`: The sampled dimer complex, representing "features", also represented with a set of default properties encoded as `Tensor`'s
* `id`: The pinder ID for the selected system
* `target_id`: The IDs of the receptor and ligand holo monomers, concatenated into a single ID string
* `sample_id`: The IDs of the sampled receptor and ligand holo monomers, concatenated into a single ID string. This can be useful for debugging purposes or generally tracking which specific monomers are selected when targeting alternative monomers (more on this shortly)


Each of the `target_complex` and `feature_complex` values are dictionaries with structural properties encoded by the `pinder.core.loader.geodata.structure2tensor` function by default:
* `atom_coordinates`
* `atom_types`
* `chain_ids`
* `residue_coordinates`
* `residue_types`
* `residue_ids`

You can choose to use a different representation by overriding the default values of `transform` and `target_transform`.

It leverages the `PinderLoader` to apply optional filters and/or transforms, provide an interface for sampling alternative monomers, and exposes `transform` and `target_transform` arguments used by the torch Dataset API.

For more details on the torch Dataset APIs, please refer to the [tutorials](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#datasets-dataloaders).

#### Example `PinderDataset`

Now, we'll go through an example dataset preparation using the `PinderDataset` built on PyTorch `Dataset`. We start by merging our structure index with relevant metadata such as resolution and experimental method. Then, we define a handy `get_subset` function to filter out low-quality or irrelevant structures.

We'll do this by limiting it to systems determined by X-ray diffraction with biologically relevant interactions. To make our initial experiments faster and more manageable, we'll sample a small subset of 10 structures each for training and validation. Finally, we'll grab one random sample from the test set to showcase our model's predictions later on. This careful data preparation sets the stage for effective model training and evaluation

In [None]:
index = get_index()
metadata = get_metadata()

index = pd.merge(index, metadata[["id", "resolution", "label", "method"]], on="id")

# We'll start by restricting the training and validataion datasets to a smaller size (10) for testing purposes.

def get_subset(index, split, n_samples=50):
    query = '(resolution < 2) and (method == "X-RAY DIFFRACTION") and (label == "BIO")'
    n_samples = n_samples
    subset = index.query(f'{query} and (split == "{split}")').sample(n_samples).reset_index(drop=True, inplace=False).id.to_list()
    return subset

n_samples = 10
training_subset = get_subset(index, "train", n_samples)
validation_subset = get_subset(index, "val", n_samples)
test_subset = get_subset(index, "test", 1) # take a random test as example case

In [None]:
from pinder.core.loader.dataset import PinderDataset, structure2tensor_transform
from pinder.core.loader import filters, transforms

# We'll start by restricting the training dataset to a smaller size (10) for testing purposes.

base_filters = [
    filters.FilterByMissingHolo(),
    filters.FilterSubByContacts(min_contacts=5, radius=10.0, calpha_only=True),
    filters.FilterDetachedHolo(radius=12, max_components=2),
]
sub_filters = [
    filters.FilterSubByAtomTypes(min_atom_types=4),
    filters.FilterByHoloOverlap(min_overlap=5),
    filters.FilterByHoloSeqIdentity(min_sequence_identity=0.8),
    filters.FilterSubRmsds(rmsd_cutoff=7.5),
    filters.FilterDetachedSub(radius=12, max_components=2),
]


# We can include Structure-level transforms (and filters) which will operate on the target and/or feature complexes
target_transforms = [
    transforms.SelectAtomTypes(atom_types=["CA", "N", "C", "O"]),
]
# In addition to slicing only backbone atoms, we introduce random rotation to the ligand protein
# in the feature complex while preserving the target (ground-truth) complex orientations.
feature_transforms = [
    transforms.SelectAtomTypes(atom_types=["CA", "N", "C", "O"]),
    transforms.RandomLigandTransform(max_translation=10.0),
]

datasets = []

for split, subset_ids in [("train", training_subset), ("val", validation_subset), ("test", test_subset)]:
    print(f"Loading {split} split with {len(subset_ids)} systems...")
    dataset = PinderDataset(
        split=split,
        ids=subset_ids,
        # We can leverage holo, apo, pred, random and random_mixed monomer sampling strategies
        monomer_priority="random_mixed" if split != "test" else "holo",
        base_filters = base_filters,
        sub_filters = sub_filters,
        structure_transforms_target=target_transforms,
        structure_transforms_feature=feature_transforms)
    datasets.append(dataset)

train_dataset, val_dataset, test_dataset = datasets

assert len(train_dataset) == len(training_subset) == n_samples
assert len(test_subset) == 1

### Sampling alternative monomers

The `monomer_priority` argument can be used to target different mixes of bound and unbound monomers to use for creating the decoy/feature complex.

The allowed values for `monomer_priority` are "apo", "holo", "pred", "random" or "random_mixed".

When `monomer_priority` is set to one of the available monomer types (holo, apo, pred), the same monomer type will be selected for both receptor and ligand.

When the monomer priority is "random", a random monomer type will be selected from the set of monomer types available for both the receptor and ligand. This option ensures the same type of monomer is used for the receptor and ligand.

When the monomer priority is "random_mixed", a random monomer type will be selected for each of receptor and ligand, separately.

Enabling the `fallback_to_holo` option (default) will enable silent fallback to holo when the `monomer_priority` is set to one of apo or pred, but the corresponding monomer is not available for the dimer.

This is useful when only one of receptor or ligand has an unbound monomer, but you wish to include apo or predicted structures in your workflow.

If `fallback_to_holo` is disabled, an error will be raised when the `monomer_priority` is set to one of apo or pred, but the corresponding monomer is not available for the dimer.


By default, when apo monomers are selected, the "canonical" apo monomer is used. Although a single canonical apo monomer should be used for eval, pinder provides multiple apo monomers paired to a single holo monomer (when available). In order to include these non-canonical/alternative monomers, you can specify `use_canonical_apo=False` when constructing the `PinderLoader` or `PinderDataset` objects.


In [None]:
from pprint import pprint
data_item = train_dataset[0]
pprint(data_item)


In [None]:
pprint(data_item["feature_complex"])

In [None]:
# Since we used the default option of crop_equal_monomer_shapes, we should expect feature and target complex coords are identical shapes
assert (
    data_item["feature_complex"]["atom_coordinates"].shape
    == data_item["target_complex"]["atom_coordinates"].shape
)

data_item["feature_complex"]["atom_coordinates"].shape

In [None]:
help(PinderDataset)

### Torch DataLoader for PinderDataset

The `PinderDataset` can be served by a `torch.utils.data.DataLoader`.

There is a convenience function `pinder.core.loader.dataset.get_torch_loader` for taking a `PinderDataset` and returning a `DataLoader` for the dataset object.

We can leverage the default `collate_fn` (`pinder.core.loader.dataset.collate_batch`) to merge multiple systems (`Dataset` items) to create mini-batches of tensors:


In [None]:
from pinder.core.loader.dataset import collate_batch, get_torch_loader
from torch.utils.data import DataLoader

batch_size = 2
train_dataloader = get_torch_loader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_batch,
    num_workers=0,
)
assert isinstance(train_dataloader, DataLoader)
assert hasattr(train_dataloader, "dataset")

# Get a batch from the dataloader
batch = next(iter(train_dataloader))

# expected batch dict keys
assert set(batch.keys()) == {
    "target_complex",
    "feature_complex",
    "id",
    "sample_id",
    "target_id",
}
assert isinstance(batch["target_complex"], dict)
assert isinstance(batch["target_complex"]["atom_coordinates"], torch.Tensor)
feature_coords = batch["feature_complex"]["atom_coordinates"]

# Ensure batch size propagates to tensor dims
assert feature_coords.shape[0] == batch_size

# Ensure coordinates have dim 3
assert feature_coords.shape[2] == 3


# Break - Q&A

# Train a model

Now, we'll use a "dummy" model to explore the `PinderDataset` further. Given a complex, we're going to predict the rotation and translation of the ligand.

This has direct application to the rigid docking task, where we assume the ligand's internal structure doesn't change during docking. We'll use a straightforward iterator to handle batched data during training. The model's forward pass will output the predicted rotation and translation, as well as the transformed position of the ligand.

We'll measure the accuracy of our predictions with the Mean Squared Error loss of the predicted vs ground truth ligand coordinates.

Finally, we'll demonstrate how the inferred rotation and translation can be used to write the docking results as PDB and evaluate using `pinder-eval`.

### Dummy model implementation

The `DummyModelPredRotTrans` class defines a PyTorch `Module` for predicting the rotation and translation of a ligand, given a random initial pose. We'll use this model as an example for the training and eval workflow using `pinder`'s existing functionalities.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class DummyModelPredRotTrans(nn.Module):
    """
    A simple model for predicting rotation and translation of a ligand given its coordinates and those of the receptor.

    Pipeline:
    1. A multi-layer perceptron (MLP) to embed the receptor and ligand coordinates.
    2. Cross-attention between the receptor and ligand.
    3. Finally, it predicts the rotation matrix and translation vector.
    """
    def __init__(self, input_dim, embed_dim, num_heads):
        """
        Initialize the model.

        Args:
            input_dim: The dimension of the input features.
            embed_dim: The dimension of the embedding space.
            num_heads: The number of attention heads in the multi-head attention layer.
        """
        super(DummyModelPredRotTrans, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim),
            nn.LayerNorm(embed_dim)  # Added normalization layer
        )
        self.cross_attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.fc_rotation = nn.Linear(embed_dim, 9)
        self.fc_translation = nn.Linear(embed_dim, 3)

    def naive_single(self, receptor, ligand):
        """
        Processes a single receptor-ligand pair.

        Args:
            receptor: Tensor of shape (1, num_receptor_atoms, 3) (receptor coordinates)
            ligand: Tensor of shape (1, num_ligand_atoms, 3) (ligand coordinates)

        Returns:
            rotation_matrix: Tensor of shape (1, 3, 3) predicted rotation matrix for the ligand.
            translation_vector: Tensor of shape (1, 3) predicted translation vector for the ligand.
        """
        emb_features_receptor, emb_features_ligand = self.mlp(receptor), self.mlp(ligand)
        attn_output, _ = self.cross_attention(emb_features_receptor, emb_features_ligand, emb_features_ligand)
        rotation_matrix = self.fc_rotation(attn_output.mean(dim=1))
        rotation_matrix = rotation_matrix.view(-1, 3, 3)
        translation_vector = self.fc_translation(attn_output.mean(dim=1))
        return rotation_matrix, translation_vector

    def forward_rot_trans(self, batch):
        """
        Predicts rotation and translation for a batch of receptor-ligand complexes.

        Args:
            batch: A dictionary containing the batch data. It should have the following keys:
                - "feature_complex": A dictionary containing:
                    - "atom_coordinates": Tensor of shape (batch_size, num_atoms, 3) representing atom coordinates.
                    - "chain_ids": Tensor of shape (batch_size, num_atoms) representing chain IDs (0 for receptor, 1 for ligand).

        Returns:
            rotation_matrix: Tensor of shape (batch_size, 3, 3) representing predicted rotation matrices.
            translation_vector: Tensor of shape (batch_size, 3) representing predicted translation vectors.
            ligands: List of tensors, each of shape (1, num_ligand_atoms, 3), representing the original ligand coordinates.
        """
        rotation_matrices = []
        translation_vectors = []
        ligands = []
        for i in range(batch["feature_complex"]["chain_ids"].shape[0]):
            current_complex_coords = batch["feature_complex"]["atom_coordinates"][i]
            ligand_mask = batch["feature_complex"]["chain_ids"][i] == 1
            receptor_mask = batch["feature_complex"]["chain_ids"][i] == 0
            ligand_coords = current_complex_coords[ligand_mask]
            receptor_coords = current_complex_coords[receptor_mask]
            ligand_coords = ligand_coords.unsqueeze(0)
            receptor_coords = receptor_coords.unsqueeze(0)
            rotation_matrix, translation_vector = self.naive_single(receptor_coords, ligand_coords)
            rotation_matrices.append(rotation_matrix)
            translation_vectors.append(translation_vector)
            ligands.append(ligand_coords)
        rotation_matrix = torch.stack(rotation_matrices)
        translation_vector = torch.stack(translation_vectors)
        return rotation_matrix, translation_vector, ligands

    def forward(self, batch):
        """
        The main forward pass of the model.

        Args:
            batch: Same as in forward_rot_trans.

        Returns:
            transformed_ligands: List of tensors, each of shape (1, num_ligand_atoms, 3)
            representing the transformed ligand coordinates after applying the predicted
            rotation and translation.
        """
        rotation_matrix, translation_vector, ligands = self.forward_rot_trans(batch)
        for i in range(len(ligands)):
            ligands[i] = ligands[i] @ rotation_matrix[i] + translation_vector[i]
        return ligands

In [None]:
# Define model hyperparameters
input_dim = 3  # 3D coordinates (x, y, z)
embed_dim = 4  # Dimension of the embedding space after MLP transformation
num_heads = 1  # Number of attention heads for cross-attention

# Instantiate the model
model = DummyModelPredRotTrans(input_dim, embed_dim, num_heads)

# Get a sample batch from the training dataloader
batch = next(iter(train_dataloader))
# Perform a forward pass to get predicted ligand coordinates
with torch.inference_mode():
    predicted_ligand_coords = model(batch)

# Verify output dimensions: Ensure the number of predicted ligand coordinates matches the actual number of ligand atoms in the batch
for i in range(len(predicted_ligand_coords)):
    assert predicted_ligand_coords[i].shape[1] == (batch["target_complex"]["chain_ids"][i] == 1).sum()

### Training and validation

In [None]:
import numpy as np
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def loss_fn(pred_l, gt_l, masks):
    """
    Calculates the MSE loss between predicted and ground truth ligand coordinates
    for the batch.

    Args:
        pred_l: predicted ligand coordinates.
        gt_l: ground truth coordinates for the entire complex.
        masks: tensor representing chain IDs (0 for receptor, 1 for ligand).

    Returns:
        The mean MSE loss for batch.
    """
    losses = []
    for i in range(len(pred_l)):
        pred_l_i = pred_l[i]
        gt_l_i = gt_l[i][masks[i] == 1].unsqueeze(0)
        losses.append(F.mse_loss(pred_l_i, gt_l_i))
    return torch.mean(torch.stack(losses))

def validation_loop(dataloader, model):
    """
    Performs validation on the given dataloader and model.

    Args:
        dataloader: The dataloader for the validation set.
        model: The model to be evaluated.

    Returns:
        The mean MSE loss for all batches.
    """
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    val_losses = []
    with torch.no_grad():
        for batch in dataloader:
            pred_l = model(batch)
            val_loss = loss_fn(pred_l,
                                batch["target_complex"]["atom_coordinates"],
                                batch["target_complex"]["chain_ids"]).item()
            val_losses.append(val_loss)
    print(f"Validation loss: {np.mean(val_losses):>8f}")
    return val_losses

def train_loop(train_dataloader, val_dataloader, model, optimizer, num_epoch=10, val_interval=1):
    """
    Trains the model.

    Args:
        train_dataloader: The dataloader for the training set.
        val_dataloader: The dataloader for the validation set.
        model: The model to be trained.
        optimizer: The optimizer used for training.
        num_epoch: The number of epochs to train for.
        val_interval: The interval (in epochs) at which to perform validation.
    """
    loss_log = []
    for epoch in range(num_epoch):
        print(f"\nEpoch {epoch+1}\n-------------------------------")
        size = len(train_dataloader.dataset)
        for batch in train_dataloader:

            # Compute prediction and loss
            pred_l = model(batch)
            loss = loss_fn(pred_l,
                           batch["target_complex"]["atom_coordinates"],
                           batch["target_complex"]["chain_ids"])

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if epoch % val_interval == 0:
            val_losses = validation_loop(val_dataloader, model)
            loss_log.extend([{"epoch": epoch, "val/loss": vl} for vl in val_losses])
    return pd.DataFrame(loss_log)



We can also create the dataloaders for `val` and `test` datasets

In [None]:
val_dataloader = get_torch_loader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=collate_batch,
    num_workers=0,
)

test_dataloader = get_torch_loader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
)
# We can start the training for 5 epochs
val_loss = train_loop(train_dataloader, val_dataloader, model, optimizer, num_epoch=5)


In [None]:
import plotly.express as px
fig = px.line(
    val_loss.groupby("epoch", as_index=False)["val/loss"].mean(),
    x="epoch",
    y="val/loss",
    template="plotly_dark",
    title="Mean validation MSE loss",
    width=850, height=650,
)
fig.show()


## Inference

Our test data contains a single example, we can use our model to make predictions on the test.

In [None]:
test_batch = next(iter(test_dataloader))
with torch.inference_mode():
  predicted_ligand_coords = model(test_batch)


In [None]:
# we can see more info on our test data
test_dataloader.dataset.loader.index

In [None]:
test_index = test_dataloader.dataset.loader.index
loaded_idx = test_index[test_index["id"] == test_batch['id'][0]].index.values[0]
loaded_idx

In [None]:
predicted_ligand_coords[0].shape

In [None]:
# now we can load them as pinder's Structure objects
system, feature_complex, target_complex = test_dataloader.dataset.loader[loaded_idx]

In [None]:
# take the predicted coords of the ligand and convert to numpy array
pred_coords = predicted_ligand_coords[0].detach().numpy()[0]
# make sure ligand coordinate dimensions match with the original input complex dimensions
assert pred_coords.shape[0] == (feature_complex.atom_array.chain_id == "L").sum()

In [None]:
# we can now edit the structure object with the new predicted coords
feature_complex.atom_array[feature_complex.atom_array.chain_id == "L"].coord = pred_coords

In [None]:
import biotite.structure as struc

# Superimpose predicted receptor chain to ground-truth receptor chain using biotite
nat_arr = target_complex.atom_array.copy()
pred_arr = feature_complex.atom_array.copy()
_, transformation = struc.superimpose(nat_arr[nat_arr.chain_id == "R"], pred_arr[pred_arr.chain_id == "R"])
pred_arr = transformation.apply(pred_arr)
feature_complex.atom_array = pred_arr.copy()



In [None]:
from pathlib import Path
# Write the prediction to PDB to visualize
output_pdb = Path('./predicted_model_rank1.pdb')
feature_complex.to_pdb(output_pdb)

In [None]:
import py3Dmol

# Viewer documentation: https://3dmol.org/doc/GLViewer.html
view = py3Dmol.view(width=900, height=900)
view.removeAllModels()
view.setViewStyle({'style':'outline','color':'black','width':0.1})
view.setBackgroundColor('black')
0
# Show the reference structure

ps = PinderSystem(test_batch['id'][0])
ref_pdb = ps.native.filepath
view.addModel(open(ref_pdb, 'r').read(),'pdb',
             {"style": {'cartoon': {'color':'gold'}}})


view.setStyle({'chain':'R'},{'cartoon': {'color':'gray', 'arrows':True, 'tubes': False, 'ribbon': False, 'style': 'edged'}})
view.setStyle({'chain':'L'},{'cartoon': {'color':'gold', 'arrows':True, 'tubes':False, 'ribbon': False, 'style':'edged'}})

# Show the receptor-chain superposed model structure
view.addModel(
    open(output_pdb, 'r').read(),
    'pdb',
    {"style": {'cartoon': {'color':'lightgray'}}}
)
# Color the predicted ligand chain in light pink
view.setStyle({'chain':'L', "model": 1},{'cartoon': {'color':'lightpink', 'arrows':True, 'tubes': False, 'ribbon': False, 'style': 'edged'}})
view.zoomTo()
view.show()

In [None]:
# Evaluate the prediction

from pinder.eval.dockq import BiotiteDockQ
ps = PinderSystem(test_batch['id'][0])
ref_pdb = ps.native.filepath
bdq = BiotiteDockQ(ref_pdb, output_pdb, parallel_io=False)
metrics = bdq.calculate()
metrics.T


## Pinder eval

The evaluation harness can be used either through methods in `pinder.eval` or as a CLI script:


```
pinder_eval --help

usage: pinder_eval [-h] --eval_dir eval_dir [--serial] [--method_name method_name] [--allow_missing]

optional arguments:
  -h, --help            show this help message and exit
  --eval_dir eval_dir, -f eval_dir
                        Path to eval
  --serial, -s          Whether to disable parallel eval over systems
  --method_name method_name, -m method_name, -n method_name
                        Optional name for output csv
  --allow_missing, -a   Whether to allow missing systems for a given pinder-set + monomer

```

The expected format for the contents of `eval_dir` are shown below:
```
eval_dir_example/
└── some_method
    ├── 1ldt__A1_P00761--1ldt__B1_P80424
    │   ├── apo_decoys
    │   │   ├── model_1.pdb
    │   │   └── model_2.pdb
    │   ├── holo_decoys
    │   │   ├── model_1.pdb
    │   │   └── model_2.pdb
    │   └── predicted_decoys
    │       ├── model_1.pdb
    │       └── model_2.pdb
    └── 1b8m__B1_P34130--1b8m__A1_P23560
        ├── holo_decoys
        │   ├── model_1.pdb
        │   └── model_2.pdb
        └── predicted_decoys
            ├── model_1.pdb
            └── model_2.pdb
```

The eval directory should contain one or more methods to evaluate as sub-directories.

Each method sub-directory should contains sub-directories that are named by pinder system ID.

Inside of each pinder system sub-directory, you should have three subdirectories:
* `holo_decoys` (predictions that were made using holo monomers)
* `apo_decoys` (predictions made using apo monomers)
* `predicted_decoys` (predictions made using predicted, e.g. AF2, monomers)

You can have any number of decoys in each directory; however, the decoys should be named in a way that the prediction rank can be extracted. In the above example, the decoys are named using a `model_<rank>.pdb` convention. Other names for decoy models are accepted, so long as they can match the regex pattern used in `pinder.eval.dockq.MethodMetrics`: `r"\d+(?=\D*$)"`

Each model decoy should have exactly two chains: {R, L} for {Receptor, Ligand}, respectively.


⚠️ **Note: in order to make a fair comparison of methods across complete test sets, if a method is missing predictions for a system, the following metrics are used as a penalty**

```python

{
    "iRMS": 100.0,
    "LRMS": 100.0,
    "Fnat": 0.0,
    "DockQ": 0.0,
    "CAPRI": "Incorrect",
}
```


Under the hood, the leaderboard makes use of the `MethodMetrics` class from the `pinder.eval.dockq.method`. This interface is itself an abstraction over the underlying `BiotiteDockQ` API.

Below I show an example of how you could use the `BiotiteDockQ` class directly.

In [None]:
import shutil
from pathlib import Path
from pinder.eval.dockq.biotite_dockq import BiotiteDockQ

method_dir = Path("./dummy-model").absolute()
system_dir = method_dir / test_batch['id'][0]
decoy_dir = system_dir / "holo_decoys"
decoy_dir.mkdir(exist_ok=True, parents=True)
if not (decoy_dir / output_pdb.name).is_file():
    shutil.copy(output_pdb, decoy_dir)

native = ps.native.filepath
decoys = list(decoy_dir.glob("*.pdb"))

R_chain, L_chain = ["R"], ["L"]
bdq = BiotiteDockQ(
    native=native, decoys=decoys,
    # These are optional and if not specified will be assigned based on number of atoms (receptor > ligand)
    native_receptor_chain=R_chain,
    native_ligand_chain=L_chain,
    decoy_receptor_chain=R_chain,
    decoy_ligand_chain=L_chain,
)
metrics = bdq.calculate()
metrics


In [None]:
from pinder.eval.dockq.method import MethodMetrics

mm = MethodMetrics(method_dir, allow_missing_systems=True, parallel=False)
metrics = mm.metrics
metrics

In [None]:
leaderboard_row = mm.get_leaderboard_entry()
leaderboard_row.T