In [1]:
import torch
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Torch version: 2.5.0+cu124
CUDA available: True
Using device: cuda


In [2]:
import zipfile
import os

if not os.path.exists("data/"):
	# path to zip file
	zip_path = "cafa-6-protein-function-prediction.zip"

	# folder to extract to
	extract_dir = "data"

	# make sure the folder exists
	os.makedirs(extract_dir, exist_ok=True)

	# extract
	with zipfile.ZipFile(zip_path, 'r') as zip_ref:
		zip_ref.extractall(extract_dir)

	print(f"Extracted all files to: {extract_dir}")



In [3]:
from Bio import SeqIO
import pandas as pd

# Load sequences
train_sequences = list(SeqIO.parse("data/Train/train_sequences.fasta", "fasta"))
print(f"Total sequences: {len(train_sequences)}")

# Convert to DataFrame
seq_df = pd.DataFrame({
    "entry_id": [record.id for record in train_sequences],
    "sequence": [str(record.seq) for record in train_sequences],
})
seq_df["seq_len"] = seq_df["sequence"].str.len()
seq_df["entry_id"] = seq_df["entry_id"].str.extract(r"sp\|([^|]+)\|")[0]
#seq_df.head()

# Load taxonomy and GO terms
tax_df = pd.read_csv("data/Train/train_taxonomy.tsv", sep="\t",header=None)
tax_df.columns = ["entry_id", "taxonomy"]
terms_df = pd.read_csv("data/Train/train_terms.tsv", sep="\t")
terms_df.columns = terms_df.columns.str.lower()
terms_df.rename(columns={"entryid": "entry_id"}, inplace=True)

# Merge all
train_df =  (
    seq_df
    .merge(tax_df, on="entry_id", how="left")
    .merge(terms_df, on="entry_id", how="left")
)
train_df.head()


Total sequences: 82404


Unnamed: 0,entry_id,sequence,seq_len,taxonomy,term,aspect
0,A0A0C5B5G6,MRWQEMGYIFYPRKLR,16,9606,GO:0001649,P
1,A0A0C5B5G6,MRWQEMGYIFYPRKLR,16,9606,GO:0033687,P
2,A0A0C5B5G6,MRWQEMGYIFYPRKLR,16,9606,GO:0005615,C
3,A0A0C5B5G6,MRWQEMGYIFYPRKLR,16,9606,GO:0005634,C
4,A0A0C5B5G6,MRWQEMGYIFYPRKLR,16,9606,GO:0005739,C


In [4]:
from Bio import SeqIO
import pandas as pd

records = list(SeqIO.parse("data/Test/testsuperset.fasta", "fasta"))

entry_ids = []
tax_ids = []
sequences = []

for r in records:
    parts = r.description.split()
    entry_ids.append(parts[0])        # e.g., A0A0C5B5G6
    tax_ids.append(parts[1])          # e.g., 9606
    sequences.append(str(r.seq))

test_df = pd.DataFrame({
    "entry_id": entry_ids,
    "tax_id": tax_ids,
    "sequence": sequences
})

test_df["seq_len"] = test_df["sequence"].str.len()


In [5]:
# Check data sizes and overlaps
print(f"Test sequences: {test_df.shape[0]}")
print(f"Train protein-function pairs: {train_df.shape[0]}")
print(f"Unique proteins in train: {train_df['entry_id'].nunique()}")
print(f"Unique GO terms: {train_df['term'].nunique()}")

# Check for sequence overlap between train and test
train_proteins = set(train_df['entry_id'].unique())
test_proteins = set(test_df['entry_id'].unique())
overlap = train_proteins.intersection(test_proteins)
print(f"Proteins in both train and test: {len(overlap)}")

Test sequences: 224309
Train protein-function pairs: 537027
Unique proteins in train: 82404
Unique GO terms: 26125
Proteins in both train and test: 82404


In [6]:
# Analyze GO term frequency
term_counts = train_df['term'].value_counts()
print("GO term frequency distribution:")
print(term_counts.describe())

# Focus on frequent terms 
min_occurrence = 1
frequent_terms = term_counts[term_counts >= min_occurrence].index
print(f"Terms with ≥{min_occurrence} occurrences: {len(frequent_terms)}")

# Filter training data to frequent terms
filtered_train_df = train_df[train_df['term'].isin(frequent_terms)]

GO term frequency distribution:
count    26125.000000
mean        20.556057
std        268.143836
min          1.000000
25%          2.000000
50%          4.000000
75%         12.000000
max      33713.000000
Name: count, dtype: float64
Terms with ≥1 occurrences: 26125


In [7]:
import pandas as pd
import numpy as np
import torch
import esm
from Bio import SeqIO
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.multioutput import MultiOutputClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score, precision_score, recall_score
import scipy.sparse as sp
import warnings
warnings.filterwarnings('ignore')

print("Loading ESM model...")
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
esm_model = esm_model.to(device)
esm_model.eval()
print("ESM model loaded successfully!")


Loading ESM model...
ESM model loaded successfully!


In [8]:
print("=== DATA SUMMARY ===")
print(f"Train df shape: {train_df.shape}")
print(f"Test df shape: {test_df.shape}")

