In [None]:
# Install necessary packages
!pip install torch transformers scikit-learn pandas numpy biopython peft bitsandbytes requests optuna

In [None]:
# Install necessary tools in Google Colab
!apt-get install -y mafft
!apt-get install -y hmmer
!wget https://github.com/soedinglab/MMseqs2/releases/download/17-b804f/mmseqs-linux-gpu.tar.gz
!tar xvf mmseqs-linux-gpu.tar.gz
!chmod +x mmseqs

In [None]:
import pandas as pd
import torch
import requests
from transformers import AutoTokenizer, AutoModelForMaskedLM, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import Ridge, Lasso, ElasticNet
from sklearn.metrics import mean_squared_error
import numpy as np
from Bio import SeqIO
from Bio import AlignIO
import optuna
from torch.utils.data import Dataset, DataLoader
import time
import xml.etree.ElementTree as ET
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import subprocess

In [None]:
import requests, time, subprocess
import numpy as np
import xml.etree.ElementTree as ET
from Bio import AlignIO

WT_SEQUENCE = "MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETEIFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNYPEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQLSLSQIRDARANDQSQNHLFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEG"
BLAST_URL = "https://blast.ncbi.nlm.nih.gov/Blast.cgi"

# ======= Step 1: BLAST API for Homologs =======
def run_blast_search(wt_sequence, identity_threshold=90.0, max_retries=30, sleep_time=10, min_length=100, hitlist_size=300):
    """Run BLAST and extract homologous sequences below a given identity threshold."""
    params = {
        "CMD": "Put",
        "PROGRAM": "blastp",
        "DATABASE": "nr",
        "QUERY": wt_sequence,
        "FORMAT_TYPE": "XML",
        "EXPECT": "1e-2",
        "HITLIST_SIZE": str(hitlist_size)
    }
    response = requests.post(BLAST_URL, data=params)
    response.raise_for_status()
    response_text = response.text
    if "RID = " not in response_text:
        raise Exception("No RID found in BLAST response.")
    rid = response_text.split("RID = ")[-1].split("\n")[0].strip()
    print(f"BLAST RID: {rid}")

    # Wait for completion
    for attempt in range(max_retries):
        status = requests.get(BLAST_URL, params={"CMD":"Get", "FORMAT_OBJECT":"SearchInfo", "RID":rid})
        if "Status=READY" in status.text:
            print("BLAST complete.")
            break
        print(f"Waiting... {attempt+1}/{max_retries}")
        time.sleep(sleep_time)
    else:
        raise Exception("BLAST timed out")

    # Download results
    result = requests.get(BLAST_URL, params={"CMD":"Get", "FORMAT_TYPE":"XML", "RID":rid})
    result.raise_for_status()
    root = ET.fromstring(result.text)
    seqs = []
    for hit in root.findall(".//Hit"):
        for hsp in hit.findall(".//Hsp"):
            hseq_elem = hsp.find("Hsp_hseq")
            identity_elem = hsp.find("Hsp_identity")
            align_len_elem = hsp.find("Hsp_align-len")
            if hseq_elem is not None and identity_elem is not None and align_len_elem is not None:
                hseq = hseq_elem.text.strip()
                identity = int(identity_elem.text)
                align_len = int(align_len_elem.text)
                identity_pct = 100 * identity / align_len
                if identity_pct < identity_threshold and len(hseq) > min_length:
                    seqs.append(hseq)
    seqs = [wt_sequence] + list({s for s in seqs if s != wt_sequence})  # unique, include WT
    print(f"Total homologs: {len(seqs)}")
    # Save to FASTA
    with open("msa_input.fasta", "w") as f:
        for i, s in enumerate(seqs):
            f.write(f">seq{i}\n{s}\n")
    return "msa_input.fasta"

# ======= Step 2: Align with MAFFT =======
def run_mafft(input_fasta, output_fasta="msa_aligned.fasta"):
    print(f"Running MAFFT alignment...")
    cmd = f"mafft --auto {input_fasta} > {output_fasta}"
    subprocess.run(cmd, shell=True, check=True)
    print(f"Alignment written: {output_fasta}")
    return output_fasta

