In [None]:
# Install necessary packages
!pip install torch transformers scikit-learn pandas numpy biopython peft bitsandbytes requests optuna

In [None]:
# Install necessary tools in Google Colab
!apt-get install -y mafft
!apt-get install -y hmmer
!wget https://github.com/soedinglab/MMseqs2/releases/download/17-b804f/mmseqs-linux-gpu.tar.gz
!tar xvf mmseqs-linux-gpu.tar.gz
!chmod +x mmseqs

In [None]:
import pandas as pd
import torch
import requests
from transformers import AutoTokenizer, AutoModelForMaskedLM, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import Ridge, Lasso, ElasticNet
from sklearn.metrics import mean_squared_error
import numpy as np
from Bio import AlignIO
from Bio import SeqIO
import optuna
from torch.utils.data import Dataset, DataLoader
import time
import xml.etree.ElementTree as ET
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import subprocess

In [None]:
import requests, time, subprocess
import numpy as np
import xml.etree.ElementTree as ET
from Bio import AlignIO

WT_SEQUENCE = "MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETEIFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNYPEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQLSLSQIRDARANDQSQNHLFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEG"
BLAST_URL = "https://blast.ncbi.nlm.nih.gov/Blast.cgi"

# ======= Step 1: BLAST API for Homologs =======
def run_blast_search(wt_sequence, identity_threshold=90.0, max_retries=30, sleep_time=10, min_length=100, hitlist_size=300):
    """Run BLAST and extract homologous sequences below a given identity threshold."""
    params = {
        "CMD": "Put",
        "PROGRAM": "blastp",
        "DATABASE": "nr",
        "QUERY": wt_sequence,
        "FORMAT_TYPE": "XML",
        "EXPECT": "1e-2",
        "HITLIST_SIZE": str(hitlist_size)
    }
    response = requests.post(BLAST_URL, data=params)
    response.raise_for_status()
    response_text = response.text
    if "RID = " not in response_text:
        raise Exception("No RID found in BLAST response.")
    rid = response_text.split("RID = ")[-1].split("\n")[0].strip()
    print(f"BLAST RID: {rid}")

    # Wait for completion
    for attempt in range(max_retries):
        status = requests.get(BLAST_URL, params={"CMD":"Get", "FORMAT_OBJECT":"SearchInfo", "RID":rid})
        if "Status=READY" in status.text:
            print("BLAST complete.")
            break
        print(f"Waiting... {attempt+1}/{max_retries}")
        time.sleep(sleep_time)
    else:
        raise Exception("BLAST timed out")

    # Download results
    result = requests.get(BLAST_URL, params={"CMD":"Get", "FORMAT_TYPE":"XML", "RID":rid})
    result.raise_for_status()
    root = ET.fromstring(result.text)
    seqs = []
    for hit in root.findall(".//Hit"):
        for hsp in hit.findall(".//Hsp"):
            hseq_elem = hsp.find("Hsp_hseq")
            identity_elem = hsp.find("Hsp_identity")
            align_len_elem = hsp.find("Hsp_align-len")
            if hseq_elem is not None and identity_elem is not None and align_len_elem is not None:
                hseq = hseq_elem.text.strip()
                identity = int(identity_elem.text)
                align_len = int(align_len_elem.text)
                identity_pct = 100 * identity / align_len
                if identity_pct < identity_threshold and len(hseq) > min_length:
                    seqs.append(hseq)
    seqs = [wt_sequence] + list({s for s in seqs if s != wt_sequence})  # unique, include WT
    print(f"Total homologs: {len(seqs)}")
    # Save to FASTA
    with open("msa_input.fasta", "w") as f:
        for i, s in enumerate(seqs):
            f.write(f">seq{i}\n{s}\n")
    return "msa_input.fasta"

# ======= Step 2: Align with MAFFT =======
def run_mafft(input_fasta, output_fasta="msa_aligned.fasta"):
    print(f"Running MAFFT alignment...")
    cmd = f"mafft --auto {input_fasta} > {output_fasta}"
    subprocess.run(cmd, shell=True, check=True)
    print(f"Alignment written: {output_fasta}")
    return output_fasta