print("\n=== TRAIN DATA PREVIEW ===")
print(train_df.head())
print("\n=== TEST DATA PREVIEW ===")
print(test_df.head())

print("\n=== COLUMN NAMES ===")
print(f"Train columns: {train_df.columns.tolist()}")
print(f"Test columns: {test_df.columns.tolist()}")

=== DATA SUMMARY ===
Train df shape: (537027, 6)
Test df shape: (224309, 4)

=== TRAIN DATA PREVIEW ===
     entry_id          sequence  seq_len  taxonomy        term aspect
0  A0A0C5B5G6  MRWQEMGYIFYPRKLR       16      9606  GO:0001649      P
1  A0A0C5B5G6  MRWQEMGYIFYPRKLR       16      9606  GO:0033687      P
2  A0A0C5B5G6  MRWQEMGYIFYPRKLR       16      9606  GO:0005615      C
3  A0A0C5B5G6  MRWQEMGYIFYPRKLR       16      9606  GO:0005634      C
4  A0A0C5B5G6  MRWQEMGYIFYPRKLR       16      9606  GO:0005739      C

=== TEST DATA PREVIEW ===
     entry_id tax_id                                           sequence  \
0  A0A0C5B5G6   9606                                   MRWQEMGYIFYPRKLR   
1  A0A1B0GTW7   9606  MLLLLLLLLLLPPLVLRVAASRCLHDETQKSVSLLRPPFSQLPSKS...   
2      A0JNW5   9606  MAGIIKKQILKHLSRFTKNLSPDKINLSTLKGEGELKNLELDEEVL...   
3      A0JP26   9606  MVAEVCSMPAASAVKKPFDLRSKMGKWCHHRFPCCRGSGKSNMGTS...   
4      A0PK11   9606  MPGWFKKAWYGLASLLSFSSFILIIVALVVPHWLSGKILCQTGVDL...   

In [9]:
class ProteinDataset(Dataset):
    def __init__(self, df, max_length=1024):
        """
        df: DataFrame with 'entry_id' and 'sequence' columns
        max_length: Truncate sequences longer than this
        """
        self.df = df.reset_index(drop=True)
        self.max_length = max_length
        
    def __len__(self):
        return len(self.df)
    
    def __hash__(self):
        df_str = self.df.to_string(index=False, header=True)
        return hash(df_str)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        entry_id = row['entry_id']
        sequence = row['sequence'][:self.max_length]  # Truncate if too long
        return entry_id, sequence

import os
import torch
import pickle
from torch.utils.data import DataLoader

def get_esm_embeddings(dataset, batch_size=8, show_progress=True, cache_file="esm_embeddings.pkl"):
    """
    Extract ESM embeddings for all sequences in dataset.
    If embeddings are cached, load them from the cache file.
    
    Returns: dict {entry_id: embedding_vector}
    """
    # Check if the embeddings are already cached
    print("cache file: ", cache_file)
    if os.path.exists(cache_file):
        print(f"Loading cached embeddings from {cache_file}")
        with open(cache_file, 'rb') as f:
            embeddings = pickle.load(f)
        return embeddings
    
    print("Extracting embeddings...")

    # Initialize dataloader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    embeddings = {}

    if show_progress:
        from tqdm import tqdm
        iterator = tqdm(dataloader, desc="Extracting embeddings")
    else:
        iterator = dataloader
    
    with torch.no_grad():
        for batch in iterator:
            entry_ids, sequences = batch
            
            # Prepare batch for ESM
            batch_data = [(entry_id, sequence) for entry_id, sequence in zip(entry_ids, sequences)]
            batch_labels, batch_strs, batch_tokens = batch_converter(batch_data)
            batch_tokens = batch_tokens.to(device)
            
            # Get embeddings 
            results = esm_model(batch_tokens, repr_layers=[esm_model.num_layers], return_contacts=False)
            token_representations = results["representations"][esm_model.num_layers]
            
            # Create protein embedding (mean of all tokens except CLS and PAD)
            for i, entry_id in enumerate(entry_ids):
                seq_len = len(batch_strs[i])
                embedding = token_representations[i, 1:seq_len+1].mean(dim=0)  # Exclude CLS token
                embeddings[entry_id] = embedding.cpu().numpy()

            # Clear GPU memory
            torch.cuda.empty_cache()

    # Cache the embeddings for future use
    with open(cache_file, 'wb') as f:
        pickle.dump(embeddings, f)
        print(f"Embeddings cached to {cache_file}")
    
    return embeddings


def create_label_matrix(train_df, protein_list, term_list):
    """
    Create binary label matrix for multi-label classification
    Returns: sparse matrix (proteins x terms), protein_to_idx mapping
    """
    protein_to_idx = {pid: idx for idx, pid in enumerate(protein_list)}
    term_to_idx = {term: idx for idx, term in enumerate(term_list)}
    
    rows, cols = [], []
    for _, row in train_df.iterrows():
        if row['entry_id'] in protein_to_idx and row['term'] in term_to_idx:
            rows.append(protein_to_idx[row['entry_id']])
            cols.append(term_to_idx[row['term']])
    
    data = np.ones(len(rows))
    label_matrix = sp.csr_matrix((data, (rows, cols)), 
                                shape=(len(protein_list), len(term_list)))
    
    return label_matrix, protein_to_idx