# ======= Step 3: Calculate Henikoff Weights =======
def henikoff_weights(msa_file, format="fasta"):
    alignment = AlignIO.read(msa_file, format)
    n_seq = len(alignment)
    aln_len = alignment.get_alignment_length()
    weights = np.zeros(n_seq)
    for pos in range(aln_len):
        aa_counts = {}
        for record in alignment:
            aa = record.seq[pos]
            if aa not in aa_counts:
                aa_counts[aa] = 0
            aa_counts[aa] += 1
        n_types = len(aa_counts)
        for i, record in enumerate(alignment):
            aa = record.seq[pos]
            weights[i] += 1.0 / (n_types * aa_counts[aa])
    weights /= weights.sum()
    return weights

# ======= (Optional) Jackhmmer/MMseqs2 integration (not changed here) =======

# ==== MAIN WORKFLOW ====
method = "blast"  # "jackhmmer" or "mmseqs2" possible if implemented

if method == "blast":
    msa_input = run_blast_search(WT_SEQUENCE, identity_threshold=90.0, hitlist_size=500)
elif method == "jackhmmer":
    msa_input = run_jackhmmer_search(WT_SEQUENCE)
elif method == "mmseqs2":
    msa_input = run_mmseqs2_search(WT_SEQUENCE)
else:
    raise ValueError("Invalid method chosen. Please select 'blast', 'jackhmmer', or 'mmseqs2'.")

msa_aligned = run_mafft(msa_input)

weights = henikoff_weights(msa_aligned, "fasta")
print("Sequence weights:", weights)
np.save("msa_weights.npy", weights)


In [None]:
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from torch.optim import AdamW
from peft import LoraConfig, get_peft_model
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from scipy.stats import pearsonr, spearmanr
from tqdm import tqdm
import matplotlib.pyplot as plt
import gc
import os

# --- MEMORY OPTIMIZATION SETTINGS ---
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:128"
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

def force_cleanup():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

force_cleanup()

# --- DEVICE ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    gpu_props = torch.cuda.get_device_properties(0)
    print(f"GPU: {gpu_props.name}")
    print(f"Total memory: {gpu_props.total_memory/1024**3:.2f} GB")

# --- Load and preprocess data ---
url = "https://figshare.com/ndownloader/files/7337543"
df = pd.read_csv(url, sep='\t')
df.rename(columns={'mutation': 'mutation_string', 'normalized_fitness': 'fitness'}, inplace=True)
df['fitness'] = pd.to_numeric(df['fitness'], errors='coerce')
df.dropna(subset=['fitness'], inplace=True)

# Use moderate sample size for stable training
df = df.sample(n=1000, random_state=42).reset_index(drop=True)
print(f"Using {len(df)} samples for training")

# Check fitness distribution and normalize
print(f"Fitness range: {df['fitness'].min():.4f} to {df['fitness'].max():.4f}")
print(f"Fitness mean: {df['fitness'].mean():.4f}, std: {df['fitness'].std():.4f}")

# CRITICAL FIX: Normalize fitness values to prevent extreme values
fitness_mean = df['fitness'].mean()
fitness_std = df['fitness'].std()
df['fitness_normalized'] = (df['fitness'] - fitness_mean) / fitness_std
print(f"Normalized fitness range: {df['fitness_normalized'].min():.4f} to {df['fitness_normalized'].max():.4f}")

WT_SEQUENCE = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLTYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"

def generate_mutated_sequence(wt_sequence, mutation_string):
    seq_list = list(wt_sequence)
    mutations = mutation_string.split(',')
    for mut in mutations:
        mut = mut.strip()
        if len(mut) >= 3 and mut[1:-1].isdigit():
            pos = int(mut[1:-1]) - 1
            if 0 <= pos < len(seq_list) and mut[-1] != '*':
                seq_list[pos] = mut[-1]
    return ''.join(seq_list)

df['mutated_sequence'] = df['mutation_string'].apply(lambda x: generate_mutated_sequence(WT_SEQUENCE, x))