# ======= Step 3: Calculate Henikoff Weights =======
def henikoff_weights(msa_file, format="fasta"):
    alignment = AlignIO.read(msa_file, format)
    n_seq = len(alignment)
    aln_len = alignment.get_alignment_length()
    weights = np.zeros(n_seq)
    for pos in range(aln_len):
        aa_counts = {}
        for record in alignment:
            aa = record.seq[pos]
            if aa not in aa_counts:
                aa_counts[aa] = 0
            aa_counts[aa] += 1
        n_types = len(aa_counts)
        for i, record in enumerate(alignment):
            aa = record.seq[pos]
            weights[i] += 1.0 / (n_types * aa_counts[aa])
    weights /= weights.sum()
    return weights

# ======= (Optional) Jackhmmer/MMseqs2 integration (not changed here) =======

# ==== MAIN WORKFLOW ====
method = "blast"  # "jackhmmer" or "mmseqs2" possible if implemented

if method == "blast":
    msa_input = run_blast_search(WT_SEQUENCE, identity_threshold=90.0, hitlist_size=500)
elif method == "jackhmmer":
    msa_input = run_jackhmmer_search(WT_SEQUENCE)
elif method == "mmseqs2":
    msa_input = run_mmseqs2_search(WT_SEQUENCE)
else:
    raise ValueError("Invalid method chosen. Please select 'blast', 'jackhmmer', or 'mmseqs2'.")

msa_aligned = run_mafft(msa_input)

weights = henikoff_weights(msa_aligned, "fasta")
print("Sequence weights:", weights)
np.save("msa_weights.npy", weights)


In [None]:
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForMaskedLM
from sklearn.linear_model import Ridge, Lasso, ElasticNet
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score, KFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import optuna
import xgboost as xgb

# --- DEVICE ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load ProtBERT ---
print("Loading ProtBERT...")
tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert")
base_model = AutoModelForMaskedLM.from_pretrained("Rostlab/prot_bert").to(device)

# --- Embedding extractor ---
def get_protbert_embedding(model, sequence):
    with torch.no_grad():
        tokens = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
        outputs = model(**tokens, output_hidden_states=True)
        last_hidden_state = outputs.hidden_states[-1]
        embedding = last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
        return embedding

# --- Mutation utility ---
def generate_mutated_sequence(wt_sequence, mutation_string):
    seq_list = list(wt_sequence)
    mutations = mutation_string.split(',')
    for mut in mutations:
        if len(mut) >= 3 and mut[1:-1].isdigit():
            pos = int(mut[1:-1]) - 1
            if pos < len(seq_list) and mut[-1] != '*':
                seq_list[pos] = mut[-1]
    return ''.join(seq_list)

# --- Load data ---
print("Loading dataset...")
url = "https://figshare.com/ndownloader/files/7337543"
df = pd.read_csv(url, sep='\t')
df.rename(columns={'mutation': 'mutation_string', 'normalized_fitness': 'fitness'}, inplace=True)
df['fitness'] = pd.to_numeric(df['fitness'], errors='coerce')
df.dropna(subset=['fitness'], inplace=True)
df = df.sample(frac=1, random_state=42).reset_index(drop=True)   # SHUFFLE
# Outlier handling (optional, clip to [-5, 5] or domain-specific)
df['fitness'] = df['fitness'].clip(-5, 5)

WT_SEQUENCE = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLTYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"
df['mutated_sequence'] = df['mutation_string'].apply(lambda x: generate_mutated_sequence(WT_SEQUENCE, x))

# --- Generate embeddings ---
print("Generating embeddings... (may take a while)")
embeddings, y = [], []
for i, row in df.iterrows():
    try:
        emb = get_protbert_embedding(base_model, row['mutated_sequence'])
        embeddings.append(emb)
        y.append(row['fitness'])
    except Exception as e:
        print(f"Skipping sequence due to error: {e}")