In [None]:
from sklearn.multioutput import MultiOutputClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score
import numpy as np
import scipy.sparse as sp
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

# class MultiOutputNN(nn.Module):
# 	def __init__(self, input_size, output_size, hidden_size=1000):
# 		super(MultiOutputNN, self).__init__()
# 		self.fc_in = nn.Linear(input_size, hidden_size)
# 		self.fc1 = nn.Linear(hidden_size, hidden_size)
# 		self.fc2 = nn.Linear(hidden_size, hidden_size)
# 		self.fc3 = nn.Linear(hidden_size, hidden_size)
# 		self.fc_out = nn.Linear(hidden_size, output_size)
# 		self.relu = nn.ReLU()
# 		self.sigmoid = nn.Sigmoid()
	
# 	def forward(self, x):
# 		x = self.relu(self.fc_in(x))
# 		x = self.relu(self.fc1(x))
# 		x = self.relu(self.fc2(x))
# 		x = self.relu(self.fc3(x))
# 		x = self.fc_out(x)
# 		return self.sigmoid(x)

class MultiOutputCNN(nn.Module):
    def __init__(self, input_size, output_size, hidden_channels=32, kernel_size=10):
        super(MultiOutputCNN, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv1d(1, hidden_channels, kernel_size, padding=1),
            nn.ReLU(),
            nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=1),
            nn.ReLU(),
        )

        # Global pooling → fixed-size output
        self.global_pool = nn.AdaptiveMaxPool1d(1)

        self.fc_out = nn.Linear(hidden_channels, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x: (batch, embedding_dim)
        x = x.unsqueeze(1)              # → (batch, 1, embedding_dim)
        x = self.conv(x)                # → (batch, hidden_channels, L)
        x = self.global_pool(x)         # → (batch, hidden_channels, 1)
        x = x.squeeze(-1)               # → (batch, hidden_channels)
        x = self.fc_out(x)
        return self.sigmoid(x)



def training(train_df, test_df, term_counts, min_occurrence, batch_size=16):
	# Select GO terms with enough occurrences
	terms = term_counts[term_counts >= min_occurrence].index
	train = train_df[train_df['term'].isin(terms)]
		
	# Select proteins present in test_df
	# common_proteins = set(train['entry_id']).intersection(set(test_df['entry_id']))
	proteins = list(set(train['entry_id']))
		
	train = train[train['entry_id'].isin(proteins)]
	train = train.drop_duplicates(subset=['entry_id'])
	test = test_df[test_df['entry_id'].isin(proteins)]
	
	print(f"Training examples: {len(train)}")
	print(f"Unique proteins: {len(proteins)}")
	print(f"GO terms: {len(terms)}")
	# Create dataset and extract embeddings
	print("Creating ProteinDataset and extracting embeddings...")
	dataset = ProteinDataset(test)
	embeddings = get_esm_embeddings(dataset, batch_size=1, cache_file=f"generated/esm_650m/esm_embeddings_train.pkl")
	
	# Build sparse label matrix
	terms_2 = term_counts[(term_counts > 50)].index
	terms = terms_2
	y_train, _protein_map = create_label_matrix(train, proteins, terms)
	X_train = np.array([embeddings[pid] for pid in proteins])
	
	print(f"Feature matrix: X={X_train.shape}, Label matrix: y={y_train.shape}")
	
	def train_safe_classifier(X, y, terms, test_size=0.2, lr=0.001, epochs=100):
		X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=test_size, random_state=42)
		
		# Identify columns (GO terms) with at least 1 positive sample
		valid_cols = np.where(y_train.sum(axis=0) > 0)[0]
		if len(valid_cols) < y_train.shape[1]:
			removed_terms = [terms[i] for i in range(len(terms)) if i not in valid_cols]
			print(f"Removing {len(removed_terms)} GO terms with no positives in training: {removed_terms[:10]}{'...' if len(removed_terms) > 10 else ''}")
		
		# Filter y to valid columns (GO terms)
		y_train_filtered = y_train[:, valid_cols]
		y_val_filtered = y_val[:, valid_cols]
		filtered_terms = [terms[i] for i in valid_cols]
		
		# Train multi-output neural network model
		# base_clf = MLPClassifier(random_state=42, max_iter=100, solver='adam', hidden_layer_sizes=(100,100), activation='relu', verbose=True)
		# clf = MultiOutputClassifier(base_clf)
		# clf.fit(X_train, y_train_filtered)
		
		# y_pred = clf.predict(X_val)
		# f1 = f1_score(y_val_filtered, y_pred, average='micro')
		# precision = precision_score(y_val_filtered, y_pred, average='micro')
		# recall = recall_score(y_val_filtered, y_pred, average='micro')
		
		# print("=== VALIDATION RESULTS ===")
		# print(f"Micro F1-score: {f1:.4f}")
		# print(f"Micro Precision: {precision:.4f}")
		# print(f"Micro Recall: {recall:.4f}")
		# print(f"Predicted labels: {y_pred.sum()} / {y_val_filtered.sum()} actual")
		
		# return clf, (f1, precision, recall), filtered_terms

		X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
		y_train_tensor = torch.tensor(y_train_filtered, dtype=torch.float32).to(device)
		X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(device)
		y_val_tensor = torch.tensor(y_val_filtered, dtype=torch.float32).to(device)
		
		train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
		train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
		
		# Initialize the model
		input_size = X_train.shape[1]
		output_size = y_train_filtered.shape[1]
		model = MultiOutputCNN(input_size, output_size)
		model = model.to(device)
		
		# Define loss function and optimizer
		criterion = nn.BCELoss()  # Use BCE loss for multi-label classification
		optimizer = optim.Adam(model.parameters(), lr=lr)

		# Training loop
		for epoch in range(epochs):
			model.train()
			epoch_loss = 0.0
			with tqdm(train_loader, desc=f"Epoch [{epoch+1}/{epochs}]", unit="batch", ncols=100) as t:
				for batch_idx, (X_batch, y_batch) in enumerate(t):
					optimizer.zero_grad()

					# Forward pass
					y_pred_train = model(X_batch)

					# Compute loss
					loss = criterion(y_pred_train, y_batch)
					epoch_loss += loss.item()

					# Backward pass and optimization
					loss.backward()
					optimizer.step()

					# Update the tqdm progress bar with current loss
					t.set_postfix(loss=epoch_loss / (batch_idx + 1))
			
			avg_loss = epoch_loss / len(train_loader)
			print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.8f}")

		# model.eval()
		# with torch.no_grad():
		# 	y_pred_val = model(X_val_tensor)
		# 	# print(y_pred_val)
		# 	y_pred_val = (y_pred_val > 0.01).float()  # Convert logits to binary predictions (0 or 1)

		# 	f1 = f1_score(y_val_tensor.cpu(), y_pred_val.cpu(), average='micro')
		# 	precision = precision_score(y_val_tensor.cpu(), y_pred_val.cpu(), average='micro')
		# 	recall = recall_score(y_val_tensor.cpu(), y_pred_val.cpu(), average='micro')

		# 	print("=== VALIDATION RESULTS ===")
		# 	print(f"Micro F1-score: {f1:.4f}")
		# 	print(f"Micro Precision: {precision:.4f}")
		# 	print(f"Micro Recall: {recall:.4f}")
		# 	print(f"Predicted labels: {y_pred_val.sum()} / {y_val_tensor.sum()} actual")
		
		return model, (0, 0, 0), filtered_terms
		
	# Train and evaluate
	classifier, metrics, filtered_terms = train_safe_classifier(X_train, y_train.toarray(), terms)
		
	return classifier, metrics, filtered_terms, X_train, y_train