# CRITICAL FIX: Proper sequence formatting for ProtBERT
def format_protein_sequence(sequence):
    """Format protein sequence with spaces between amino acids as expected by ProtBERT"""
    return ' '.join(list(sequence))

df['formatted_sequence'] = df['mutated_sequence'].apply(format_protein_sequence)

# --- Load tokenizer and model ---
MODEL_NAME = "Rostlab/prot_bert"
print("Loading tokenizer...")
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)

force_cleanup()

print("Loading model...")
base_model = BertModel.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,
    low_cpu_mem_usage=True
)

# --- FIXED LoRA config based on research ---
lora_config = LoraConfig(
    r=8,  # Increased rank for better expressiveness
    lora_alpha=32,  # Higher alpha for stronger adaptation
    lora_dropout=0.1,
    target_modules=["query", "key", "value"],  # All attention components
    bias="none"
)

print("Applying LoRA...")
base_model = get_peft_model(base_model, lora_config)
base_model = base_model.to(device)

force_cleanup()

# --- FIXED Dataset ---
class ProteinFitnessDataset(Dataset):
    def __init__(self, dataframe, tokenizer, target_col, max_length=512):
        self.sequences = dataframe['formatted_sequence'].tolist()  # Use formatted sequences
        self.targets = dataframe[target_col].values.astype(np.float32)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Debug: Print first few sequences
        print(f"First sequence example: {self.sequences[0][:100]}...")
        print(f"Target range: {self.targets.min():.4f} to {self.targets.max():.4f}")
    
    def __len__(self): 
        return len(self.sequences)
    
    def __getitem__(self, idx):
        seq = self.sequences[idx]
        
        # CRITICAL: Ensure sequence is properly tokenized
        tokenized = self.tokenizer(
            seq, 
            padding='max_length', 
            truncation=True, 
            max_length=self.max_length, 
            return_tensors="pt",
            add_special_tokens=True  # Ensure [CLS] and [SEP] tokens
        )
        
        return {
            'input_ids': tokenized['input_ids'].squeeze(0),
            'attention_mask': tokenized['attention_mask'].squeeze(0),
            'labels': torch.tensor(self.targets[idx], dtype=torch.float32)
        }

# Split data with stratification to ensure balanced distribution
train_df, test_df = train_test_split(df, test_size=0.3, random_state=42, 
                                   stratify=pd.cut(df['fitness_normalized'], bins=5, labels=False))
print(f"Train size: {len(train_df)}, Test size: {len(test_df)}")

# Create datasets using normalized fitness
train_dataset = ProteinFitnessDataset(train_df, tokenizer, 'fitness_normalized')
test_dataset = ProteinFitnessDataset(test_df, tokenizer, 'fitness_normalized')

# Use batch size of 4 for better gradient estimates
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=4, num_workers=0)

# --- FIXED Regression Head ---
class ProtBERTRegressionHead(torch.nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder
        
        # CRITICAL FIX: Better architecture for regression
        hidden_size = self.encoder.config.hidden_size
        self.dropout = torch.nn.Dropout(0.3)
        
        self.regressor = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, 512),
            torch.nn.LayerNorm(512),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(512, 128),
            torch.nn.LayerNorm(128),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(128, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 1)
        )
        
        # CRITICAL: Proper initialization
        for module in self.regressor:
            if isinstance(module, torch.nn.Linear):
                torch.nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
    
    def forward(self, input_ids, attention_mask, labels=None):
        # Get encoder outputs
        with torch.cuda.amp.autocast():  # Use mixed precision
            encoder_outputs = self.encoder(
                input_ids=input_ids, 
                attention_mask=attention_mask,
                output_hidden_states=True
            )
        
        # CRITICAL FIX: Better pooling strategy
        hidden_states = encoder_outputs.last_hidden_state
        
        # Use mean pooling over sequence length, weighted by attention mask
        attention_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
        sum_embeddings = torch.sum(hidden_states * attention_mask_expanded, 1)
        sum_mask = torch.clamp(attention_mask_expanded.sum(1), min=1e-9)
        pooled_output = sum_embeddings / sum_mask
        
        pooled_output = self.dropout(pooled_output)
        predictions = self.regressor(pooled_output).squeeze(-1)
        
        loss = None
        if labels is not None:
            # Use smooth L1 loss for more stable training
            loss_fn = torch.nn.SmoothL1Loss(beta=1.0)
            loss = loss_fn(predictions, labels)
        
        return predictions, loss