X = np.vstack(embeddings)
y = np.array(y)

# --- Scale features ---
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# --- Train/test split (shuffle again just in case) ---
all_indices = np.arange(len(X_scaled))
X_train_index, X_test_index = train_test_split(all_indices, test_size=0.2, random_state=42)
X_train, X_test = X_scaled[X_train_index], X_scaled[X_test_index]
y_train, y_test = y[X_train_index], y[X_test_index]

print(f"Train targets variance: {np.var(y_train):.3f}, Test targets variance: {np.var(y_test):.3f}")

# ---- Sample weights (optional, all ones) ----
weights = np.ones(len(X_scaled))
train_weights = weights[X_train_index]

# --- XGBoost + Optuna ---
def xgb_objective(trial):
    params = {
        'tree_method': 'hist',
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'n_estimators': trial.suggest_int('n_estimators', 100, 400),
        'max_depth': trial.suggest_int('max_depth', 3, 16),
        'learning_rate': trial.suggest_float('learning_rate', 0.005, 0.1, log=True),
        'subsample': trial.suggest_float('subsample', 0.6, 1.0),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.6, 1.0),
        'reg_alpha': trial.suggest_float('reg_alpha', 0, 0.6),
        'reg_lambda': trial.suggest_float('reg_lambda', 0, 0.6)
    }
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    mse_scores = []
    for train_idx, val_idx in kf.split(X_train):
        model = xgb.XGBRegressor(**params)
        model.fit(X_train[train_idx], y_train[train_idx], sample_weight=train_weights[train_idx])
        preds = model.predict(X_train[val_idx])
        mse_scores.append(mean_squared_error(y_train[val_idx], preds))
    return np.mean(mse_scores)

print("Running Optuna for XGBoost...")
study = optuna.create_study(direction='minimize')
study.optimize(xgb_objective, n_trials=15)
xgb_params = study.best_params
xgb_params['tree_method'] = 'hist'
xgb_params['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Best XGBoost params:", xgb_params)

xgb_model = xgb.XGBRegressor(**xgb_params)
xgb_model.fit(X_train, y_train, sample_weight=train_weights)
y_pred_xgb = xgb_model.predict(X_test)

print("\n--- Diagnostics: y_test and y_pred_xgb ---")
print("y_test: min", np.min(y_test), "max", np.max(y_test), "unique", np.unique(y_test)[:10], "std", np.std(y_test))
print("y_pred_xgb: min", np.min(y_pred_xgb), "max", np.max(y_pred_xgb), "unique", np.unique(np.round(y_pred_xgb,3))[:10], "std", np.std(y_pred_xgb))

mse_xgb = mean_squared_error(y_test, y_pred_xgb)
if np.std(y_pred_xgb) < 1e-8 or np.std(y_test) < 1e-8:
    pearson_xgb = 0.0
else:
    pearson_xgb, _ = pearsonr(y_test, y_pred_xgb)
print(f"XGBoost Test MSE: {mse_xgb:.4f}  Pearson: {pearson_xgb:.4f}")

plt.figure(figsize=(8, 6))
plt.scatter(y_test, y_pred_xgb, color='green', alpha=0.6)
plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], color='red', linestyle='--')
plt.xlabel('Actual Fitness')
plt.ylabel('Predicted Fitness')
plt.title('Actual vs Predicted Fitness (XGBoost)')
plt.grid(True)
plt.show()

# ---- Cross-validation: report mean Pearson/MSE on all data ----
print("\n--- Cross-validation on all data (XGBoost) ---")
kf = KFold(n_splits=5, shuffle=True, random_state=42)
pearson_scores = []
mse_scores = []
for train_idx, test_idx in kf.split(X_scaled):
    model = xgb.XGBRegressor(**xgb_params)
    model.fit(X_scaled[train_idx], y[train_idx], sample_weight=weights[train_idx])
    y_pred = model.predict(X_scaled[test_idx])
    mse_scores.append(mean_squared_error(y[test_idx], y_pred))
    # Pearson
    if np.std(y_pred) < 1e-8 or np.std(y[test_idx]) < 1e-8:
        p = 0.0
    else:
        p, _ = pearsonr(y[test_idx], y_pred)
    pearson_scores.append(p)