classifier, metrics, filtered_terms, X_train_after, y_train_after = training(
	filtered_train_df, test_df, term_counts,
	min_occurrence=1,
	batch_size=128
)


Training examples: 82404
Unique proteins: 82404
GO terms: 26125
Creating ProteinDataset and extracting embeddings...
cache file:  generated/esm_650m/esm_embeddings_train.pkl
Loading cached embeddings from generated/esm_650m/esm_embeddings_train.pkl
Feature matrix: X=(82404, 1280), Label matrix: y=(82404, 1542)
Removing 122 GO terms with no positives in training: ['GO:0051897', 'GO:0006325', 'GO:0051649', 'GO:0000165', 'GO:0007204', 'GO:0070588', 'GO:0042826', 'GO:0051402', 'GO:0061844', 'GO:0007519']...


Epoch [1/100]: 100%|██████████████████████████████| 516/516 [00:05<00:00, 90.88batch/s, loss=0.0363]


Epoch [1/100], Loss: 0.03632479


Epoch [2/100]: 100%|█████████████████████████████| 516/516 [00:05<00:00, 93.37batch/s, loss=0.00348]


Epoch [2/100], Loss: 0.00348301


Epoch [3/100]: 100%|█████████████████████████████| 516/516 [00:05<00:00, 94.22batch/s, loss=0.00333]


Epoch [3/100], Loss: 0.00333062


Epoch [4/100]: 100%|█████████████████████████████| 516/516 [00:05<00:00, 93.30batch/s, loss=0.00329]


Epoch [4/100], Loss: 0.00328700


Epoch [5/100]: 100%|█████████████████████████████| 516/516 [00:05<00:00, 91.91batch/s, loss=0.00328]


Epoch [5/100], Loss: 0.00327745


Epoch [6/100]: 100%|█████████████████████████████| 516/516 [00:05<00:00, 94.70batch/s, loss=0.00327]


Epoch [6/100], Loss: 0.00326977


Epoch [7/100]: 100%|█████████████████████████████| 516/516 [00:05<00:00, 94.91batch/s, loss=0.00327]


Epoch [7/100], Loss: 0.00326941


Epoch [8/100]: 100%|█████████████████████████████| 516/516 [00:05<00:00, 92.26batch/s, loss=0.00327]


