In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path

import pandas as pd

In [3]:
# Replace with your path to the magneton-data huggingface dataset
data_path = Path("/weka/scratch/weka/kellislab/rcalef/data/magneton-data")

fasta_path = (
    data_path /
    "sequences" /
    "uniprot_sprot.fasta.gz"
)

# Core types overview

The core datasets used in this repo pertain to protein substructures, i.e. small subsets of a protein's overall 3D structure that are evolutionarily conserved and observed across many distinct proteins.  These sub-structures form the lower-level building blocks that make up the 3D strucutre of a full-length protein, and some also have distinct functions.

We use two main sources for our substructure annotations:
- InterPro: "provides functional analysis of proteins by classifying them into families and predicting domains and important sites". InterPro annotations include functional domains, binding sites, and higher-order structural groupings like protein families.
- DSSP: provides annotation of secondary structure elements given an experimental or predicted protein structure

In this notebook, we explore the representation of these datatypes in our current codebase.

The core types are defined in `magneton/types.py`, the most notable of which is the `Protein` class.

```
@dataclass
class Protein:
    uniprot_id: str
    kb_id: str
    name: str
    length: int
    parsed_entries: int
    total_entries: int
    entries: List[InterproEntry]
    secondary_structs: List[SecondaryStructure] = field(default_factory=list)
```

This class represents a single protein, specified by a UniProt ID, and provides its InterPro annotations and secondary structure predictions.

These `Protein` objects are currently stored as Python pickle files.

In [4]:
from magneton.io.internal import parse_from_pkl

In [5]:
pkl_path = (
    data_path /
    "interpro_103.0" /
    "seq_splits" /
    "train_sharded" /
    "swissprot.with_ss.train.0.pkl.bz2"
)

example_prots = []
for i, prot in enumerate(parse_from_pkl(pkl_path)):
    example_prots.append(prot)
    if i == 10:
        break

# The `print()` method just pretty prints the object.
example_prots[0].print()

