In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import pandas as pd

from magneton.utils import get_data_dir

In [3]:
fasta_path = get_data_dir() / "sequences" / "uniprot_sprot.fasta.gz"

# Core types overview

The core datasets used in this repo pertain to protein substructures, i.e. small subsets of a protein's overall 3D structure that are evolutionarily conserved and observed across many distinct proteins.  These sub-structures form the lower-level building blocks that make up the 3D strucutre of a full-length protein, and some also have distinct functions.

We use two main sources for our substructure annotations:
- InterPro: "provides functional analysis of proteins by classifying them into families and predicting domains and important sites". InterPro annotations include functional domains, binding sites, and higher-order structural groupings like protein families.
- DSSP: provides annotation of secondary structure elements given an experimental or predicted protein structure

In this notebook, we explore the representation of these datatypes in our current codebase.

The core types are defined in `magneton/core_types.py`, the most notable of which is the `Protein` class.

```
@dataclass
class Protein:
    uniprot_id: str
    kb_id: str
    name: str
    length: int
    parsed_entries: int
    total_entries: int
    entries: List[InterproEntry]
    secondary_structs: List[SecondaryStructure]
```

This class represents a single protein, specified by a UniProt ID, and provides its InterPro annotations and secondary structure predictions.

These `Protein` objects are currently stored as gzipped JSONL files.

In [4]:
from magneton.io.internal import parse_from_json

In [5]:
json_path = (
    get_data_dir()
    / "interpro_103.0"
    / "swissprot_subset"
    / "swissprot.with_ss.0.jsonl.gz"
)

example_prots = []
for i, prot in enumerate(parse_from_json(json_path)):
    example_prots.append(prot)
    if i == 10:
        break

# The `print()` method just pretty prints the object.
example_prots[0].print()

