In [None]:
# Install necessary packages
!pip install torch transformers scikit-learn pandas numpy biopython peft bitsandbytes requests optuna

In [None]:
# Install necessary tools in Google Colab
!apt-get install -y mafft
!apt-get install -y hmmer
!wget https://github.com/soedinglab/MMseqs2/releases/download/17-b804f/mmseqs-linux-gpu.tar.gz
!tar xvf mmseqs-linux-gpu.tar.gz
!chmod +x mmseqs

In [None]:
import pandas as pd
import torch
import requests
from transformers import AutoTokenizer, AutoModelForMaskedLM, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import Ridge, Lasso, ElasticNet
from sklearn.metrics import mean_squared_error
import numpy as np
from Bio import SeqIO
from Bio import AlignIO
import optuna
from torch.utils.data import Dataset, DataLoader
import time
import xml.etree.ElementTree as ET
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import subprocess

In [None]:
import requests, time, subprocess
import numpy as np
import xml.etree.ElementTree as ET
from Bio import AlignIO

WT_SEQUENCE = "MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETEIFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNYPEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQLSLSQIRDARANDQSQNHLFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEG"
BLAST_URL = "https://blast.ncbi.nlm.nih.gov/Blast.cgi"

# ======= Step 1: BLAST API for Homologs =======
def run_blast_search(wt_sequence, identity_threshold=90.0, max_retries=30, sleep_time=10, min_length=100, hitlist_size=300):
    """Run BLAST and extract homologous sequences below a given identity threshold."""
    params = {
        "CMD": "Put",
        "PROGRAM": "blastp",
        "DATABASE": "nr",
        "QUERY": wt_sequence,
        "FORMAT_TYPE": "XML",
        "EXPECT": "1e-2",
        "HITLIST_SIZE": str(hitlist_size)
    }
    response = requests.post(BLAST_URL, data=params)
    response.raise_for_status()
    response_text = response.text
    if "RID = " not in response_text:
        raise Exception("No RID found in BLAST response.")
    rid = response_text.split("RID = ")[-1].split("\n")[0].strip()
    print(f"BLAST RID: {rid}")

    # Wait for completion
    for attempt in range(max_retries):
        status = requests.get(BLAST_URL, params={"CMD":"Get", "FORMAT_OBJECT":"SearchInfo", "RID":rid})
        if "Status=READY" in status.text:
            print("BLAST complete.")
            break
        print(f"Waiting... {attempt+1}/{max_retries}")
        time.sleep(sleep_time)
    else:
        raise Exception("BLAST timed out")

    # Download results
    result = requests.get(BLAST_URL, params={"CMD":"Get", "FORMAT_TYPE":"XML", "RID":rid})
    result.raise_for_status()
    root = ET.fromstring(result.text)
    seqs = []
    for hit in root.findall(".//Hit"):
        for hsp in hit.findall(".//Hsp"):
            hseq_elem = hsp.find("Hsp_hseq")
            identity_elem = hsp.find("Hsp_identity")
            align_len_elem = hsp.find("Hsp_align-len")
            if hseq_elem is not None and identity_elem is not None and align_len_elem is not None:
                hseq = hseq_elem.text.strip()
                identity = int(identity_elem.text)
                align_len = int(align_len_elem.text)
                identity_pct = 100 * identity / align_len
                if identity_pct < identity_threshold and len(hseq) > min_length:
                    seqs.append(hseq)
    seqs = [wt_sequence] + list({s for s in seqs if s != wt_sequence})  # unique, include WT
    print(f"Total homologs: {len(seqs)}")
    # Save to FASTA
    with open("msa_input.fasta", "w") as f:
        for i, s in enumerate(seqs):
            f.write(f">seq{i}\n{s}\n")
    return "msa_input.fasta"

# ======= Step 2: Align with MAFFT =======
def run_mafft(input_fasta, output_fasta="msa_aligned.fasta"):
    print(f"Running MAFFT alignment...")
    cmd = f"mafft --auto {input_fasta} > {output_fasta}"
    subprocess.run(cmd, shell=True, check=True)
    print(f"Alignment written: {output_fasta}")
    return output_fasta

# ======= Step 3: Calculate Henikoff Weights =======
def henikoff_weights(msa_file, format="fasta"):
    alignment = AlignIO.read(msa_file, format)
    n_seq = len(alignment)
    aln_len = alignment.get_alignment_length()
    weights = np.zeros(n_seq)
    for pos in range(aln_len):
        aa_counts = {}
        for record in alignment:
            aa = record.seq[pos]
            if aa not in aa_counts:
                aa_counts[aa] = 0
            aa_counts[aa] += 1
        n_types = len(aa_counts)
        for i, record in enumerate(alignment):
            aa = record.seq[pos]
            weights[i] += 1.0 / (n_types * aa_counts[aa])
    weights /= weights.sum()
    return weights