Protein(uniprot_id='A0A009IHW8',
        kb_id='sp|A0A009IHW8|ABTIR_ACIB9',
        name='ABTIR_ACIB9',
        length=269,
        parsed_entries=5,
        total_entries=5,
        entries=[InterproEntry(id='IPR035897',
                               element_type='Homologous_superfamily',
                               match_id='G3DSA:3.40.50.10140',
                               element_name='Toll/interleukin-1 receptor '
                                            'homology (TIR) domain superfamily',
                               representative=False,
                               positions=[(80, 266)]),
                 InterproEntry(id='IPR000157',
                               element_type='Domain',
                               match_id='PF13676',
                               element_name='Toll/interleukin-1 receptor '
                                            'homology (TIR) domain',
                               representative=False,
                               p

In the above example, `parse_from_pkl` returns an iterator that reads `Protein` objects from the pickle file one at a time.

Each `Protein` object contains the corresponding protein's InterPro annotations as `InterproEntry` objects:
```
@dataclass
class InterproEntry:
    id: str
    element_type: str
    match_id: str
    element_name: str
    representative: bool
    # Note that positions are 1-indexed, i.e. exactly as given in InterPro.
    positions: List[Tuple[int]]
```
These objects specify the unique InterPro ID for this specific substructure annotation, what type of annotation it is (e.g. Domain, Family, etc.), and where it occurs in the given protein. Note that since these substructures occurr in 3D, they can be made up of amino acids that are not adjacent in the underlying amino acid sequence, hence being represented as a list of ranges.

In [6]:
example_prots[0].entries[0].print()

InterproEntry(id='IPR035897',
              element_type='Homologous_superfamily',
              match_id='G3DSA:3.40.50.10140',
              element_name='Toll/interleukin-1 receptor homology (TIR) domain '
                           'superfamily',
              representative=False,
              positions=[(80, 266)])


We also provide the secondary structure annotations as `SecondaryStructure` object, which just provides the type of secondary structure as an enum, and the contiguous range on the protein sequence (all secondary structures are contiguous):
```
@unique
class DsspType(Enum):
    H = 0
    B = auto()
    E = auto()
    G = auto()
    I = auto()
    P = auto()
    T = auto()
    S = auto()
    # In place of ' ' (space) for OTHER
    X = auto()

@dataclass
class SecondaryStructure:
    dssp_type: DsspType
    # Note that positions are 1-indexed, as output by dssp.
    # Coordinates are half-open, i.e. [start, end)
    start: int
    end: int
```

In [7]:
example_prots[0].secondary_structs[0].print()

Alphahelix: 3 - 21


# Dataset interface and implementation

Our overarching dataset is UniProtKB, which contains over 200 million proteins. To work with a smaller dataset, we use the manually reviewed subset of UniProtKB, which is referred to as SwissProt and contains ~550k proteins. However, since this can still be a large and unwieldy dataset to store as a single file, we store the dataset as many files, i.e. as a "sharded" dataset. 

This means that we can choose to either process the dataset as a whole by reading all the shards in sequence, or can trivially parallelize any operation that operates on single `Protein`s by parallelizing over the many shards.

To abstract whether the underlying dataset is a single file or a directory of sharded files, we introduce the `ShardedProteinDataset` object. Here we read a small number of proteins from a sharded directory.

Note that we also define a simpler `InMemoryProteinDataset` for testing with smaller datasets.

In [8]:
from magneton.data import get_protein_dataset

In [9]:
prot_dataset = get_protein_dataset(
    input_path=(
        data_path /
        "interpro_103.0" /
        "seq_splits" /
        "train_sharded"
    ),
    # Provide the prefix for the sharded files, named like
    # {prefix}.{shard_num}.pkl.bz2
    prefix="swissprot.with_ss.train",
)
len(prot_dataset)

423885

In [10]:
# These shards each contain 10000 proteins, so we'll demonstrate
# reading from multiple shards seamlessly.
example_prots = []
for i, prot in enumerate(prot_dataset):
    if i == 15000:
        break
    example_prots.append(prot)

len(example_prots)

15000

To speed up fetching a single protein from the dataset, we also use an index to identify which shard a protein is contained in, and can then fetch just from that shard.

In [11]:
prot_dataset.fetch_protein("A1S1K5").print()

Protein(uniprot_id='A1S1K5',
        kb_id='sp|A1S1K5|TSAC_SHEAM',
        name='TSAC_SHEAM',
        length=187,
        parsed_entries=6,
        total_entries=7,
        entries=[InterproEntry(id='IPR023535',
                               element_type='Family',
                               match_id='MF_01852',
                               element_name='Threonylcarbamoyl-AMP synthase',
                               representative=True,
                               positions=[(3, 185)]),
                 InterproEntry(id='IPR006070',
                               element_type='Domain',
                               match_id='PF01300',
                               element_name='Threonylcarbamoyl-AMP '
                                            'synthase-like domain',
                               representative=False,
                               positions=[(12, 187)]),
                 InterproEntry(id='IPR006070',
                               element_type='Domain',


We can also leverage the sharded structure to quickly perform computations or subsets of the full dataset. 

In this example, we use the `filter_proteins` function that takes in a function that receives a `Protein` and returns a `bool`, and fetches all proteins that match that criteria. In particular, we pull out proteins with very long alpha helices.

Note that this example requires you to be running on a machine with multiple cores, set the `nprocs` parameter accordingly.

In [12]:
from magneton.types import DsspType, Protein
from magneton.io.internal import filter_proteins

In [13]:
def has_long_helix(prot: Protein, max_len: int = 500) -> bool:
    for ss in prot.secondary_structs:
        if ss.dssp_type == DsspType.H and (ss.end - ss.start) >= max_len:
            return True
    return False

long_helices = filter_proteins(
    prot_dataset.input_path,
    has_long_helix,
    prefix=prot_dataset.prefix,
    nprocs=32,
)
len(long_helices)
long_helices[0].print()

100%|██████████████████████████████████████████████████████████| 43/43 [01:31<00:00,  2.14s/it]


Protein(uniprot_id='A0A166B1A6',
        kb_id='sp|A0A166B1A6|NMCP1_DAUCS',
        name='NMCP1_DAUCS',
        length=1119,
        parsed_entries=1,
        total_entries=1,
        entries=[InterproEntry(id='IPR040418',
                               element_type='Family',
                               match_id='PTHR31908',
                               element_name='Protein crowded nuclei',
                               representative=True,
                               positions=[(3, 1119)])],
        secondary_structs=[SecondaryStructure(dssp_type=<DsspType.S: 7>,
                                              start=5,
                                              end=6),
                           SecondaryStructure(dssp_type=<DsspType.S: 7>,
                                              start=7,
                                              end=10),
                           SecondaryStructure(dssp_type=<DsspType.H: 0>,
                                              start=10

# Associated metadata

## All InterPro types/labels

If on OpenMind, an exhaustive list of the set of InterPro entries can be found here:

In [14]:
label_path = (
    data_path /
    "interpro_103.0" /
    "selected_label_set"
)

def load_stats(interpro_type: str) -> pd.DataFrame:
    labels_path = os.path.join(label_path, f"{interpro_type}.labels.tsv")
    return pd.read_table(labels_path)

In [15]:
load_stats("Domain").head()

Unnamed: 0,label,element_name,interpro_id
0,0,Kringle,IPR000001
1,1,C2 domain,IPR000008
2,2,Cystatin domain,IPR000010
3,3,PAS domain,IPR000014
4,4,Phosphofructokinase domain,IPR000023


In [16]:
load_stats("Family").head()

Unnamed: 0,label,element_name,interpro_id
0,0,"Metallothionein, vertebrate",IPR000006
1,1,SsrA-binding protein,IPR000037
2,2,Adenosylhomocysteinase-like,IPR000043
3,3,Thymidine/pyrimidine-nucleoside phosphorylase,IPR000053
4,4,Large ribosomal subunit protein eL31,IPR000054


## Structures and sequences

For those working on OpenMind, you can find the associated 3D structure files from AlphaFoldDB here:

In [17]:
example_id = example_prots[0].uniprot_id
cif_tmpl = "/weka/scratch/weka/kellislab/rcalef/data/cif_alphafolddb/AF-%s-F1-model_v4.cif.gz"
pdb_tmpl = "/weka/scratch/weka/kellislab/rcalef/data/pdb_alphafolddb/AF-%s-F1-model_v4.pdb"

print(os.path.exists(cif_tmpl % example_id))
print(os.path.exists(pdb_tmpl % example_id))

True
True


Similarly, the associated amino acid sequence can be obtained as follows:

In [18]:
from pysam import FastaFile

In [19]:
fa = FastaFile(fasta_path)

# Note the use of kb_id instead of uniprot_id
print(fa[example_prots[0].kb_id])

MSLEQKKGADIISKILQIQNSIGKTTSPSTLKTKLSEISRKEQENARIQSKLSDLQKKKIDIDNKLLKEKQNLIKEEILERKKLEVLTKKQQKDEIEHQKKLKREIDAIKASTQYITDVSISSYNNTIPETEPEYDLFISHASEDKEDFVRPLAETLQQLGVNVWYDEFTLKVGDSLRQKIDSGLRNSKYGTVVLSTDFIKKDWTNYELDGLVAREMNGHKMILPIWHKITKNDVLDYSPNLADKVALNTSVNSIEEIAHQLADVILNR


## Dataset splits

Dataset splits for the SwissProt subset can be found on OpenMind at the following path, and are also structured as directories of sharded pickle files. 

Note that in order to minimize information leakage, we construct our splits using pre-computed clusters of proteins that have been clustered on the basis of either sequence or structure similarity. Both sets of splits are available, with the structure-based splits being more stringent in terms of substructure observations being mostly assigned to a single split.

In [20]:
splits_path = (
    data_path /
    "interpro_103.0" /
    "seq_splits"
)
splits_path

PosixPath('/weka/scratch/weka/kellislab/rcalef/data/magneton-data/interpro_103.0/seq_splits')

In [21]:
list(splits_path.glob("*_sharded"))

[PosixPath('/weka/scratch/weka/kellislab/rcalef/data/magneton-data/interpro_103.0/seq_splits/val_sharded'),
 PosixPath('/weka/scratch/weka/kellislab/rcalef/data/magneton-data/interpro_103.0/seq_splits/test_sharded'),
 PosixPath('/weka/scratch/weka/kellislab/rcalef/data/magneton-data/interpro_103.0/seq_splits/train_sharded')]

For more details on split construction, see [`create_dataset_splits.ipynb`](https://github.com/rcalef/magneton/blob/main/experiments/create_dataset_splits.ipynb)

# Data loaders

To actually train and evaluate a model, we'll need to construct data loaders from these protein datasets

In [4]:
import torch

from torchdata.nodes import BaseNode, MapStyleWrapper, Filter
from torch.utils.data import RandomSampler, DistributedSampler, SequentialSampler

from magneton.config import DataConfig
from magneton.data import MagnetonDataModule
from magneton.data.core import (
    CoreDataset,
    SubstructType,
)
from magneton.types import DataType

In [5]:
interpro_path = (
    data_path /
    "interpro_103.0"
)
label_path = (
    interpro_path /
    "selected_label_set"
)

dataset_path = (
    interpro_path /
    "debug_datasets"
)

data_config = DataConfig(
    data_dir=dataset_path,
    prefix="swissprot.with_ss.train",
    fasta_path=fasta_path,
    labels_path=label_path,
    substruct_types=[SubstructType.DOMAIN],
)

In [6]:
data_module = MagnetonDataModule(
    data_config=data_config,
    model_type="esmc",
)

In [7]:
loader = data_module.train_dataloader()

100%|██████████████████████████████████████████████████████| 1/1 [00:13<00:00, 13.30s/it]


In [8]:
it = iter(loader)

In [9]:
batch = next(it)
batch

ESMCBatch(protein_ids=['A1ABS6', 'A1AWD0', 'A0KU61', 'A0JXU0', 'A1AU61', 'A0BD73', 'A0ALA8', 'A1ADB6', 'A0A2H4HHY6', 'A0AEM3', 'A0AFC3', 'A0A455LLX4', 'A0T0M9', 'A0RJ81', 'A0KYA2', 'A0A0H3NBY9', 'A0RV25', 'A0KEH8', 'A0B9K1', 'A0RCM7', 'A0T0L8', 'A0QL16', 'A0QSG3', 'A1JLK6', 'A0Q3I1', 'A0R1W8', 'A0T0H8', 'A0KZ22', 'A0PW28', 'A0A0H2URG7', 'A0RPF9', 'A1JNH0'], seqs=None, substructures=None, structure_list=None, labels=None, tokenized_seq=tensor([[ 0, 20, 17,  ...,  1,  1,  1],
        [ 0, 20,  5,  ...,  1,  1,  1],
        [ 0, 20, 21,  ...,  1,  1,  1],
        ...,
        [ 0, 20, 11,  ...,  1,  1,  1],
        [ 0, 20, 12,  ...,  1,  1,  1],
        [ 0, 20, 20,  ...,  1,  1,  1]]))