print("Creating regression model...")
regression_model = ProtBERTRegressionHead(base_model).to(device)

# CRITICAL FIX: Much smaller learning rate for stability
optimizer = AdamW(regression_model.parameters(), lr=1e-5, weight_decay=0.01, eps=1e-8)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-7)

# Use GradScaler for mixed precision training
scaler = torch.cuda.amp.GradScaler()

epochs = 10  # More epochs for better convergence

def train_epoch(model, loader, optimizer, scheduler, scaler):
    model.train()
    total_loss = 0.0
    num_batches = 0
    
    for batch in tqdm(loader, desc="Training"):
        try:
            optimizer.zero_grad()
            
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            # Forward pass with mixed precision
            with torch.cuda.amp.autocast():
                predictions, loss = model(input_ids, attention_mask, labels)
            
            # Check for NaN loss
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"NaN/Inf loss detected: {loss.item()}")
                continue
            
            # Backward pass with gradient scaling
            scaler.scale(loss).backward()
            
            # Gradient clipping
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            scaler.step(optimizer)
            scaler.update()
            
            total_loss += loss.item()
            num_batches += 1
            
        except RuntimeError as e:
            print(f"Error in training batch: {e}")
            force_cleanup()
            continue
    
    scheduler.step()
    return total_loss / max(num_batches, 1)

def eval_epoch(model, loader):
    model.eval()
    y_true, y_pred = [], []
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            try:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"]
                
                with torch.cuda.amp.autocast():
                    predictions, _ = model(input_ids, attention_mask)
                
                y_true.extend(labels.cpu().numpy().tolist())
                y_pred.extend(predictions.cpu().numpy().tolist())
                
            except RuntimeError as e:
                print(f"Error in evaluation batch: {e}")
                continue
    
    if len(y_true) < 2:
        return 0, 0, 0, 0, 0, y_true, y_pred
    
    # Convert to numpy arrays and clean data
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    
    # Remove NaN/Inf values
    valid_mask = ~(np.isnan(y_true) | np.isnan(y_pred) | np.isinf(y_true) | np.isinf(y_pred))
    y_true = y_true[valid_mask]
    y_pred = y_pred[valid_mask]
    
    if len(y_true) < 2:
        return 0, 0, 0, 0, 0, y_true.tolist(), y_pred.tolist()
    
    try:
        mse = mean_squared_error(y_true, y_pred)
        mae = mean_absolute_error(y_true, y_pred)
        r2 = r2_score(y_true, y_pred)
        
        # Both Pearson and Spearman correlations
        if np.std(y_pred) > 1e-10 and np.std(y_true) > 1e-10:
            pearson = pearsonr(y_true, y_pred)[0]
            spearman = spearmanr(y_true, y_pred)[0]
        else:
            pearson = spearman = 0.0
            
    except Exception as e:
        print(f"Error calculating metrics: {e}")
        mse = mae = r2 = pearson = spearman = 0.0
    
    return mse, mae, r2, pearson, spearman, y_true.tolist(), y_pred.tolist()

# --- Training loop with better monitoring ---
print("Starting training...")
best_pearson = -1.0
patience = 3
patience_counter = 0

for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    
    train_loss = train_epoch(regression_model, train_loader, optimizer, scheduler, scaler)
    
    if np.isnan(train_loss) or np.isinf(train_loss):
        print("Training loss is NaN/Inf. Stopping training.")
        break
    
    mse, mae, r2, pearson, spearman, y_true, y_pred = eval_epoch(regression_model, test_loader)
    
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Test MSE: {mse:.4f}, MAE: {mae:.4f}, R²: {r2:.4f}")
    print(f"Pearson: {pearson:.4f}, Spearman: {spearman:.4f}")
    print(f"Learning Rate: {scheduler.get_last_lr()[0]:.2e}")
    
    # Early stopping based on Pearson correlation
    if pearson > best_pearson:
        best_pearson = pearson
        patience_counter = 0
        print(f"New best Pearson correlation: {best_pearson:.4f}")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping: no improvement for {patience} epochs")
            break
    
    force_cleanup()

