In [None]:
# ProtT5 LoRA Fine-tuning for Kaggle T4 GPU
# Updated for Kaggle notebook environment

# Install required packages (run this in a cell)
!pip install transformers datasets evaluate deepspeed scipy scikit-learn

In [None]:
import transformers
print(transformers.__version__)

In [None]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="./outputs",
    logging_strategy="epoch",
    save_strategy="no",
    num_train_epochs=3
)
print(args)


In [None]:
# ======================== 1. Required Imports ======================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
import re
import copy
import os
from scipy import stats
from tqdm import tqdm

from transformers import (
    T5EncoderModel, T5Tokenizer, T5PreTrainedModel, T5Config,
    TrainingArguments, Trainer, set_seed
)
from transformers.modeling_outputs import SequenceClassifierOutput
from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss
from torch.utils.data import DataLoader
from datasets import Dataset

# ======================== 2. Load GB1 Data ======================================
import requests
import zipfile
from io import BytesIO

url = 'https://github.com/J-SNACKKB/FLIP/raw/main/splits/gb1/splits.zip'
response = requests.get(url)
zip_file = zipfile.ZipFile(BytesIO(response.content))
with zip_file.open('splits/three_vs_rest.csv') as file:
    df = pd.read_csv(file)
df = df.rename(columns={"target": "label"})
if "validation" not in df.columns:
    df["validation"] = False

# Ultra-small debug: only 8 samples per split, use for OOM testing
N_TRAIN = 64
N_VALID = 32
N_TEST = 32

my_train = df[(df["set"]=="train") & (df["validation"]!=True)][["sequence", "label"]].reset_index(drop=True).iloc[:N_TRAIN]
my_valid = df[(df["set"]=="train") & (df["validation"]==True)][["sequence", "label"]].reset_index(drop=True).iloc[:N_VALID]
my_test  = df[df["set"]=="test"][["sequence", "label"]].reset_index(drop=True).iloc[:N_TEST]

print(my_train.head())

# ======================== 3. Model & Tokenizer ========================
class LoRAConfig:
    def __init__(self):
        self.lora_rank = 2
        self.lora_init_scale = 0.01
        self.lora_modules = r".*SelfAttention|.*EncDecAttention"
        self.lora_layers = r"q|k|v|o"
        self.trainable_param_names = r".*layer_norm.*|.*lora_[ab].*"
        self.lora_scaling_rank = 1

class LoRALinear(nn.Module):
    def __init__(self, linear_layer, rank, scaling_rank, init_scale):
        super().__init__()
        self.in_features = linear_layer.in_features
        self.out_features = linear_layer.out_features
        self.rank = rank
        self.scaling_rank = scaling_rank
        self.weight = linear_layer.weight
        self.bias = linear_layer.bias
        if self.rank > 0:
            self.lora_a = nn.Parameter(torch.randn(rank, linear_layer.in_features) * init_scale)
            self.lora_b = nn.Parameter(torch.zeros(linear_layer.out_features, rank))
        if self.scaling_rank:
            self.multi_lora_a = nn.Parameter(
                torch.ones(self.scaling_rank, linear_layer.in_features)
                + torch.randn(self.scaling_rank, linear_layer.in_features) * init_scale
            )
            self.multi_lora_b = nn.Parameter(torch.ones(linear_layer.out_features, self.scaling_rank))

    def forward(self, input):
        weight = self.weight
        if self.scaling_rank:
            weight = weight * torch.matmul(self.multi_lora_b, self.multi_lora_a) / self.scaling_rank
        if self.rank:
            weight = weight + torch.matmul(self.lora_b, self.lora_a) / self.rank
        return F.linear(input, weight, self.bias)

def modify_with_lora(transformer, config):
    for m_name, module in dict(transformer.named_modules()).items():
        if re.fullmatch(config.lora_modules, m_name):
            for c_name, layer in dict(module.named_children()).items():
                if re.fullmatch(config.lora_layers, c_name):
                    if isinstance(layer, nn.Linear):
                        setattr(module, c_name,
                            LoRALinear(layer, config.lora_rank, config.lora_scaling_rank, config.lora_init_scale))
    return transformer

class ClassConfig:
    def __init__(self, dropout=0.2, num_labels=1):
        self.dropout_rate = dropout
        self.num_labels = num_labels