# ======= (Optional) Jackhmmer/MMseqs2 integration (not changed here) =======

# ==== MAIN WORKFLOW ====
method = "blast"  # "jackhmmer" or "mmseqs2" possible if implemented

if method == "blast":
    msa_input = run_blast_search(WT_SEQUENCE, identity_threshold=90.0, hitlist_size=500)
elif method == "jackhmmer":
    msa_input = run_jackhmmer_search(WT_SEQUENCE)
elif method == "mmseqs2":
    msa_input = run_mmseqs2_search(WT_SEQUENCE)
else:
    raise ValueError("Invalid method chosen. Please select 'blast', 'jackhmmer', or 'mmseqs2'.")

msa_aligned = run_mafft(msa_input)

weights = henikoff_weights(msa_aligned, "fasta")
print("Sequence weights:", weights)
np.save("msa_weights.npy", weights)


In [None]:
# ======================== 1. Required Imports ======================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
import re
import os
from scipy import stats
from tqdm import tqdm
from transformers import (
    T5EncoderModel, T5Tokenizer, T5PreTrainedModel, T5Config,
    TrainingArguments, Trainer, set_seed
)
from transformers.modeling_outputs import SequenceClassifierOutput
from torch.nn import MSELoss
from torch.utils.data import DataLoader
from datasets import Dataset

# ======================== 2. Load and preprocess data =============================
# Use SSM fitness data (from figshare)
url = "https://figshare.com/ndownloader/files/7337543"
df = pd.read_csv(url, sep='\t')
df.rename(columns={'mutation': 'mutation_string', 'normalized_fitness': 'fitness'}, inplace=True)
df['fitness'] = pd.to_numeric(df['fitness'], errors='coerce')
df.dropna(subset=['fitness'], inplace=True)

# --- Add weights ---
# Here, we'll add inverse frequency as weights, just as an example
msa_weights = np.load("msa_weights.npy")
wt_weight = msa_weights[0]  # First sequence is the wild-type

# Add this weight to every row in your DataFrame:
df['weight'] = wt_weight

# --- Generate mutated sequences from wildtype + mutation string ---
WT_SEQUENCE = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLTYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"
def generate_mutated_sequence(wt_sequence, mutation_string):
    seq_list = list(wt_sequence)
    mutations = mutation_string.split(',')
    for mut in mutations:
        mut = mut.strip()
        if len(mut) >= 3 and mut[1:-1].isdigit():
            pos = int(mut[1:-1]) - 1
            if 0 <= pos < len(seq_list) and mut[-1] != '*':
                seq_list[pos] = mut[-1]
    return ''.join(seq_list)
df['sequence'] = df['mutation_string'].apply(lambda x: generate_mutated_sequence(WT_SEQUENCE, x))

# Format for T5: whitespace between each AA
def format_protein_sequence(sequence):
    return ' '.join(list(sequence))
df['sequence'] = df['sequence'].apply(format_protein_sequence)

# ======================== 3. Train/val/test split ================================
# (If you want stratify: stratify=df['fitness'] > df['fitness'].median())
from sklearn.model_selection import train_test_split
my_train, temp = train_test_split(df, test_size=0.3, random_state=42)
my_valid, my_test = train_test_split(temp, test_size=0.5, random_state=42)
print(f"Train: {len(my_train)}, Val: {len(my_valid)}, Test: {len(my_test)}")
print(my_train[['sequence', 'fitness', 'weight']].head())

# ======================== 4. Model & Tokenizer (improved) ========================
class LoRAConfig:
    def __init__(self):
        self.lora_rank = 2
        self.lora_init_scale = 0.01
        self.lora_modules = r".*SelfAttention|.*EncDecAttention"
        self.lora_layers = r"q|k|v|o"
        self.trainable_param_names = r".*layer_norm.*|.*lora_[ab].*|.*classifier.*|.*block\.23.*"  # Unfreeze head & last block
        self.lora_scaling_rank = 1

class LoRALinear(nn.Module):
    def __init__(self, linear_layer, rank, scaling_rank, init_scale):
        super().__init__()
        self.in_features = linear_layer.in_features
        self.out_features = linear_layer.out_features
        self.rank = rank
        self.scaling_rank = scaling_rank
        self.weight = linear_layer.weight
        self.bias = linear_layer.bias
        if self.rank > 0:
            self.lora_a = nn.Parameter(torch.randn(rank, linear_layer.in_features) * init_scale)
            self.lora_b = nn.Parameter(torch.zeros(linear_layer.out_features, rank))
        if self.scaling_rank:
            self.multi_lora_a = nn.Parameter(
                torch.ones(self.scaling_rank, linear_layer.in_features)
                + torch.randn(self.scaling_rank, linear_layer.in_features) * init_scale
            )
            self.multi_lora_b = nn.Parameter(torch.ones(linear_layer.out_features, self.scaling_rank))
    def forward(self, input):
        weight = self.weight
        if self.scaling_rank:
            weight = weight * torch.matmul(self.multi_lora_b, self.multi_lora_a) / self.scaling_rank
        if self.rank:
            weight = weight + torch.matmul(self.lora_b, self.lora_a) / self.rank
        return F.linear(input, weight, self.bias)

