In [1]:
import os

from collections import defaultdict

import numpy as np
import pandas as pd

from magneton.io.internal import ProteinDataset, shard_proteins

In [2]:
# The following files contain the FoldSeek sequence-based clusters (AFDB50) and structure-based
# clusters (FoldSeek-Cluster), subset to only include SwissProt proteins.
data_dir = "/weka/scratch/weka/kellislab/rcalef/data"
interpro_dir = os.path.join(
    data_dir,
    "interpro",
    "103.0",
    "swissprot",
    "sharded_swissprot",
    "with_ss",
)
outdir = os.path.join(interpro_dir, "dataset_splits")
os.makedirs(outdir, exist_ok=True)

foldseek_dir = os.path.join(data_dir, "foldseek_cluster")
seq_cluster_path = os.path.join(foldseek_dir, "afdb50_clusters_swissprot.tsv")
struct_cluster_path = os.path.join(foldseek_dir, "foldseek_clusters_swissprot.tsv")

In [3]:
# `rep_id` gives the UniProt ID of the representative for the cluster, `uniprot_id` gives
# the UniProt ID for the actual individual protein. File has an extraneous blank column.
seq_cluster_assignments = pd.read_table(seq_cluster_path, names=["rep_id", "uniprot_id", "X"]).drop(columns=["X"])
seq_cluster_assignments.head()

Unnamed: 0,rep_id,uniprot_id
0,A0A1Z5LF99,Q7JYV2
1,B9LLY0,B9LLY0
2,B9LLY0,A9WHT8
3,P07316,A3RLE1
4,P07316,P04344


In [4]:
# `rep_id` gives the UniProt ID of the representative for the cluster, `uniprot_id` gives
# the UniProt ID for the actual individual protein.
# `clust_type` describes the reason the protein is included in the cluster (i.e. based on sequence or structure),
# or whether or not FoldSeek authors excluded it and why.
# `tax_id` is just taxonomic ID for organism
# See description for file `5-allmembers-repId-entryId-cluFlag-taxId.tsv.gz` here:
#  https://afdb-cluster.steineggerlab.workers.dev/

struct_cluster_assignments = (
    pd.read_table(struct_cluster_path, names=["rep_id", "uniprot_id", "clust_type", "tax_id"])
    .assign(
        clust_type=lambda x: x.clust_type.map({
            1: "seq_clust",
            2: "struct_clust",
            3: "fragments",
            4: "singleton",
        })
    )
)
struct_cluster_assignments.head()

Unnamed: 0,rep_id,uniprot_id,clust_type,tax_id
0,A0A009J8A1,P0ACU7,seq_clust,83333
1,A0A009J8A1,P0ACU8,seq_clust,199310
2,A0A009J8A1,P0ACU9,seq_clust,623
3,A0A010QBH7,P87145,struct_clust,284812
4,A0A010R299,Q2U1H5,struct_clust,510516


In [5]:
struct_cluster_assignments.clust_type.value_counts()

clust_type
seq_clust       439154
struct_clust     95602
singleton         9477
fragments         3552
Name: count, dtype: int64

In [6]:
prot_dataset = ProteinDataset(interpro_dir, prefix="swissprot.with_ss")
prots = list(prot_dataset)
dataset_ids = pd.Series([x.uniprot_id for x in prots])
len(dataset_ids)

530601

In [7]:
dataset_ids.isin(seq_cluster_assignments.uniprot_id).value_counts()

True    530601
Name: count, dtype: int64

In [8]:
dataset_ids.isin(struct_cluster_assignments.uniprot_id).value_counts()

True    530601
Name: count, dtype: int64

In [9]:
(
    struct_cluster_assignments
    .loc[lambda x: x.uniprot_id.isin(dataset_ids)]
    .clust_type.value_counts()
)

clust_type
seq_clust       430991
struct_clust     92059
singleton         4962
fragments         2589
Name: count, dtype: int64

In [10]:
# Set the random seed for reproducibility
np.random.seed(42)

# Get the unique cluster identifiers
cluster_ids = struct_cluster_assignments['rep_id'].unique()

# Shuffle the cluster identifiers randomly
np.random.shuffle(cluster_ids)