class T5EncoderClassificationHead(nn.Module):
    def __init__(self, hidden_size, class_config):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(class_config.dropout_rate)
        self.out_proj = nn.Linear(hidden_size, class_config.num_labels)

    def forward(self, hidden_states):
        hidden_states = torch.mean(hidden_states, dim=1)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.dense(hidden_states)
        hidden_states = torch.tanh(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.out_proj(hidden_states)
        return hidden_states

class T5EncoderForSimpleSequenceClassification(T5PreTrainedModel):
    def __init__(self, encoder_model, config: T5Config, class_config):
        super().__init__(config)
        self.num_labels = class_config.num_labels
        self.config = config
        self.encoder_model = encoder_model
        self.dropout = nn.Dropout(class_config.dropout_rate) 
        self.classifier = T5EncoderClassificationHead(config.d_model, class_config)
        self.post_init()
        self.model_parallel = False
        self.device_map = None

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        outputs = self.encoder_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = outputs.last_hidden_state
        logits = self.classifier(hidden_states)
        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"
            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                loss = loss_fct(logits.squeeze(), labels.squeeze())
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)
        if not return_dict:
            output = (logits,) + (outputs[1:] if isinstance(outputs, tuple) else ())
            return ((loss,) + output) if loss is not None else output
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=getattr(outputs, "hidden_states", None),
            attentions=getattr(outputs, "attentions", None),
        )

def PT5_classification_model(num_labels=1, half_precision=False):
    if not half_precision:
        encoder_model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
        tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", legacy=False)
    elif half_precision and torch.cuda.is_available():
        tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False, legacy=False)
        encoder_model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16).to(torch.device('cuda'))
    else:
        raise ValueError('Half precision can be run on GPU only.')

    class_config = ClassConfig(num_labels=num_labels)
    class_model = T5EncoderForSimpleSequenceClassification(encoder_model, encoder_model.config, class_config)
    config = LoRAConfig()
    class_model.encoder_model = modify_with_lora(class_model.encoder_model, config)
    for name, param in class_model.encoder_model.named_parameters():
        param.requires_grad = False
    for name, param in class_model.named_parameters():
        if re.fullmatch(config.trainable_param_names, name):
            param.requires_grad = True
    return class_model, tokenizer