Epoch [8/100], Loss: 0.00326982


Epoch [9/100]: 100%|█████████████████████████████| 516/516 [00:05<00:00, 93.49batch/s, loss=0.00326]


Epoch [9/100], Loss: 0.00326168


Epoch [10/100]: 100%|████████████████████████████| 516/516 [00:05<00:00, 88.85batch/s, loss=0.00325]


Epoch [10/100], Loss: 0.00325325


Epoch [11/100]: 100%|████████████████████████████| 516/516 [00:05<00:00, 96.48batch/s, loss=0.00325]


Epoch [11/100], Loss: 0.00324806


Epoch [12/100]: 100%|████████████████████████████| 516/516 [00:05<00:00, 94.89batch/s, loss=0.00324]


Epoch [12/100], Loss: 0.00323963


Epoch [13/100]: 100%|████████████████████████████| 516/516 [00:05<00:00, 93.29batch/s, loss=0.00323]


Epoch [13/100], Loss: 0.00322956


Epoch [14/100]: 100%|████████████████████████████| 516/516 [00:05<00:00, 90.48batch/s, loss=0.00322]


Epoch [14/100], Loss: 0.00321860


Epoch [15/100]: 100%|████████████████████████████| 516/516 [00:05<00:00, 92.04batch/s, loss=0.00321]


Epoch [15/100], Loss: 0.00320609


Epoch [16/100]: 100%|█████████████████████████████| 516/516 [00:05<00:00, 94.37batch/s, loss=0.0032]


Epoch [16/100], Loss: 0.00319516


Epoch [17/100]: 100%|████████████████████████████| 516/516 [00:05<00:00, 93.52batch/s, loss=0.00318]


Epoch [17/100], Loss: 0.00318158


Epoch [18/100]: 100%|████████████████████████████| 516/516 [00:05<00:00, 93.58batch/s, loss=0.00316]


Epoch [18/100], Loss: 0.00316289


Epoch [19/100]: 100%|████████████████████████████| 516/516 [00:05<00:00, 94.76batch/s, loss=0.00315]


Epoch [19/100], Loss: 0.00315301


Epoch [20/100]: 100%|████████████████████████████| 516/516 [00:05<00:00, 94.49batch/s, loss=0.00313]


Epoch [20/100], Loss: 0.00313218


Epoch [21/100]: 100%|████████████████████████████| 516/516 [00:05<00:00, 86.83batch/s, loss=0.00312]


Epoch [21/100], Loss: 0.00312107


Epoch [22/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 85.16batch/s, loss=0.00311]


Epoch [22/100], Loss: 0.00310551


Epoch [23/100]: 100%|████████████████████████████| 516/516 [00:05<00:00, 87.98batch/s, loss=0.00309]


Epoch [23/100], Loss: 0.00309168


Epoch [24/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 84.95batch/s, loss=0.00307]


Epoch [24/100], Loss: 0.00307046


Epoch [25/100]: 100%|████████████████████████████| 516/516 [00:05<00:00, 86.73batch/s, loss=0.00305]


Epoch [25/100], Loss: 0.00305088


Epoch [26/100]: 100%|████████████████████████████| 516/516 [00:05<00:00, 89.18batch/s, loss=0.00303]


Epoch [26/100], Loss: 0.00303221


Epoch [27/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 84.34batch/s, loss=0.00302]


Epoch [27/100], Loss: 0.00301817


Epoch [28/100]: 100%|██████████████████████████████| 516/516 [00:05<00:00, 88.84batch/s, loss=0.003]


Epoch [28/100], Loss: 0.00299918


Epoch [29/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 75.05batch/s, loss=0.00298]


Epoch [29/100], Loss: 0.00297872


Epoch [30/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 75.59batch/s, loss=0.00297]


Epoch [30/100], Loss: 0.00296866


Epoch [31/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 82.34batch/s, loss=0.00295]


Epoch [31/100], Loss: 0.00295135


Epoch [32/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 76.77batch/s, loss=0.00294]


Epoch [32/100], Loss: 0.00293739


Epoch [33/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 75.01batch/s, loss=0.00293]


Epoch [33/100], Loss: 0.00292724


Epoch [34/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 82.66batch/s, loss=0.00291]


Epoch [34/100], Loss: 0.00291285


Epoch [35/100]: 100%|█████████████████████████████| 516/516 [00:05<00:00, 90.50batch/s, loss=0.0029]


Epoch [35/100], Loss: 0.00290254


Epoch [36/100]: 100%|█████████████████████████████| 516/516 [00:06<00:00, 81.54batch/s, loss=0.0029]


Epoch [36/100], Loss: 0.00289925


Epoch [37/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 79.92batch/s, loss=0.00288]


Epoch [37/100], Loss: 0.00288321


Epoch [38/100]: 100%|████████████████████████████| 516/516 [00:05<00:00, 90.79batch/s, loss=0.00287]


Epoch [38/100], Loss: 0.00287334


Epoch [39/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 83.53batch/s, loss=0.00287]


Epoch [39/100], Loss: 0.00286686


Epoch [40/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 78.86batch/s, loss=0.00285]


Epoch [40/100], Loss: 0.00285262