def modify_with_lora(transformer, config):
    for m_name, module in dict(transformer.named_modules()).items():
        if re.fullmatch(config.lora_modules, m_name):
            for c_name, layer in dict(module.named_children()).items():
                if re.fullmatch(config.lora_layers, c_name):
                    if isinstance(layer, nn.Linear):
                        setattr(module, c_name,
                            LoRALinear(layer, config.lora_rank, config.lora_scaling_rank, config.lora_init_scale))
    return transformer

class ClassConfig:
    def __init__(self, dropout=0.2, num_labels=1):
        self.dropout_rate = dropout
        self.num_labels = num_labels

class T5EncoderClassificationHead(nn.Module):
    def __init__(self, hidden_size, class_config):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(class_config.dropout_rate)
        self.out_proj = nn.Linear(hidden_size, class_config.num_labels)
    def forward(self, hidden_states):
        pooled = hidden_states[:, 0, :]  # first token ([CLS]-like)
        pooled = self.dropout(pooled)
        pooled = self.dense(pooled)
        pooled = torch.tanh(pooled)
        pooled = self.dropout(pooled)
        logits = self.out_proj(pooled)
        return logits

class T5EncoderForSimpleSequenceRegression(T5PreTrainedModel):
    def __init__(self, encoder_model, config: T5Config, class_config):
        super().__init__(config)
        self.num_labels = class_config.num_labels
        self.config = config
        self.encoder_model = encoder_model
        self.dropout = nn.Dropout(class_config.dropout_rate)
        self.classifier = T5EncoderClassificationHead(config.d_model, class_config)
        self.post_init()
        self.model_parallel = False
        self.device_map = None
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        outputs = self.encoder_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = outputs.last_hidden_state
        logits = self.classifier(hidden_states)
        loss = None
        if labels is not None:
            loss_fct = MSELoss()
            loss = loss_fct(logits.squeeze(), labels.squeeze())
        if not return_dict:
            output = (logits,) + (outputs[1:] if isinstance(outputs, tuple) else ())
            return ((loss,) + output) if loss is not None else output
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=getattr(outputs, "hidden_states", None),
            attentions=getattr(outputs, "attentions", None),
        )

def PT5_regression_model(num_labels=1, half_precision=False):
    encoder_model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
    tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", legacy=False)
    class_config = ClassConfig(num_labels=num_labels)
    class_model = T5EncoderForSimpleSequenceRegression(encoder_model, encoder_model.config, class_config)
    config = LoRAConfig()
    class_model.encoder_model = modify_with_lora(class_model.encoder_model, config)
    for name, param in class_model.named_parameters():
        if re.fullmatch(config.trainable_param_names, name):
            param.requires_grad = True
        else:
            param.requires_grad = False
    return class_model, tokenizer

