In [1]:
from utils import * 
from src.files.xml import XMLFile
from src.files.fasta import FASTAFile
from sklearn.model_selection import StratifiedShuffleSplit, GroupShuffleSplit, StratifiedGroupKFold

%load_ext autoreload
%autoreload 2

# Downloaded AntiFam-annotated sequences from UniProt on 01.24.2026. 
# Dowloaded taxonomy metadata from NCBI on 04.08.2025. 

In [2]:
# ! mamba create -n sprout python=3.10 numpy scipy pandas scikit-learn matplotlib seaborn -c conda-forge
# ! mamba install pytorch torchvision torchaudio -c pytorch -c conda-forge
# ! mamba install jupyterlab ipykernel ipywidgets notebook -c conda-forge
# ! mamba install biopython pyarrow hdf5 pillow lxml pyyaml requests -c conda-forge

In [3]:
# Downloaded AntiFam-annotated sequences from UniProt on 01.24.2026. 
# https://rest.uniprot.org/uniprotkb/stream?compressed=true&format=xml&query=%28%28taxonomy_id%3A2%29+AND+%28database%3Aantifam%29%29

In [4]:
# XMLFile('../data/uniprot-antifam.xml').to_df().to_csv('../data/uniprot-antifam.csv')

In [5]:
dataset_df = pd.read_csv('../data/uniprot-antifam.csv').assign(label=0)
dataset_df = dataset_df[(dataset_df.pfam_id == 'none') & (dataset_df.existence.str.contains('predicted|uncertain', regex=True))].copy()
dataset_df = pd.concat([dataset_df, pd.read_csv('../data/uniprot-sprot.csv').assign(label=1)])
dataset_df = dataset_df[dataset_df.domain == 'Bacteria'].copy() # Filter out non-bacteria. 
dataset_df = dataset_df[(dataset_df.non_terminal_residue == 'none') | dataset_df.non_terminal_residue.isnull()].copy() # Remove fragments
dataset_df = dataset_df.drop_duplicates('id').fillna('none') # Prevent mixed data types.
dataset_df = dataset_df[dataset_df.seq.apply(len) <= 1024].copy() # Remove sequences whose length exceeds the ESM2 window size. 

In [6]:
# mmseqs easy-cluster dataset.fasta dataset tmp --min-seq-id 0.8
dataset_cluster_df = pd.read_csv('../data/dataset_cluster.tsv', sep='\t', names=['rep_id', 'id'])
dataset_df = dataset_df[dataset_df['id'].isin(dataset_cluster_df.rep_id.unique())].copy()

In [7]:
# mmseqs easy-cluster dataset.fasta dataset tmp --min-seq-id 0.5
dataset_cluster_df = pd.read_csv('../data/dataset_dereplicated_cluster.tsv', sep='\t', names=['rep_id', 'id'])
dataset_df['cluster_id'] = dataset_df['id'].map(dataset_cluster_df.set_index('id').rep_id)

In [8]:
print('All clusters homogenous?', np.all((dataset_df.groupby('cluster_id').label.nunique() == 1)))
print('Num. clusters:', dataset_df.cluster_id.nunique())
print('Num. singleton clusters:', (dataset_df.groupby('cluster_id').size() == 1).sum())

All clusters homogenous? True
Num. clusters: 50496
Num. singleton clusters: 35496


In [9]:
# split = StratifiedShuffleSplit(n_splits=2, test_size=0.2, train_size=0.8, random_state=42)
# split = GroupShuffleSplit(n_splits=2, test_size=0.2, random_state=42)
split = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)

X = dataset_df['id'].values 
y = dataset_df.label.values
groups = dataset_df['cluster_id'].values

train_idxs, test_idxs = next(split.split(X, y, groups=groups))
dataset_train_df, dataset_test_df = dataset_df.iloc[train_idxs].copy(), dataset_df.iloc[test_idxs].copy()
dataset_train_df.to_csv('../data/dataset_train.csv')
dataset_test_df.to_csv('../data/dataset_test.csv')

print('Num. elements in training set:', len(train_idxs))
print('Num. elements in testing set:', len(test_idxs))

Num. elements in training set: 109726
Num. elements in testing set: 27734


In [10]:
print(f'Class balance in training data: {100 * (dataset_train_df.label == 0).sum() / len(dataset_train_df):.2f}% spurious')

Class balance in training data: 6.33% spurious