print(f"Mean CV MSE: {np.mean(mse_scores):.4f}, Mean CV Pearson: {np.mean(pearson_scores):.4f}")

# ---- Linear Models: Lasso, Ridge, ElasticNet ----
def train_and_evaluate_model(model, param_grid, X_train, y_train, X_test, y_test, train_weights, model_name):
    cv_scores = cross_val_score(
        model, X_train, y_train, cv=5,
        scoring='neg_mean_squared_error',
        fit_params={'sample_weight': train_weights}
    )
    print(f"{model_name} - CV MSE: {-cv_scores.mean():.4f}")

    grid_search = GridSearchCV(model, param_grid, cv=5, scoring='neg_mean_squared_error')
    grid_search.fit(X_train, y_train, sample_weight=train_weights)
    print(f"Best parameters for {model_name}: {grid_search.best_params_}")

    y_pred = grid_search.predict(X_test)
    mse = mean_squared_error(y_test, y_pred)
    if np.std(y_pred) < 1e-8 or np.std(y_test) < 1e-8:
        pearson_corr = 0.0
    else:
        pearson_corr, _ = pearsonr(y_test, y_pred)
    print(f"MSE - {model_name}: {mse:.4f}")
    print(f"Pearson Corr - {model_name}: {pearson_corr:.4f}")

    plt.figure(figsize=(8, 6))
    plt.scatter(y_test, y_pred, color='blue', alpha=0.6)
    plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], color='red', linestyle='--')
    plt.xlabel('Actual Fitness')
    plt.ylabel('Predicted Fitness')
    plt.title(f'Actual vs Predicted Fitness ({model_name})')
    plt.grid(True)
    plt.show()

    return grid_search

param_grid_ridge = {'alpha': [0.01, 0.1, 1, 10, 100]}
param_grid_lasso = {'alpha': [0.01, 0.1, 1, 10, 100]}
param_grid_en = {'alpha': [0.01, 0.1, 1, 10, 100], 'l1_ratio': [0.1, 0.5, 0.9, 1.0]}

ridge_model = Ridge()
lasso_model = Lasso()
elastic_net_model = ElasticNet()

train_and_evaluate_model(ridge_model, param_grid_ridge, X_train, y_train, X_test, y_test, train_weights, "Ridge")
train_and_evaluate_model(lasso_model, param_grid_lasso, X_train, y_train, X_test, y_test, train_weights, "Lasso")
train_and_evaluate_model(elastic_net_model, param_grid_en, X_train, y_train, X_test, y_test, train_weights, "ElasticNet")

print("Done!")


In [None]:
#import dependencies
import os.path
os.chdir("set a path here")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.utils.data import DataLoader

import re
import numpy as np
import pandas as pd
import copy

import transformers, datasets
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from transformers import T5EncoderModel, T5Tokenizer
from transformers import TrainingArguments, Trainer, set_seed

from evaluate import load
from datasets import Dataset

from tqdm import tqdm
import random

from scipy import stats
from sklearn.metrics import accuracy_score

import matplotlib.pyplot as plt # Set environment variables to run Deepspeed from a notebook
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "9994"  # modify if RuntimeError: Address already in use
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1" # For this example we import the "three_vs_rest" GB1 dataset from https://github.com/J-SNACKKB/FLIP
# For details, see publication here: https://openreview.net/forum?id=p2dMLEwL8tF
import requests
import zipfile
from io import BytesIO

# Download the zip file from GitHub
url = 'https://github.com/J-SNACKKB/FLIP/raw/main/splits/gb1/splits.zip'
response = requests.get(url)
zip_file = zipfile.ZipFile(BytesIO(response.content))