# ======================== 5. Custom Regression Head ==========================
def set_seeds(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    set_seed(seed)

def create_dataset(tokenizer, seqs, labels, max_length=128):  # Even smaller max_length
    seqs_processed = [seq.replace('O', 'X').replace('B', 'X').replace('U', 'X').replace('Z', 'X') for seq in seqs]
    seqs_spaced = [" ".join(list(seq)) for seq in seqs_processed]
    tokenized = tokenizer(seqs_spaced, max_length=max_length, padding='max_length', truncation=True)
    dataset = Dataset.from_dict(tokenized)
    dataset = dataset.add_column("labels", labels)
    return dataset

# ======================== 6. Dataset & DataLoader ============================
def train_per_protein(
    train_df, 
    valid_df, 
    num_labels=1, 
    batch=3,   # SINGLE SAMPLE PER BATCH!
    accum=1,
    val_batch=3,
    epochs=5,  # Only 1 epoch
    lr=3e-4, 
    seed=42, 
    mixed=False, 
    gpu=0,     # Set to 0 to default to CPU if GPU fails
    use_deepspeed=False  
):
    set_seeds(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model, tokenizer = PT5_classification_model(num_labels=num_labels, half_precision=mixed)
    model.to(device)
    train_set = create_dataset(tokenizer, list(train_df["sequence"]), list(train_df["label"]), max_length=128)
    valid_set = create_dataset(tokenizer, list(valid_df["sequence"]), list(valid_df["label"]), max_length=128)

    args = TrainingArguments(
        output_dir="./outputs",
        learning_rate=lr,
        per_device_train_batch_size=batch,
        per_device_eval_batch_size=val_batch,
        gradient_accumulation_steps=accum,
        num_train_epochs=epochs,
        seed=seed,
        fp16=False,  # No fp16 for safety (try enabling later if you want)
        dataloader_num_workers=0,
        report_to="none",
    )

    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = np.squeeze(predictions)
        labels = np.squeeze(labels)
        if predictions.ndim == 0:
            predictions = np.array([predictions])
        if labels.ndim == 0:
            labels = np.array([labels])
        correlation = stats.spearmanr(predictions, labels).correlation
        if np.isnan(correlation):
            correlation = 0.0
        return {"spearmanr": correlation}
    
    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_set,
        eval_dataset=valid_set,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )
    print("Starting training...")
    trainer.train()
    print("Training completed!")
    return tokenizer, model, trainer.state.log_history

# ======================== 7. Training Loop (Simple) ==========================
print("Starting PT5 training...")

tokenizer, model, history = train_per_protein(
    my_train, 
    my_valid, 
    num_labels=1, 
    batch=3,     # OOM protection: batch size 1
    accum=1,
    val_batch=3,
    epochs=5,    # Just 1 epoch for quick run
    lr=3e-4,
    seed=42,
    mixed=False, 
    use_deepspeed=False 
)

# ======================== 8. Plot Results ==========================
loss = [x['loss'] for x in history if 'loss' in x]
val_loss = [x['eval_loss'] for x in history if 'eval_loss' in x]
metric = [x['eval_spearmanr'] for x in history if 'eval_spearmanr' in x]
epochs_list = [x['epoch'] for x in history if 'loss' in x]

if len(loss) > 0:
    fig, ax1 = plt.subplots(figsize=(8, 4))
    ax2 = ax1.twinx()
    line1 = ax1.plot(epochs_list, loss, label='train_loss')
    line2 = ax1.plot(epochs_list, val_loss, label='val_loss') if len(val_loss) > 0 else []
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    line3 = ax2.plot(epochs_list, metric, color='red', label='val_metric') if len(metric) > 0 else []
    ax2.set_ylabel('Spearman r')
    ax2.set_ylim([0, 1])
    lines = line1 + line2 + line3
    labels = [line.get_label() for line in lines]
    ax1.legend(lines, labels, loc='lower left')
    plt.title("Training History")
    plt.show()

# ======================== 9. Save & Reload Model (LoRA weights only) ==========================
def save_model(model, filepath):
    non_frozen_params = {n: p for n, p in model.named_parameters() if p.requires_grad}
    torch.save(non_frozen_params, filepath)
    print(f"Model saved to {filepath}")

def load_model(filepath, num_labels=1, mixed=False):
    model, tokenizer = PT5_classification_model(num_labels=num_labels, half_precision=mixed)
    non_frozen_params = torch.load(filepath, map_location='cpu')
    for name, param in model.named_parameters():
        if name in non_frozen_params:
            param.data = non_frozen_params[name].data.clone()
    print(f"Model loaded from {filepath}")
    return tokenizer, model

save_model(model, "./PT5_GB1_finetuned.pth")
tokenizer_reload, model_reload = load_model("./PT5_GB1_finetuned.pth", num_labels=1, mixed=False)

# ======================== 10. Inference on Test Set ==========================
print("Running inference on test set...")

test_set = create_dataset(tokenizer_reload, list(my_test["sequence"]), list(my_test["label"]), max_length=128)
test_set = test_set.with_format("torch")
test_dataloader = DataLoader(test_set, batch_size=3, shuffle=False)   # batch_size 1 for OOM safety

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_reload.to(device)
model_reload.eval()

predictions = []
with torch.no_grad():
    for batch in tqdm(test_dataloader, desc="Testing"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        logits = model_reload(input_ids=input_ids, attention_mask=attention_mask).logits
        batch_predictions = logits.cpu().numpy().squeeze()
        if batch_predictions.ndim == 0:
            batch_predictions = [batch_predictions.item()]
        else:
            batch_predictions = batch_predictions.tolist()
        predictions.extend(batch_predictions)

test_correlation = stats.spearmanr(predictions, my_test['label']).correlation
print(f"Test Spearman r: {test_correlation:.4f}")


# ======================== 11. Plot Test Predictions ==========================
import seaborn as sns

true_labels = my_test['label'].values
predictions_arr = np.array(predictions)

plt.figure(figsize=(6, 6))
sns.scatterplot(x=true_labels, y=predictions_arr, s=80, color='royalblue', edgecolor='k')
plt.plot([true_labels.min(), true_labels.max()],
         [true_labels.min(), true_labels.max()],
         'r--', lw=2, label='Identity (y=x)')
plt.xlabel("True Fitness (Label)")
plt.ylabel("Predicted Fitness")
plt.title(f"Test Set Predictions vs True (Spearman r = {test_correlation:.3f})")
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
# ======================== 1. Required Imports ======================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
import re
import os
from scipy import stats
from tqdm import tqdm

from transformers import (
    T5EncoderModel, T5Tokenizer, T5PreTrainedModel, T5Config,
    TrainingArguments, Trainer, set_seed
)
from transformers.modeling_outputs import SequenceClassifierOutput
from torch.nn import MSELoss
from torch.utils.data import DataLoader
from datasets import Dataset

# ======================== 2. Load GB1 Data ======================================
import requests
import zipfile
from io import BytesIO

url = 'https://github.com/J-SNACKKB/FLIP/raw/main/splits/gb1/splits.zip'
response = requests.get(url)
zip_file = zipfile.ZipFile(BytesIO(response.content))
with zip_file.open('splits/three_vs_rest.csv') as file:
    df = pd.read_csv(file)
df = df.rename(columns={"target": "label"})
if "validation" not in df.columns:
    df["validation"] = False

# Use ALL data
my_train = df[(df["set"]=="train") & (df["validation"]!=True)][["sequence", "label"]].reset_index(drop=True)
my_valid = df[(df["set"]=="train") & (df["validation"]==True)][["sequence", "label"]].reset_index(drop=True)
my_test  = df[df["set"]=="test"][["sequence", "label"]].reset_index(drop=True)

print(f"Train: {len(my_train)}, Val: {len(my_valid)}, Test: {len(my_test)}")
print(my_train.head())

# ======================== 3. Model & Tokenizer (improved) ========================
class LoRAConfig:
    def __init__(self):
        self.lora_rank = 2
        self.lora_init_scale = 0.01
        self.lora_modules = r".*SelfAttention|.*EncDecAttention"
        self.lora_layers = r"q|k|v|o"
        self.trainable_param_names = r".*layer_norm.*|.*lora_[ab].*|.*classifier.*|.*block\.23.*"  # Unfreeze head & last block
        self.lora_scaling_rank = 1

class LoRALinear(nn.Module):
    def __init__(self, linear_layer, rank, scaling_rank, init_scale):
        super().__init__()
        self.in_features = linear_layer.in_features
        self.out_features = linear_layer.out_features
        self.rank = rank
        self.scaling_rank = scaling_rank
        self.weight = linear_layer.weight
        self.bias = linear_layer.bias
        if self.rank > 0:
            self.lora_a = nn.Parameter(torch.randn(rank, linear_layer.in_features) * init_scale)
            self.lora_b = nn.Parameter(torch.zeros(linear_layer.out_features, rank))
        if self.scaling_rank:
            self.multi_lora_a = nn.Parameter(
                torch.ones(self.scaling_rank, linear_layer.in_features)
                + torch.randn(self.scaling_rank, linear_layer.in_features) * init_scale
            )
            self.multi_lora_b = nn.Parameter(torch.ones(linear_layer.out_features, self.scaling_rank))

    def forward(self, input):
        weight = self.weight
        if self.scaling_rank:
            weight = weight * torch.matmul(self.multi_lora_b, self.multi_lora_a) / self.scaling_rank
        if self.rank:
            weight = weight + torch.matmul(self.lora_b, self.lora_a) / self.rank
        return F.linear(input, weight, self.bias)

def modify_with_lora(transformer, config):
    for m_name, module in dict(transformer.named_modules()).items():
        if re.fullmatch(config.lora_modules, m_name):
            for c_name, layer in dict(module.named_children()).items():
                if re.fullmatch(config.lora_layers, c_name):
                    if isinstance(layer, nn.Linear):
                        setattr(module, c_name,
                            LoRALinear(layer, config.lora_rank, config.lora_scaling_rank, config.lora_init_scale))
    return transformer

class ClassConfig:
    def __init__(self, dropout=0.2, num_labels=1):
        self.dropout_rate = dropout
        self.num_labels = num_labels

class T5EncoderClassificationHead(nn.Module):
    def __init__(self, hidden_size, class_config):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(class_config.dropout_rate)
        self.out_proj = nn.Linear(hidden_size, class_config.num_labels)

    def forward(self, hidden_states):
        # Instead of mean pooling, try [first] token
        pooled = hidden_states[:, 0, :]  # first token ([CLS]-like)
        pooled = self.dropout(pooled)
        pooled = self.dense(pooled)
        pooled = torch.tanh(pooled)
        pooled = self.dropout(pooled)
        logits = self.out_proj(pooled)
        return logits

class T5EncoderForSimpleSequenceRegression(T5PreTrainedModel):
    def __init__(self, encoder_model, config: T5Config, class_config):
        super().__init__(config)
        self.num_labels = class_config.num_labels
        self.config = config
        self.encoder_model = encoder_model
        self.dropout = nn.Dropout(class_config.dropout_rate)
        self.classifier = T5EncoderClassificationHead(config.d_model, class_config)
        self.post_init()
        self.model_parallel = False
        self.device_map = None

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        outputs = self.encoder_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = outputs.last_hidden_state
        logits = self.classifier(hidden_states)
        loss = None
        if labels is not None:
            loss_fct = MSELoss()
            loss = loss_fct(logits.squeeze(), labels.squeeze())
        if not return_dict:
            output = (logits,) + (outputs[1:] if isinstance(outputs, tuple) else ())
            return ((loss,) + output) if loss is not None else output
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=getattr(outputs, "hidden_states", None),
            attentions=getattr(outputs, "attentions", None),
        )

def PT5_regression_model(num_labels=1, half_precision=False):
    if not half_precision:
        encoder_model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
        tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", legacy=False)
    elif half_precision and torch.cuda.is_available():
        tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False, legacy=False)
        encoder_model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16).to(torch.device('cuda'))
    else:
        raise ValueError('Half precision can be run on GPU only.')

    class_config = ClassConfig(num_labels=num_labels)
    class_model = T5EncoderForSimpleSequenceRegression(encoder_model, encoder_model.config, class_config)
    config = LoRAConfig()
    class_model.encoder_model = modify_with_lora(class_model.encoder_model, config)
    # Unfreeze head, last block, and LoRA
    for name, param in class_model.named_parameters():
        if re.fullmatch(config.trainable_param_names, name):
            param.requires_grad = True
        else:
            param.requires_grad = False
    return class_model, tokenizer