Protein(uniprot_id='A0A009IHW8',
        kb_id='sp|A0A009IHW8|ABTIR_ACIB9',
        name='ABTIR_ACIB9',
        length=269,
        parsed_entries=5,
        total_entries=5,
        entries=[InterproEntry(id='IPR035897',
                               element_type=<SubstructType.HOMO_FAMILY: 'Homologous_superfamily'>,
                               match_id='G3DSA:3.40.50.10140',
                               element_name='Toll/interleukin-1 receptor '
                                            'homology (TIR) domain superfamily',
                               representative=False,
                               positions=[(80, 266)]),
                 InterproEntry(id='IPR000157',
                               element_type=<SubstructType.DOMAIN: 'Domain'>,
                               match_id='PF13676',
                               element_name='Toll/interleukin-1 receptor '
                                            'homology (TIR) domain',
                               r

In the above example, `parse_from_json` returns an iterator that reads `Protein` objects from the pickle file one at a time.

Each `Protein` object contains the corresponding protein's InterPro annotations as `InterproEntry` objects:
```
@dataclass
class InterproEntry:
    """A single InterPro entry.

    - id (str): InterPro ID for this entry's element.
    - element_type (InterProType): Type of InterPro element.
    - match_id (str): ID of the specific match.
    - element_name (str): Human-readable name of the InterPro element.
    - representative (bool): Whether or not this is the representative entry of this element for this protein.
    - positions (List[Tuple[int]]): List of [start, end) positions for this entry, 1-indexed as in InterPro.
    """

    id: str
    element_type: SubstructType
    match_id: str
    element_name: str
    representative: bool
    # Note that positions are 1-indexed, i.e. exactly as given in InterPro.
    positions: list[tuple[int, int]]
```
These objects specify the unique InterPro ID for this specific substructure annotation, what type of annotation it is (e.g. Domain, Family, etc.), and where it occurs in the given protein. Note that since these substructures occurr in 3D, they can be made up of amino acids that are not adjacent in the underlying amino acid sequence, hence being represented as a list of ranges.

In [6]:
example_prots[0].entries[0].print()

InterproEntry(id='IPR035897',
              element_type=<SubstructType.HOMO_FAMILY: 'Homologous_superfamily'>,
              match_id='G3DSA:3.40.50.10140',
              element_name='Toll/interleukin-1 receptor homology (TIR) domain '
                           'superfamily',
              representative=False,
              positions=[(80, 266)])


We also provide the secondary structure annotations as `SecondaryStructure` objects, which just provide the type of secondary structure as an enum, and the contiguous range on the protein sequence (all secondary structures are contiguous):
```
@unique
class DsspType(Enum):
    H = 0
    B = auto()
    E = auto()
    G = auto()
    I = auto()
    P = auto()
    T = auto()
    S = auto()
    # In place of ' ' (space) for OTHER
    X = auto()

@dataclass
class SecondaryStructure:
    dssp_type: DsspType
    # Note that positions are 1-indexed, as output by dssp.
    # Coordinates are half-open, i.e. [start, end)
    start: int
    end: int
```

In [7]:
example_prots[0].secondary_structs[0].print()

Alphahelix: 3 - 21


# Dataset interface and implementation

Our overarching dataset is UniProtKB, which contains over 200 million proteins. To work with a smaller dataset, we use the manually reviewed subset of UniProtKB, which is referred to as SwissProt and contains ~550k proteins. However, since this can still be a large and unwieldy dataset to store as a single file, we store the dataset as many files, i.e. as a "sharded" dataset. 

This means that we can choose to either process the dataset as a whole by reading all the shards in sequence, or can trivially parallelize any operation that operates on single `Protein`s by parallelizing over the many shards.

To abstract whether the underlying dataset is a single file or a directory of sharded files, we introduce the `ShardedProteinDataset` object. Here we read a small number of proteins from a sharded directory.

Note that we also define a simpler `InMemoryProteinDataset` for testing with smaller datasets.

In [8]:
from magneton.data.core import get_protein_dataset

In [9]:
prot_dataset = get_protein_dataset(
    input_path=(get_data_dir() / "interpro_103.0" / "swissprot_subset"),
    # Provide the prefix for the sharded files, named like
    # {prefix}.{shard_num}.pkl.bz2
    prefix="swissprot.with_ss",
)
len(prot_dataset)

530601

In [10]:
# These shards each contain 10000 proteins, so we'll demonstrate
# reading from multiple shards seamlessly.
example_prots = []
for i, prot in enumerate(prot_dataset):
    if i == 15000:
        break
    example_prots.append(prot)

len(example_prots)

15000

To speed up fetching a single protein from the dataset, we also use an index to identify which shard a protein is contained in, and can then fetch just from that shard.

In [11]:
prot_dataset.fetch_protein("A1S1K5").print()

Protein(uniprot_id='A1S1K5',
        kb_id='sp|A1S1K5|TSAC_SHEAM',
        name='TSAC_SHEAM',
        length=187,
        parsed_entries=6,
        total_entries=7,
        entries=[InterproEntry(id='IPR023535',
                               element_type=<SubstructType.FAMILY: 'Family'>,
                               match_id='MF_01852',
                               element_name='Threonylcarbamoyl-AMP synthase',
                               representative=True,
                               positions=[(3, 185)]),
                 InterproEntry(id='IPR006070',
                               element_type=<SubstructType.DOMAIN: 'Domain'>,
                               match_id='PF01300',
                               element_name='Threonylcarbamoyl-AMP '
                                            'synthase-like domain',
                               representative=False,
                               positions=[(12, 187)]),
                 InterproEntry(id='IPR006070',
      

We can also leverage the sharded structure to quickly perform computations or subsets of the full dataset. 

In this example, we use the `filter_proteins` function that takes in a function that receives a `Protein` and returns a `bool`, and fetches all proteins that match that criteria. In particular, we pull out proteins with very long alpha helices, which may be real or artifacts of the predicted structures.

Note that this example requires you to be running on a machine with multiple cores, set the `nprocs` parameter accordingly.

In [12]:
from magneton.core_types import DsspType, Protein
from magneton.io.internal import filter_proteins

In [13]:
def has_long_helix(prot: Protein, max_len: int = 500) -> bool:
    for ss in prot.secondary_structs:
        if ss.dssp_type == DsspType.H and (ss.end - ss.start) >= max_len:
            return True
    return False


long_helices = filter_proteins(
    prot_dataset.input_path,
    has_long_helix,
    prefix=prot_dataset.prefix,
    nprocs=8,
)
len(long_helices)
long_helices[0].print()

processing proteins: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 55/55 [00:15<00:00,  3.45it/s]

Protein(uniprot_id='A0A166B1A6',
        kb_id='sp|A0A166B1A6|NMCP1_DAUCS',
        name='NMCP1_DAUCS',
        length=1119,
        parsed_entries=1,
        total_entries=1,
        entries=[InterproEntry(id='IPR040418',
                               element_type=<SubstructType.FAMILY: 'Family'>,
                               match_id='PTHR31908',
                               element_name='Protein crowded nuclei',
                               representative=True,
                               positions=[(3, 1119)])],
        secondary_structs=[SecondaryStructure(dssp_type=<DsspType.S: 7>,
                                              start=5,
                                              end=6),
                           SecondaryStructure(dssp_type=<DsspType.S: 7>,
                                              start=7,
                                              end=10),
                           SecondaryStructure(dssp_type=<DsspType.H: 0>,
                              




# Associated metadata

## All InterPro types/labels

In [14]:
label_path = get_data_dir() / "interpro_103.0" / "labels" / "full_set"


def load_stats(interpro_type: str) -> pd.DataFrame:
    labels_path = os.path.join(label_path, f"{interpro_type}.labels.tsv")
    return pd.read_table(labels_path)

In [15]:
load_stats("Domain").head()

Unnamed: 0,label,element_name,interpro_id
0,0,Kringle,IPR000001
1,1,"Tubby, C-terminal",IPR000007
2,2,C2 domain,IPR000008
3,3,Cystatin domain,IPR000010
4,4,PAS domain,IPR000014


In [16]:
load_stats("Family").head()

Unnamed: 0,label,element_name,interpro_id
0,0,Retinoid X receptor/HNF4,IPR000003
1,1,"Metallothionein, vertebrate",IPR000006
2,2,Protein phosphatase 2A regulatory subunit PR55,IPR000009
3,3,Ubiquitin/SUMO-activating enzyme E1-like,IPR000011
4,4,Retroviral VpR/VpX protein,IPR000012


## Structures and sequences

Due to the large size of the structure files for all of the SwissProt proteins, we don't store it in our HuggingFace dataset, but do provide a placeholder directory in which to download it. Please see the "Structure data" section of our repo's README for details.

In [17]:
example_id = example_prots[0].uniprot_id
struct_template = get_data_dir() / "afdb_structures" / "AF-%s-F1-model_v6.pdb"

print(os.path.exists(str(struct_template) % example_id))

True


Similarly, the associated amino acid sequence can be obtained as follows:

In [18]:
from pysam import FastaFile

In [19]:
fa = FastaFile(fasta_path)

# Note the use of kb_id instead of uniprot_id
print(fa[example_prots[0].kb_id])

MSLEQKKGADIISKILQIQNSIGKTTSPSTLKTKLSEISRKEQENARIQSKLSDLQKKKIDIDNKLLKEKQNLIKEEILERKKLEVLTKKQQKDEIEHQKKLKREIDAIKASTQYITDVSISSYNNTIPETEPEYDLFISHASEDKEDFVRPLAETLQQLGVNVWYDEFTLKVGDSLRQKIDSGLRNSKYGTVVLSTDFIKKDWTNYELDGLVAREMNGHKMILPIWHKITKNDVLDYSPNLADKVALNTSVNSIEEIAHQLADVILNR


## Dataset splits

Dataset splits are available as a TSV containing columns for UniProt ID and the split each protein is assigned to.

Note that in order to minimize information leakage, we construct our splits using pre-computed clusters of proteins that have been clustered on the basis of either sequence or structure similarity. Both sets of splits are available, with the structure-based splits being more stringent in terms of substructure observations being mostly assigned to a single split.

We also make available more stringent splits that remove from training all proteins that are contained in the test set of any of our downstream evaluation tasks.

In [20]:
splits_path = get_data_dir() / "interpro_103.0" / "dataset_splits"

[x.name for x in splits_path.glob("*splits*.tsv")]

['debug_splits.tsv',
 'seq_splits.mmseq0.3_exact_align.tsv',
 'seq_splits.tsv',
 'struct_splits.tsv',
 'seq_splits.stringent_clusters.tsv']

For more details on split construction, see [`create_dataset_splits.ipynb`](https://github.com/rcalef/magneton/blob/main/experiments/create_dataset_splits.ipynb)

# Data loaders

To actually train and evaluate a model, we'll need to construct data loaders from these protein datasets.

For this example, we use the debug dataset splits, which contain ~1000 proteins per split.

In [21]:
from magneton.config import DataConfig
from magneton.core_types import SubstructType
from magneton.data import MagnetonDataModule

In [22]:
interpro_path = get_data_dir() / "interpro_103.0"
label_path = interpro_path / "labels" / "selected_subset"

dataset_path = interpro_path / "swissprot_subset"

splits_path = interpro_path / "dataset_splits" / "debug_splits.tsv"

data_config = DataConfig(
    data_dir=dataset_path,
    fasta_path=fasta_path,
    labels_path=label_path,
    splits=splits_path,
    batch_size=4,
    substruct_types=[SubstructType.DOMAIN],
    num_loader_workers=8,
)

In [23]:
data_module = MagnetonDataModule(
    data_config=data_config,
    # Model type specifies the model-specific transforms to
    # add on top of the base datasets, e.g. tokenization
    model_type="esmc",
)

In [24]:
loader = data_module.train_dataloader()

Forcing collapse_labels to True for simplicity.


processing proteins: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 55/55 [00:23<00:00,  2.36it/s]
INFO:magneton.data.core.unified_dataset:split train: got 366 proteins
INFO:magneton.data.data_modules:remaining proteins after length filter: 342 / 366


In [25]:
it = iter(loader)

In [26]:
batch = next(it)
batch

ESMCBatch(protein_ids=['A0A0C4DH68', 'A0A0H3NGY1', 'A0A0A1EQ07', 'A0A0H3H456'], lengths=[120, 239, 313, 472], seqs=None, substructures=[[LabeledSubstructure(ranges=[tensor([ 16, 120])], label=431, element_type=<SubstructType.DOMAIN: 'Domain'>)], [LabeledSubstructure(ranges=[tensor([  6, 120])], label=191, element_type=<SubstructType.DOMAIN: 'Domain'>), LabeledSubstructure(ranges=[tensor([135, 234])], label=201, element_type=<SubstructType.DOMAIN: 'Domain'>)], [LabeledSubstructure(ranges=[tensor([ 68, 294])], label=530, element_type=<SubstructType.DOMAIN: 'Domain'>)], [LabeledSubstructure(ranges=[tensor([397, 463])], label=462, element_type=<SubstructType.DOMAIN: 'Domain'>), LabeledSubstructure(ranges=[tensor([156, 243])], label=5, element_type=<SubstructType.DOMAIN: 'Domain'>)]], structure_list=None, labels=None, tokenized_seq=tensor([[ 0, 20, 10,  ...,  1,  1,  1],
        [ 0, 20, 16,  ...,  1,  1,  1],
        [ 0, 20,  5,  ...,  1,  1,  1],
        [ 0, 20,  7,  ..., 10,  6,  2]]))