def set_seeds(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    set_seed(seed)

def create_dataset(tokenizer, seqs, labels, weights=None, max_length=128):
    seqs_spaced = [s for s in seqs]
    tokenized = tokenizer(seqs_spaced, max_length=max_length, padding='max_length', truncation=True)
    dataset = Dataset.from_dict(tokenized)
    dataset = dataset.add_column("labels", labels)
    if weights is not None:
        dataset = dataset.add_column("weights", weights)
    return dataset

def train_per_protein(
    train_df, 
    valid_df, 
    num_labels=1, 
    batch=4,
    accum=1,
    val_batch=16,
    epochs=10,
    lr=2e-4,
    wd=0.01,
    seed=42, 
    mixed=False,
):
    set_seeds(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model, tokenizer = PT5_regression_model(num_labels=num_labels, half_precision=mixed)
    model.to(device)
    train_set = create_dataset(tokenizer, list(train_df["sequence"]), list(train_df["fitness"]), list(train_df["weight"]), max_length=128)
    valid_set = create_dataset(tokenizer, list(valid_df["sequence"]), list(valid_df["fitness"]), list(valid_df["weight"]), max_length=128)

    args = TrainingArguments(
        output_dir="./outputs",
        do_train=True,
        do_eval=True,
        per_device_train_batch_size=batch,
        per_device_eval_batch_size=val_batch,
        gradient_accumulation_steps=accum,
        num_train_epochs=epochs,
        weight_decay=wd,
        learning_rate=lr,
        seed=seed,
        fp16=False,
        dataloader_num_workers=2,
        report_to="none",
        save_total_limit=1,
    )
    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = np.squeeze(predictions)
        labels = np.squeeze(labels)
        correlation = stats.spearmanr(predictions, labels).correlation
        if np.isnan(correlation):
            correlation = 0.0
        return {"spearmanr": correlation}
    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_set,
        eval_dataset=valid_set,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )
    print("Starting training...")
    trainer.train()
    print("Training completed!")
    return tokenizer, model, trainer.state.log_history

# ======================== 7. Training Loop (Simple) ==========================
print("Starting PT5 regression training...")

tokenizer, model, history = train_per_protein(
    my_train,
    my_valid,
    num_labels=1,
    batch=4,
    accum=1,
    val_batch=16,
    epochs=10,
    lr=2e-4,
    wd=0.01,
    seed=42,
    mixed=False,
)

# ======================== 8. Plot Results ==========================
loss = [x['loss'] for x in history if 'loss' in x]
val_loss = [x['eval_loss'] for x in history if 'eval_loss' in x]
metric = [x['eval_spearmanr'] for x in history if 'eval_spearmanr' in x]
epochs_list = [x['epoch'] for x in history if 'loss' in x]

if len(loss) > 0:
    fig, ax1 = plt.subplots(figsize=(8, 4))
    ax2 = ax1.twinx()
    line1 = ax1.plot(epochs_list, loss, label='train_loss')
    line2 = ax1.plot(epochs_list, val_loss, label='val_loss') if len(val_loss) > 0 else []
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    line3 = ax2.plot(epochs_list, metric, color='red', label='val_metric') if len(metric) > 0 else []
    ax2.set_ylabel('Spearman r')
    ax2.set_ylim([0, 1])
    lines = line1 + line2 + line3
    labels = [line.get_label() for line in lines]
    ax1.legend(lines, labels, loc='lower left')
    plt.title("Training History")
    plt.show()

# ====================== 9. Save & Reload Model (LoRA weights only) =========================
def save_model(model, filepath):
    non_frozen_params = {n: p for n, p in model.named_parameters() if p.requires_grad}
    torch.save(non_frozen_params, filepath)
    print(f"Model saved to {filepath}")

def load_model(filepath, num_labels=1, mixed=False):
    model, tokenizer = PT5_regression_model(num_labels=num_labels, half_precision=mixed)
    non_frozen_params = torch.load(filepath, map_location='cpu')
    for name, param in model.named_parameters():
        if name in non_frozen_params:
            param.data = non_frozen_params[name].data.clone()
    print(f"Model loaded from {filepath}")
    return tokenizer, model

save_model(model, "./PT5_finetuned.pth")
tokenizer_reload, model_reload = load_model("./PT5_finetuned.pth", num_labels=1, mixed=False)

# ======================== 10. Inference on Test Set ==========================
print("Running inference on test set...")

test_set = create_dataset(tokenizer_reload, list(my_test["sequence"]), list(my_test["fitness"]), list(my_test["weight"]), max_length=128)
test_set = test_set.with_format("torch")
test_dataloader = DataLoader(test_set, batch_size=8, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_reload.to(device)
model_reload.eval()

predictions = []
with torch.no_grad():
    for batch in tqdm(test_dataloader, desc="Testing"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        logits = model_reload(input_ids=input_ids, attention_mask=attention_mask).logits
        batch_predictions = logits.cpu().numpy().squeeze()
        if batch_predictions.ndim == 0:
            batch_predictions = [batch_predictions.item()]
        else:
            batch_predictions = batch_predictions.tolist()
        predictions.extend(batch_predictions)

true_labels = my_test['fitness'].values
predictions_arr = np.array(predictions)
test_correlation = stats.spearmanr(predictions_arr, true_labels).correlation
pearson_corr = stats.pearsonr(predictions_arr, true_labels)[0]

print(f"Test Spearman r: {test_correlation:.4f}")
print(f"Test Pearson r:  {pearson_corr:.4f}")
print(f"Variance of predictions: {np.var(predictions_arr):.6f}")

print("\nTrue vs Predicted on Test Set:")
print(pd.DataFrame({
    "True": true_labels,
    "Pred": predictions_arr
}).head(20))

# ======================== 11. Plot Test Predictions ==========================
import seaborn as sns

plt.figure(figsize=(6, 6))
sns.scatterplot(x=true_labels, y=predictions_arr, s=80, color='royalblue', edgecolor='k')
plt.plot([true_labels.min(), true_labels.max()],
         [true_labels.min(), true_labels.max()],
         'r--', lw=2, label='Identity (y=x)')
plt.xlabel("True Fitness (Label)")
plt.ylabel("Predicted Fitness")
plt.title(f"Test Set Predictions vs True (Spearman r = {test_correlation:.3f})")
plt.legend()
plt.tight_layout()
plt.show()