Epoch [41/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 77.67batch/s, loss=0.00285]


Epoch [41/100], Loss: 0.00284603


Epoch [42/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 84.97batch/s, loss=0.00284]


Epoch [42/100], Loss: 0.00283819


Epoch [43/100]: 100%|████████████████████████████| 516/516 [00:05<00:00, 86.25batch/s, loss=0.00283]


Epoch [43/100], Loss: 0.00283119


Epoch [44/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 79.35batch/s, loss=0.00282]


Epoch [44/100], Loss: 0.00282219


Epoch [45/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 80.11batch/s, loss=0.00282]


Epoch [45/100], Loss: 0.00281581


Epoch [46/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 83.87batch/s, loss=0.00281]


Epoch [46/100], Loss: 0.00280975


Epoch [47/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 80.99batch/s, loss=0.00281]


Epoch [47/100], Loss: 0.00281039


Epoch [48/100]: 100%|█████████████████████████████| 516/516 [00:06<00:00, 84.75batch/s, loss=0.0028]


Epoch [48/100], Loss: 0.00279881


Epoch [49/100]: 100%|████████████████████████████| 516/516 [00:05<00:00, 88.69batch/s, loss=0.00279]


Epoch [49/100], Loss: 0.00279200


Epoch [50/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 74.09batch/s, loss=0.00278]


Epoch [50/100], Loss: 0.00278221


Epoch [51/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 78.21batch/s, loss=0.00278]


Epoch [51/100], Loss: 0.00278058


Epoch [52/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 77.42batch/s, loss=0.00278]


Epoch [52/100], Loss: 0.00277632


Epoch [53/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 78.38batch/s, loss=0.00277]


Epoch [53/100], Loss: 0.00276931


Epoch [54/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 78.89batch/s, loss=0.00277]


Epoch [54/100], Loss: 0.00276586


Epoch [55/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 80.78batch/s, loss=0.00276]


Epoch [55/100], Loss: 0.00276000


Epoch [56/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 79.06batch/s, loss=0.00275]


Epoch [56/100], Loss: 0.00275456


Epoch [57/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 78.75batch/s, loss=0.00275]


Epoch [57/100], Loss: 0.00274913


Epoch [58/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 78.09batch/s, loss=0.00275]


Epoch [58/100], Loss: 0.00275041


Epoch [59/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 78.82batch/s, loss=0.00275]


Epoch [59/100], Loss: 0.00274709


Epoch [60/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 77.88batch/s, loss=0.00275]


Epoch [60/100], Loss: 0.00274672


Epoch [61/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 81.11batch/s, loss=0.00274]


Epoch [61/100], Loss: 0.00273884


Epoch [62/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 77.51batch/s, loss=0.00273]


Epoch [62/100], Loss: 0.00273146


Epoch [63/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 76.69batch/s, loss=0.00273]


Epoch [63/100], Loss: 0.00273051


Epoch [64/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 75.63batch/s, loss=0.00272]


Epoch [64/100], Loss: 0.00272427


Epoch [65/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 77.07batch/s, loss=0.00272]


Epoch [65/100], Loss: 0.00272203


Epoch [66/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 74.69batch/s, loss=0.00272]


Epoch [66/100], Loss: 0.00271936


Epoch [67/100]: 100%|████████████████████████████| 516/516 [00:07<00:00, 72.48batch/s, loss=0.00272]


Epoch [67/100], Loss: 0.00271793


Epoch [68/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 74.41batch/s, loss=0.00272]


Epoch [68/100], Loss: 0.00271687


Epoch [69/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 74.83batch/s, loss=0.00272]


Epoch [69/100], Loss: 0.00272001


Epoch [70/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 75.97batch/s, loss=0.00272]


Epoch [70/100], Loss: 0.00271543


Epoch [71/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 79.20batch/s, loss=0.00271]


Epoch [71/100], Loss: 0.00270762


Epoch [72/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 78.69batch/s, loss=0.00271]


Epoch [72/100], Loss: 0.00270613


Epoch [73/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 75.56batch/s, loss=0.00271]


Epoch [73/100], Loss: 0.00270557


Epoch [74/100]: 100%|█████████████████████████████| 516/516 [00:06<00:00, 75.62batch/s, loss=0.0027]


Epoch [74/100], Loss: 0.00270233


Epoch [75/100]: 100%|█████████████████████████████| 516/516 [00:06<00:00, 76.59batch/s, loss=0.0027]


Epoch [75/100], Loss: 0.00269875


Epoch [76/100]: 100%|█████████████████████████████| 516/516 [00:06<00:00, 79.53batch/s, loss=0.0027]


Epoch [76/100], Loss: 0.00269706


Epoch [77/100]: 100%|█████████████████████████████| 516/516 [00:06<00:00, 78.54batch/s, loss=0.0027]


Epoch [77/100], Loss: 0.00269552


Epoch [78/100]: 100%|████████████████████████████| 516/516 [00:05<00:00, 86.03batch/s, loss=0.00269]


Epoch [78/100], Loss: 0.00269368


Epoch [79/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 83.18batch/s, loss=0.00269]