# Load the `three_vs_rest.csv` file into a pandas dataframe
with zip_file.open('splits/three_vs_rest.csv') as file:
    df = pd.read_csv(file) # Drop test data
df=df[df.set=="train"]

# Get train and validation data
my_train=df[df.validation!=True].reset_index(drop=True)
my_valid=df[df.validation==True].reset_index(drop=True)

# Set column names to "sequence" and "label"
my_train.columns=["sequence","label"]+list(my_train.columns[2:])
my_valid.columns=["sequence","label"]+list(my_valid.columns[2:])

# Drop unneeded columns
my_train=my_train[["sequence","label"]]
my_valid=my_valid[["sequence","label"]] # Modifies an existing transformer and introduce the LoRA layers

class LoRAConfig:
    def __init__(self):
        self.lora_rank = 4
        self.lora_init_scale = 0.01
        self.lora_modules = ".*SelfAttention|.*EncDecAttention"
        self.lora_layers = "q|k|v|o"
        self.trainable_param_names = ".*layer_norm.*|.*lora_[ab].*"
        self.lora_scaling_rank = 1
        # lora_modules and lora_layers are speicified with regular expressions
        # see https://www.w3schools.com/python/python_regex.asp for reference
        
class LoRALinear(nn.Module):
    def __init__(self, linear_layer, rank, scaling_rank, init_scale):
        super().__init__()
        self.in_features = linear_layer.in_features
        self.out_features = linear_layer.out_features
        self.rank = rank
        self.scaling_rank = scaling_rank
        self.weight = linear_layer.weight
        self.bias = linear_layer.bias
        if self.rank > 0:
            self.lora_a = nn.Parameter(torch.randn(rank, linear_layer.in_features) * init_scale)
            if init_scale < 0:
                self.lora_b = nn.Parameter(torch.randn(linear_layer.out_features, rank) * init_scale)
            else:
                self.lora_b = nn.Parameter(torch.zeros(linear_layer.out_features, rank))
        if self.scaling_rank:
            self.multi_lora_a = nn.Parameter(
                torch.ones(self.scaling_rank, linear_layer.in_features)
                + torch.randn(self.scaling_rank, linear_layer.in_features) * init_scale
            )
            if init_scale < 0:
                self.multi_lora_b = nn.Parameter(
                    torch.ones(linear_layer.out_features, self.scaling_rank)
                    + torch.randn(linear_layer.out_features, self.scaling_rank) * init_scale
                )
            else:
                self.multi_lora_b = nn.Parameter(torch.ones(linear_layer.out_features, self.scaling_rank))

    def forward(self, input):
        if self.scaling_rank == 1 and self.rank == 0:
            # parsimonious implementation for ia3 and lora scaling
            if self.multi_lora_a.requires_grad:
                hidden = F.linear((input * self.multi_lora_a.flatten()), self.weight, self.bias)
            else:
                hidden = F.linear(input, self.weight, self.bias)
            if self.multi_lora_b.requires_grad:
                hidden = hidden * self.multi_lora_b.flatten()
            return hidden
        else:
            # general implementation for lora (adding and scaling)
            weight = self.weight
            if self.scaling_rank:
                weight = weight * torch.matmul(self.multi_lora_b, self.multi_lora_a) / self.scaling_rank
            if self.rank:
                weight = weight + torch.matmul(self.lora_b, self.lora_a) / self.rank
            return F.linear(input, weight, self.bias)

    def extra_repr(self):
        return "in_features={}, out_features={}, bias={}, rank={}, scaling_rank={}".format(
            self.in_features, self.out_features, self.bias is not None, self.rank, self.scaling_rank
        )


def modify_with_lora(transformer, config):
    for m_name, module in dict(transformer.named_modules()).items():
        if re.fullmatch(config.lora_modules, m_name):
            for c_name, layer in dict(module.named_children()).items():
                if re.fullmatch(config.lora_layers, c_name):
                    assert isinstance(
                        layer, nn.Linear
                    ), f"LoRA can only be applied to torch.nn.Linear, but {layer} is {type(layer)}."
                    setattr(
                        module,
                        c_name,
                        LoRALinear(layer, config.lora_rank, config.lora_scaling_rank, config.lora_init_scale),
                    )
    return transformer class ClassConfig:
    def __init__(self, dropout=0.2, num_labels=1):
        self.dropout_rate = dropout
        self.num_labels = num_labels