# Calculate the number of clusters for each split
total_clusters = len(cluster_ids)
train_clusters_count = int(0.8 * total_clusters)
val_clusters_count = int(0.1 * total_clusters)
test_clusters_count = total_clusters - train_clusters_count - val_clusters_count

# Split the shuffled cluster identifiers into three sets
train_clusters = cluster_ids[:train_clusters_count]
val_clusters = cluster_ids[train_clusters_count:train_clusters_count+val_clusters_count]
test_clusters = cluster_ids[train_clusters_count+val_clusters_count:]

# Filter the struct_cluster_assignments dataframe based on the cluster splits
train_data = struct_cluster_assignments[struct_cluster_assignments['rep_id'].isin(train_clusters)]
val_data = struct_cluster_assignments[struct_cluster_assignments['rep_id'].isin(val_clusters)]
test_data = struct_cluster_assignments[struct_cluster_assignments['rep_id'].isin(test_clusters)]

# Get the corresponding datapoints for each split
train_datapoints = dataset_ids[dataset_ids.isin(train_data['uniprot_id'])]
val_datapoints = dataset_ids[dataset_ids.isin(val_data['uniprot_id'])]
test_datapoints = dataset_ids[dataset_ids.isin(test_data['uniprot_id'])]

# Print total number of clusters and datapoints per split
def summary(name, clusters_count, split_datapoints):
    print(f"{name} clusters: {clusters_count}, datapoints: {len(split_datapoints)} ({len(split_datapoints) / len(dataset_ids) * 100:.2f}%) proteins per cluster: {len(split_datapoints) / clusters_count:.2f}")
summary("Train", train_clusters_count, train_datapoints)
summary("Val", val_clusters_count, val_datapoints)
summary("Test", test_clusters_count, test_datapoints)

Train clusters: 59400, datapoints: 423821 (79.88%) proteins per cluster: 7.14
Val clusters: 7425, datapoints: 52623 (9.92%) proteins per cluster: 7.09
Test clusters: 7426, datapoints: 54157 (10.21%) proteins per cluster: 7.29


In [11]:
# Confirm train, val, and test sets are disjoint
assert len(set(train_datapoints) & set(val_datapoints)) == 0
assert len(set(train_datapoints) & set(test_datapoints)) == 0

# Confirm all datapoints have been assigned to one of the splits
assert len(set(train_datapoints)) + len(set(val_datapoints)) + len(set(test_datapoints)) == len(set(dataset_ids))

In [13]:
splits = {
    "train": train_datapoints,
    "val": val_datapoints,
    "test": test_datapoints,
}
lookups = {k:set(v) for k,v in splits.items()}

all_splits = []
split_prots = defaultdict(list)

for prot in prots:
    for split, lookup in lookups.items():
        if prot.uniprot_id in lookup:
            split_prots[split].append(prot)
            break


for split_name, datapoints in splits.items():
    all_splits.append(
        datapoints.rename("uniprot_id").to_frame().assign(split=split_name)
    )

    split_dir = os.path.join(outdir, f"{split_name}_sharded")
    os.makedirs(split_dir)
    shard_proteins(
        split_prots[split_name],
        split_dir,
        prefix=f"swissprot.with_ss.{split_name}",
        prots_per_file=10000,
    )

all_splits = pd.concat(all_splits)
all_splits.to_csv(os.path.join(outdir, "dataset_splits.tsv"), index=False, sep="\t")

completed file 1, starting file 2
completed file 2, starting file 3
completed file 3, starting file 4
completed file 4, starting file 5
completed file 5, starting file 6
completed file 6, starting file 7
completed file 7, starting file 8
completed file 8, starting file 9
completed file 9, starting file 10
completed file 10, starting file 11
completed file 11, starting file 12
completed file 12, starting file 13
completed file 13, starting file 14
completed file 14, starting file 15
completed file 15, starting file 16
completed file 16, starting file 17
completed file 17, starting file 18
completed file 18, starting file 19
completed file 19, starting file 20
completed file 20, starting file 21
completed file 21, starting file 22
completed file 22, starting file 23
completed file 23, starting file 24
completed file 24, starting file 25
completed file 25, starting file 26
completed file 26, starting file 27
completed file 27, starting file 28
completed file 28, starting file 29
completed