Epoch [79/100], Loss: 0.00269264


Epoch [80/100]: 100%|████████████████████████████| 516/516 [00:07<00:00, 73.23batch/s, loss=0.00269]


Epoch [80/100], Loss: 0.00268743


Epoch [81/100]: 100%|████████████████████████████| 516/516 [00:07<00:00, 73.24batch/s, loss=0.00269]


Epoch [81/100], Loss: 0.00268523


Epoch [82/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 79.91batch/s, loss=0.00269]


Epoch [82/100], Loss: 0.00268683


Epoch [83/100]: 100%|████████████████████████████| 516/516 [00:07<00:00, 70.17batch/s, loss=0.00269]


Epoch [83/100], Loss: 0.00268571


Epoch [84/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 74.56batch/s, loss=0.00268]


Epoch [84/100], Loss: 0.00268020


Epoch [85/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 79.31batch/s, loss=0.00268]


Epoch [85/100], Loss: 0.00268136


Epoch [86/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 76.60batch/s, loss=0.00268]


Epoch [86/100], Loss: 0.00267745


Epoch [87/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 77.71batch/s, loss=0.00267]


Epoch [87/100], Loss: 0.00267422


Epoch [88/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 84.40batch/s, loss=0.00268]


Epoch [88/100], Loss: 0.00267874


Epoch [89/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 79.89batch/s, loss=0.00267]


Epoch [89/100], Loss: 0.00267197


Epoch [90/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 78.81batch/s, loss=0.00268]


Epoch [90/100], Loss: 0.00267512


Epoch [91/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 85.02batch/s, loss=0.00267]


Epoch [91/100], Loss: 0.00266712


Epoch [92/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 85.68batch/s, loss=0.00267]


Epoch [92/100], Loss: 0.00266685


Epoch [93/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 85.31batch/s, loss=0.00267]


Epoch [93/100], Loss: 0.00266861


Epoch [94/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 79.81batch/s, loss=0.00266]


Epoch [94/100], Loss: 0.00266470


Epoch [95/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 77.66batch/s, loss=0.00267]


Epoch [95/100], Loss: 0.00266581


Epoch [96/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 74.66batch/s, loss=0.00267]


Epoch [96/100], Loss: 0.00266769


Epoch [97/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 83.09batch/s, loss=0.00266]


Epoch [97/100], Loss: 0.00266013


Epoch [98/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 78.95batch/s, loss=0.00266]


Epoch [98/100], Loss: 0.00266091


Epoch [99/100]: 100%|████████████████████████████| 516/516 [00:06<00:00, 79.75batch/s, loss=0.00266]


Epoch [99/100], Loss: 0.00266052


Epoch [100/100]: 100%|███████████████████████████| 516/516 [00:05<00:00, 86.09batch/s, loss=0.00266]


Epoch [100/100], Loss: 0.00265777


In [11]:
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm

def predict(classifier, test_df, term_list, batch_size=1000, esm_batch_size=8, output_file="predictions.csv"):
	"""
	Predict GO terms for a large set of proteins.
	Saves predictions incrementally to CSV to avoid memory issues.
			
	Parameters:
		classifier : trained multi-output classifier
		test_df   : DataFrame with 'entry_id' and 'sequence'
		term_list  : list of terms used during training
		batch_size : number of proteins per batch
		esm_batch_size : batch size for ESM embeddings
		output_file: path to save predictions CSV
	"""

	classifier.eval()
	# Initialize CSV with headers
	columns = ['entry_id', 'GO_term', 'probability']
	pd.DataFrame(columns=columns).to_csv(output_file, index=False)
	test_df = test_df.sort_values(by='sequence', key=lambda x: x.str.len(), ascending=False)
	# Process in batches
	num_batches = (len(test_df) - 1) // batch_size + 1
	for i in range(num_batches):
		batch_df = test_df.iloc[i*batch_size : (i+1)*batch_size]
		batch_dataset = ProteinDataset(batch_df)
		
		print(f"Processing batch {i+1}/{num_batches} ({len(batch_df)} proteins)...")
		
		# Extract embeddings
		cache_file = f"generated/esm_650m/esm_embeddings_test_{i}.pkl"
		batch_embeddings = get_esm_embeddings(batch_dataset, batch_size=esm_batch_size, show_progress=True, cache_file=cache_file)
		batch_proteins = list(batch_embeddings.keys())
		X_batch = np.array([batch_embeddings[pid] for pid in batch_proteins])
		
		# Predictions
		X_batch_tensor = torch.tensor(X_batch, dtype=torch.float32).to(device)
		y_pred_batch = classifier(X_batch_tensor)
		y_pred_batch = y_pred_batch.cpu()
		# Prepare rows to save
		rows = []
		for j, pid in enumerate(batch_proteins):
			prob_list = []
			for k, term in enumerate(term_list):
				prob = y_pred_batch[j, k].item()

				if (prob > 1e-4):
					# rows.append({'entry_id': pid, 'GO_term': term, 'probability': prob})
					prob_list.append((term, prob))
			sorted_prob_list = sorted(prob_list, key=lambda x: -x[1])[:25]
			for term, prob in sorted_prob_list:
					rows.append({'entry_id': pid, 'GO_term': term, 'probability': prob})

		# Append to CSV
		pd.DataFrame(rows).to_csv(output_file, mode='a', header=False, index=False)
		
	print(f"Predictions saved to {output_file}")
	return output_file