# --- Final evaluation and plotting ---
print("\n" + "="*50)
print("FINAL EVALUATION")
print("="*50)

if 'y_true' in locals() and len(y_true) > 1:
    # Convert back to original scale for interpretation
    y_true_orig = np.array(y_true) * fitness_std + fitness_mean
    y_pred_orig = np.array(y_pred) * fitness_std + fitness_mean
    
    # Calculate metrics on original scale
    mse_orig = mean_squared_error(y_true_orig, y_pred_orig)
    mae_orig = mean_absolute_error(y_true_orig, y_pred_orig)
    r2_orig = r2_score(y_true_orig, y_pred_orig)
    pearson_orig = pearsonr(y_true_orig, y_pred_orig)[0] if len(y_true_orig) > 1 else 0
    
    print(f"Final Results (Original Scale):")
    print(f"  Test samples: {len(y_true_orig)}")
    print(f"  MSE: {mse_orig:.4f}")
    print(f"  MAE: {mae_orig:.4f}")
    print(f"  R²: {r2_orig:.4f}")
    print(f"  Pearson: {pearson_orig:.4f}")
    
    # Plot results
    plt.figure(figsize=(10, 8))
    
    # Plot on original scale
    plt.subplot(2, 2, 1)
    plt.scatter(y_true_orig, y_pred_orig, alpha=0.6, s=30)
    min_val = min(min(y_true_orig), min(y_pred_orig))
    max_val = max(max(y_true_orig), max(y_pred_orig))
    plt.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8)
    plt.xlabel("True Fitness")
    plt.ylabel("Predicted Fitness")
    plt.title(f"Original Scale (R²={r2_orig:.3f}, ρ={pearson_orig:.3f})")
    plt.grid(True, alpha=0.3)
    
    # Plot on normalized scale
    plt.subplot(2, 2, 2)
    plt.scatter(y_true, y_pred, alpha=0.6, s=30)
    min_val = min(min(y_true), min(y_pred))
    max_val = max(max(y_true), max(y_pred))
    plt.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8)
    plt.xlabel("True Fitness (Normalized)")
    plt.ylabel("Predicted Fitness (Normalized)")
    plt.title(f"Normalized Scale (R²={r2:.3f}, ρ={pearson:.3f})")
    plt.grid(True, alpha=0.3)
    
    # Residual plot
    plt.subplot(2, 2, 3)
    residuals = np.array(y_pred_orig) - np.array(y_true_orig)
    plt.scatter(y_true_orig, residuals, alpha=0.6, s=30)
    plt.axhline(y=0, color='r', linestyle='--', alpha=0.8)
    plt.xlabel("True Fitness")
    plt.ylabel("Residuals")
    plt.title("Residual Plot")
    plt.grid(True, alpha=0.3)
    
    # Distribution comparison
    plt.subplot(2, 2, 4)
    plt.hist(y_true_orig, bins=20, alpha=0.5, label='True', density=True)
    plt.hist(y_pred_orig, bins=20, alpha=0.5, label='Predicted', density=True)
    plt.xlabel("Fitness")
    plt.ylabel("Density")
    plt.title("Distribution Comparison")
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Show example predictions
    print(f"\nExample predictions (original scale):")
    indices = np.random.choice(len(y_true_orig), min(10, len(y_true_orig)), replace=False)
    for i in indices:
        print(f"  True: {y_true_orig[i]:.4f}, Pred: {y_pred_orig[i]:.4f}, Error: {abs(y_true_orig[i] - y_pred_orig[i]):.4f}")
else:
    print("No valid results to display.")

force_cleanup()
print(f"Final GPU memory usage: {torch.cuda.memory_allocated()/1024**3:.2f} GB" if torch.cuda.is_available() else "CPU mode")