class T5EncoderClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config, class_config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(class_config.dropout_rate)
        self.out_proj = nn.Linear(config.hidden_size, class_config.num_labels)

    def forward(self, hidden_states):

        hidden_states =  torch.mean(hidden_states,dim=1)  # avg embedding

        hidden_states = self.dropout(hidden_states)
        hidden_states = self.dense(hidden_states)
        hidden_states = torch.tanh(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.out_proj(hidden_states)
        return hidden_states

class T5EncoderForSimpleSequenceClassification(T5PreTrainedModel):

    def __init__(self, config: T5Config, class_config):
        super().__init__(config)
        self.num_labels = class_config.num_labels
        self.config = config

        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        self.encoder = T5Stack(encoder_config, self.shared)

        self.dropout = nn.Dropout(class_config.dropout_rate) 
        self.classifier = T5EncoderClassificationHead(config, class_config)

        # Initialize weights and apply final processing
        self.post_init()

        # Model parallel
        self.model_parallel = False
        self.device_map = None

    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.classifier = self.classifier.to(self.encoder.first_device)
        self.model_parallel = True

    def deparallelize(self):
        self.encoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.set_input_embeddings(new_embeddings)

    def get_encoder(self):
        return self.encoder

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        logits = self.classifier(hidden_states)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)
        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        ) def PT5_classification_model(num_labels, half_precision):
    # Load PT5 and tokenizer
    # possible to load the half preciion model (thanks to @pawel-rezo for pointing that out)
    if not half_precision:
        model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
        tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50")
    elif half_precision and torch.cuda.is_available() : 
        tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
        model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16).to(torch.device('cuda'))
    else:
          raise ValueError('Half precision can be run on GPU only.')
    
    # Create new Classifier model with PT5 dimensions
    class_config=ClassConfig(num_labels=num_labels)
    class_model=T5EncoderForSimpleSequenceClassification(model.config,class_config)
    
    # Set encoder and embedding weights to checkpoint weights
    class_model.shared=model.shared
    class_model.encoder=model.encoder    
    
    # Delete the checkpoint model
    model=class_model
    del class_model
    
    # Print number of trainable parameters
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print("ProtT5_Classfier\nTrainable Parameter: "+ str(params))    
 
    # Add model modification lora
    config = LoRAConfig()
    
    # Add LoRA layers
    model = modify_with_lora(model, config)
    
    # Freeze Embeddings and Encoder (except LoRA)
    for (param_name, param) in model.shared.named_parameters():
                param.requires_grad = False
    for (param_name, param) in model.encoder.named_parameters():
                param.requires_grad = False       

    for (param_name, param) in model.named_parameters():
            if re.fullmatch(config.trainable_param_names, param_name):
                param.requires_grad = True

    # Print trainable Parameter          
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print("ProtT5_LoRA_Classfier\nTrainable Parameter: "+ str(params) + "\n")
    
    return model, tokenizer # Deepspeed config for optimizer CPU offload

ds_config = {
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },

    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto"
        }
    },

    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": True
        },
        "allgather_partitions": True,
        "allgather_bucket_size": 2e8,
        "overlap_comm": True,
        "reduce_scatter": True,
        "reduce_bucket_size": 2e8,
        "contiguous_gradients": True
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": False
} # Set random seeds for reproducibility of your trainings run
def set_seeds(s):
    torch.manual_seed(s)
    np.random.seed(s)
    random.seed(s)
    set_seed(s)

# Dataset creation
def create_dataset(tokenizer,seqs,labels):
    tokenized = tokenizer(seqs, max_length=1024, padding=False, truncation=True)
    dataset = Dataset.from_dict(tokenized)
    dataset = dataset.add_column("labels", labels)

    return dataset
    