def set_seeds(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    set_seed(seed)

def create_dataset(tokenizer, seqs, labels, max_length=128):
    seqs_processed = [seq.replace('O', 'X').replace('B', 'X').replace('U', 'X').replace('Z', 'X') for seq in seqs]
    seqs_spaced = [" ".join(list(seq)) for seq in seqs_processed]
    tokenized = tokenizer(seqs_spaced, max_length=max_length, padding='max_length', truncation=True)
    dataset = Dataset.from_dict(tokenized)
    dataset = dataset.add_column("labels", labels)
    return dataset

def train_per_protein(
    train_df, 
    valid_df, 
    num_labels=1, 
    batch=4,
    accum=1,
    val_batch=16,
    epochs=10,  # More epochs for higher accuracy
    lr=2e-4,
    wd=0.01,
    seed=42, 
    mixed=False,
):
    set_seeds(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model, tokenizer = PT5_regression_model(num_labels=num_labels, half_precision=mixed)
    model.to(device)
    train_set = create_dataset(tokenizer, list(train_df["sequence"]), list(train_df["label"]), max_length=128)
    valid_set = create_dataset(tokenizer, list(valid_df["sequence"]), list(valid_df["label"]), max_length=128)

    args = TrainingArguments(
    output_dir="./outputs",
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=batch,
    per_device_eval_batch_size=val_batch,
    gradient_accumulation_steps=accum,
    num_train_epochs=epochs,
    weight_decay=wd,
    learning_rate=lr,
    seed=seed,
    fp16=False,
    dataloader_num_workers=2,
    report_to="none",
    save_total_limit=1,
    # Remove all evaluation_strategy, save_strategy, load_best_model_at_end, metric_for_best_model, greater_is_better
    )


    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = np.squeeze(predictions)
        labels = np.squeeze(labels)
        if predictions.ndim == 0:
            predictions = np.array([predictions])
        if labels.ndim == 0:
            labels = np.array([labels])
        correlation = stats.spearmanr(predictions, labels).correlation
        if np.isnan(correlation):
            correlation = 0.0
        return {"spearmanr": correlation}

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_set,
        eval_dataset=valid_set,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )
    print("Starting training...")
    trainer.train()
    print("Training completed!")
    return tokenizer, model, trainer.state.log_history

# ======================== 7. Training Loop (Simple) ==========================
print("Starting PT5 regression training...")

tokenizer, model, history = train_per_protein(
    my_train,
    my_valid,
    num_labels=1,
    batch=4,
    accum=1,
    val_batch=16,
    epochs=10,
    lr=2e-4,
    wd=0.01,
    seed=42,
    mixed=False,
)

# ======================== 8. Plot Results ==========================
loss = [x['loss'] for x in history if 'loss' in x]
val_loss = [x['eval_loss'] for x in history if 'eval_loss' in x]
metric = [x['eval_spearmanr'] for x in history if 'eval_spearmanr' in x]
epochs_list = [x['epoch'] for x in history if 'loss' in x]

if len(loss) > 0:
    fig, ax1 = plt.subplots(figsize=(8, 4))
    ax2 = ax1.twinx()
    line1 = ax1.plot(epochs_list, loss, label='train_loss')
    line2 = ax1.plot(epochs_list, val_loss, label='val_loss') if len(val_loss) > 0 else []
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    line3 = ax2.plot(epochs_list, metric, color='red', label='val_metric') if len(metric) > 0 else []
    ax2.set_ylabel('Spearman r')
    ax2.set_ylim([0, 1])
    lines = line1 + line2 + line3
    labels = [line.get_label() for line in lines]
    ax1.legend(lines, labels, loc='lower left')
    plt.title("Training History")
    plt.show()

# ====================== 9. Save & Reload Model (LoRA weights only) =========================
def save_model(model, filepath):
    non_frozen_params = {n: p for n, p in model.named_parameters() if p.requires_grad}
    torch.save(non_frozen_params, filepath)
    print(f"Model saved to {filepath}")

def load_model(filepath, num_labels=1, mixed=False):
    model, tokenizer = PT5_regression_model(num_labels=num_labels, half_precision=mixed)
    non_frozen_params = torch.load(filepath, map_location='cpu')
    for name, param in model.named_parameters():
        if name in non_frozen_params:
            param.data = non_frozen_params[name].data.clone()
    print(f"Model loaded from {filepath}")
    return tokenizer, model

save_model(model, "./PT5_GB1_finetuned.pth")
tokenizer_reload, model_reload = load_model("./PT5_GB1_finetuned.pth", num_labels=1, mixed=False)

# ======================== 10. Inference on Test Set ==========================
print("Running inference on test set...")

test_set = create_dataset(tokenizer_reload, list(my_test["sequence"]), list(my_test["label"]), max_length=128)
test_set = test_set.with_format("torch")
test_dataloader = DataLoader(test_set, batch_size=8, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_reload.to(device)
model_reload.eval()

predictions = []
with torch.no_grad():
    for batch in tqdm(test_dataloader, desc="Testing"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        logits = model_reload(input_ids=input_ids, attention_mask=attention_mask).logits
        batch_predictions = logits.cpu().numpy().squeeze()
        if batch_predictions.ndim == 0:
            batch_predictions = [batch_predictions.item()]
        else:
            batch_predictions = batch_predictions.tolist()
        predictions.extend(batch_predictions)

true_labels = my_test['label'].values
predictions_arr = np.array(predictions)
test_correlation = stats.spearmanr(predictions_arr, true_labels).correlation
pearson_corr = stats.pearsonr(predictions_arr, true_labels)[0]

print(f"Test Spearman r: {test_correlation:.4f}")
print(f"Test Pearson r:  {pearson_corr:.4f}")
print(f"Variance of predictions: {np.var(predictions_arr):.6f}")

# Print table for manual inspection
print("\nTrue vs Predicted on Test Set:")
print(pd.DataFrame({
    "True": true_labels,
    "Pred": predictions_arr
}).head(20))

# ======================== 11. Plot Test Predictions ==========================
import seaborn as sns

plt.figure(figsize=(6, 6))
sns.scatterplot(x=true_labels, y=predictions_arr, s=80, color='royalblue', edgecolor='k')
plt.plot([true_labels.min(), true_labels.max()],
         [true_labels.min(), true_labels.max()],
         'r--', lw=2, label='Identity (y=x)')
plt.xlabel("True Fitness (Label)")
plt.ylabel("Predicted Fitness")
plt.title(f"Test Set Predictions vs True (Spearman r = {test_correlation:.3f})")
plt.legend()
plt.tight_layout()
plt.show()