# Predict on all proteins (~225k)
predict(
    classifier=classifier,
    test_df=test_df,
    term_list=filtered_terms,
    batch_size=1000,      # process 1k proteins per loop
    esm_batch_size=4,     # ESM embeddings batch size for GPU
    output_file="protein_predictions1.csv"
)

Processing batch 1/225 (1000 proteins)...
cache file:  generated/esm_650m/esm_embeddings_test_0.pkl
Loading cached embeddings from generated/esm_650m/esm_embeddings_test_0.pkl
Processing batch 2/225 (1000 proteins)...
cache file:  generated/esm_650m/esm_embeddings_test_1.pkl
Loading cached embeddings from generated/esm_650m/esm_embeddings_test_1.pkl
Processing batch 3/225 (1000 proteins)...
cache file:  generated/esm_650m/esm_embeddings_test_2.pkl
Loading cached embeddings from generated/esm_650m/esm_embeddings_test_2.pkl
Processing batch 4/225 (1000 proteins)...
cache file:  generated/esm_650m/esm_embeddings_test_3.pkl
Loading cached embeddings from generated/esm_650m/esm_embeddings_test_3.pkl
Processing batch 5/225 (1000 proteins)...
cache file:  generated/esm_650m/esm_embeddings_test_4.pkl
Loading cached embeddings from generated/esm_650m/esm_embeddings_test_4.pkl
Processing batch 6/225 (1000 proteins)...
cache file:  generated/esm_650m/esm_embeddings_test_5.pkl
Loading cached embed

'protein_predictions1.csv'

In [12]:
import pandas as pd
import numpy as np

def prepare_cafa_submission(pred_csv, go_obo_file, output_file="submission.tsv", max_terms=1500):
    """
    - Loads raw prediction CSV (entry_id, GO_term, probability)
    - Propagates predictions up the GO hierarchy
    - Limits to max_terms per protein
    - Formats final CAFA submission file
    """
    print("\n=== Preparing CAFA Submission File ===")

    print("Loading predictions...")
    df = pd.read_csv(pred_csv)

    print("Loading GO ontology...")
    parents = {}

    # Parse go-basic.obo to build term → parents dictionary
    with open(go_obo_file, 'r') as f:
        current_term = None
        for line in f:
            line = line.strip()
            if line.startswith("id: GO:"):
                current_term = line.split("id: ")[1]
                parents[current_term] = []
            if line.startswith("is_a: GO:") and current_term:
                parent_term = line.split("is_a: ")[1].split()[0]
                parents[current_term].append(parent_term)

    print("Propagating predictions to parent terms...")
    propagated_rows = []

    for protein, group in df.groupby("entry_id"):
        term_scores = dict(zip(group.GO_term, group.probability))

        # BFS upward propagation
        queue = list(term_scores.keys())
        visited = set()

        while queue:
            term = queue.pop()
            if term in visited: 
                continue
            visited.add(term)

            if term in parents:
                for p in parents[term]:
                    new_score = term_scores.get(term, 0)
                    term_scores[p] = max(term_scores.get(p, 0), new_score)
                    queue.append(p)

        # Keep highest prob terms and limit to max_terms
        top_terms = sorted(term_scores.items(), key=lambda x: x[1], reverse=True)[:max_terms]

        for term, score in top_terms:
            if score > 0:  # CAFA requirement: no zeroes
                propagated_rows.append([protein, term, round(score, 3)])

    print("Saving formatted submission file...")
    sub = pd.DataFrame(propagated_rows, columns=["ProteinID", "GO_ID", "Score"])
    sub.to_csv(output_file, sep="\t", index=False, header=False)

    print(f"\nSubmission ready: {output_file}")
    print(f"Total predictions: {len(sub)}")
    return sub


In [13]:
prepare_cafa_submission(
    pred_csv="protein_predictions1.csv",
    go_obo_file="data/Train/go-basic.obo",
    output_file="submission.tsv"
)



=== Preparing CAFA Submission File ===
Loading predictions...
Loading GO ontology...
Propagating predictions to parent terms...
Saving formatted submission file...

Submission ready: submission.tsv
Total predictions: 18675818


Unnamed: 0,ProteinID,GO_ID,Score
0,A0A017SE81,GO:0005515,0.089
1,A0A017SE81,GO:0005488,0.089
2,A0A017SE81,GO:0003674,0.089
3,A0A017SE81,GO:0005886,0.055
4,A0A017SE81,GO:0016020,0.055
...,...,...,...
18675813,X6R8R1,GO:0019209,0.003
18675814,X6R8R1,GO:0019887,0.003
18675815,X6R8R1,GO:0019207,0.003
18675816,X6R8R1,GO:0005096,0.003


In [14]:
# kaggle competitions submit -c cafa-6-protein-function-prediction -f submission.tsv -m "esm 150m + nn 100x100"