# Main training fuction
def train_per_protein(
        train_df,         #training data
        valid_df,         #validation data      
        num_labels= 1,    #1 for regression, >1 for classification
    
        # effective training batch size is batch * accum
        # we recommend an effective batch size of 8 
        batch= 4,         #for training
        accum= 2,         #gradient accumulation
    
        val_batch = 16,   #batch size for evaluation
        epochs= 10,       #training epochs
        lr= 3e-4,         #recommended learning rate
        seed= 42,         #random seed
        deepspeed= True,  #if gpu is large enough disable deepspeed for training speedup
        mixed= False,     #enable mixed precision training
        gpu= 1 ):         #gpu selection (1 for first gpu)

    # Set gpu device
    os.environ["CUDA_VISIBLE_DEVICES"]=str(gpu-1)
    
    # Set all random seeds
    set_seeds(seed)
    
    # load model
    model, tokenizer = PT5_classification_model(num_labels=num_labels)

    # Preprocess inputs
    # Replace uncommon AAs with "X"
    train_df["sequence"]=train_df["sequence"].str.replace('|'.join(["O","B","U","Z"]),"X",regex=True)
    valid_df["sequence"]=valid_df["sequence"].str.replace('|'.join(["O","B","U","Z"]),"X",regex=True)
    # Add spaces between each amino acid for PT5 to correctly use them
    train_df['sequence']=train_df.apply(lambda row : " ".join(row["sequence"]), axis = 1)
    valid_df['sequence']=valid_df.apply(lambda row : " ".join(row["sequence"]), axis = 1)

    # Create Datasets
    train_set=create_dataset(tokenizer,list(train_df['sequence']),list(train_df['label']))
    valid_set=create_dataset(tokenizer,list(valid_df['sequence']),list(valid_df['label']))

    # Huggingface Trainer arguments
    args = TrainingArguments(
        "./",
        evaluation_strategy = "epoch",
        logging_strategy = "epoch",
        save_strategy = "no",
        learning_rate=lr,
        per_device_train_batch_size=batch,
        per_device_eval_batch_size=val_batch,
        gradient_accumulation_steps=accum,
        num_train_epochs=epochs,
        seed = seed,
        deepspeed= ds_config if deepspeed else None,
        fp16 = mixed,
    ) 

    # Metric definition for validation data
    def compute_metrics(eval_pred):
        if num_labels>1:  # for classification
            metric = load("accuracy")
            predictions, labels = eval_pred
            predictions = np.argmax(predictions, axis=1)
        else:  # for regression
            metric = load("spearmanr")
            predictions, labels = eval_pred

        return metric.compute(predictions=predictions, references=labels)
    
    # Trainer          
    trainer = Trainer(
        model,
        args,
        train_dataset=train_set,
        eval_dataset=valid_set,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )    
    
    # Train model
    trainer.train()

    return tokenizer, model, trainer.state.log_history tokenizer, model, history = train_per_protein(my_train, my_valid, num_labels=1, batch=1, accum=8, epochs=20, seed=42) # Get loss, val_loss, and the computed metric from history
loss = [x['loss'] for x in history if 'loss' in x]
val_loss = [x['eval_loss'] for x in history if 'eval_loss' in x]

# Get spearman (for regression) or accuracy value (for classification)
if [x['eval_spearmanr'] for x in history if 'eval_spearmanr' in x] != []:
    metric = [x['eval_spearmanr'] for x in history if 'eval_spearmanr' in x]
else:
    metric = [x['eval_accuracy'] for x in history if 'eval_accuracy' in x]

epochs = [x['epoch'] for x in history if 'loss' in x]

# Create a figure with two y-axes
fig, ax1 = plt.subplots(figsize=(10, 5))
ax2 = ax1.twinx()

# Plot loss and val_loss on the first y-axis
line1 = ax1.plot(epochs, loss, label='train_loss')
line2 = ax1.plot(epochs, val_loss, label='val_loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')

# Plot the computed metric on the second y-axis
line3 = ax2.plot(epochs, metric, color='red', label='val_metric')
ax2.set_ylabel('Metric')
ax2.set_ylim([0, 1])

# Combine the lines from both y-axes and create a single legend
lines = line1 + line2 + line3
labels = [line.get_label() for line in lines]
ax1.legend(lines, labels, loc='lower left')

# Show the plot
plt.title("Training History")
plt.show() def save_model(model,filepath):
# Saves all parameters that were changed during finetuning

    # Create a dictionary to hold the non-frozen parameters
    non_frozen_params = {}

    # Iterate through all the model parameters
    for param_name, param in model.named_parameters():
        # If the parameter has requires_grad=True, add it to the dictionary
        if param.requires_grad:
            non_frozen_params[param_name] = param

    # Save only the finetuned parameters 
    torch.save(non_frozen_params, filepath)

    
def load_model(filepath, num_labels=1, mixed = False):
# Creates a new PT5 model and loads the finetuned weights from a file

    # load a new model
    model, tokenizer = PT5_classification_model(num_labels=num_labels, half_precision=mixed)
    
    # Load the non-frozen parameters from the saved file
    non_frozen_params = torch.load(filepath)

    # Assign the non-frozen parameters to the corresponding parameters of the model
    for param_name, param in model.named_parameters():
        if param_name in non_frozen_params:
            param.data = non_frozen_params[param_name].data

    return tokenizer, model tokenizer, model_reload = load_model("./PT5_GB1_finetuned.pth", num_labels=1, mixed=False) # Put both models to the same device
model=model.to("cpu")
model_reload=model_reload.to("cpu")

# Iterate through the parameters of the two models and compare the data
for param1, param2 in zip(model.parameters(), model_reload.parameters()):
    if not torch.equal(param1.data, param2.data):
        print("Models have different weights")
        break
else:
    print("Models have identical weights") # For this we import the "three_vs_rest" GB1 dataset again 
# from https://github.com/J-SNACKKB/FLIP

import requests
import zipfile
from io import BytesIO

# Download the zip file from GitHub
url = 'https://github.com/J-SNACKKB/FLIP/raw/main/splits/gb1/splits.zip'
response = requests.get(url)
zip_file = zipfile.ZipFile(BytesIO(response.content))

# Load the `three_vs_rest.csv` file into a pandas dataframe
with zip_file.open('splits/three_vs_rest.csv') as file:
    df = pd.read_csv(file) # Select only test data
my_test=df[df.set=="test"]

# Set column names to "sequence" and "label"
my_test.columns=["sequence","label"]+list(my_test.columns[2:])

# Drop unneeded columns
my_test=my_test[["sequence","label"]]
print(my_test.head(5))

# Preprocess sequences
my_test["sequence"]=my_test["sequence"].str.replace('|'.join(["O","B","U","Z"]),"X",regex=True)
my_test['sequence']=my_test.apply(lambda row : " ".join(row["sequence"]), axis = 1) #Use reloaded model
model = model_reload
del model_reload

# Set the device to use
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

# create Dataset
test_set=create_dataset(tokenizer,list(my_test['sequence']),list(my_test['label']))
# make compatible with torch DataLoader
test_set = test_set.with_format("torch", device=device)

# Create a dataloader for the test dataset
test_dataloader = DataLoader(test_set, batch_size=16, shuffle=False)

# Put the model in evaluation mode
model.eval()

# Make predictions on the test dataset
predictions = []
with torch.no_grad():
    for batch in tqdm(test_dataloader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        #add batch results(logits) to predictions
        predictions += model(input_ids, attention_mask=attention_mask).logits.tolist() # Regression
print(stats.spearmanr(a=predictions, b=my_test.label, axis=0))

# Classification
# we need to determine the prediction class from the logit output
# predictions= [item.argmax() for item in np.array(predictions)]
# print("Accuracy: ", accuracy_score(my_test.label, predictions)) :_:: pleaase update the whole code to run kaggle notebook t4.