# HyFlexPIM BERT Notebook

This notebook uses functions and classes from `hyflex_utils.py`.

In [None]:
from transformers import AutoTokenizer, LlamaForCausalLM, DataCollatorForLanguageModeling, get_scheduler, LlamaTokenizerFast, DataCollatorWithPadding
from transformers import BertTokenizer, BertForSequenceClassification, DataCollatorWithPadding
from datasets import Value
from datasets import load_dataset
from accelerate import Accelerator
from torch.utils.data import DataLoader
from itertools import chain
import torch
import math
from tqdm import tqdm
import os
from torch import nn
import torch.nn.functional as F
from peft import get_peft_model, LoraConfig, TaskType
import torch.nn as nn
import time
import copy
from hyflex_utils import *  # Import all necessary functions/classes
import matplotlib.pyplot as plt

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 
torch.cuda.set_device(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES"))
print("Available GPUs:", torch.cuda.device_count())
print('Device:', device)
print('Current cuda device:', torch.cuda.current_device())
print('Count of using GPUs:', torch.cuda.device_count())

Step 1. Load the models (Run all)

In [None]:
# ========== Configuration ==========
model_ = "bert-large-uncased"
# model_ = "bert-base-uncased" 
task = "mrpc"  # Change to "sst2", "cola", "qnli", "qqp", "mrpc", "stsb", "rte"

cfg = {
    "seed": 42,
    "batch_size": 16,
    "lr": 2e-5,
    "n_epochs": 3,
    "warmup": 0.1,
    "save_steps": 100,
}

torch.manual_seed(cfg["seed"])
accelerator = Accelerator()

# ========== Tokenizer & Model ==========
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_)

# Task-specific num_labels
task_to_num_labels = {
    "sst2": 2,
    "cola": 2,
    "qnli": 2,
    "qqp": 2,
    "mrpc": 2,
    "rte": 2,
    "stsb": 1,  # Regression
}

model = AutoModelForSequenceClassification.from_pretrained(
    model_,
    num_labels=task_to_num_labels[task],
    problem_type="regression" if task == "stsb" else "single_label_classification"
)

# ========== Preprocessing Functions ==========
def sst2_preprocess(examples):
    return tokenizer(examples["sentence"], padding="max_length", truncation=True, max_length=128)

def cola_preprocess(examples):
    return tokenizer(examples["sentence"], padding="max_length", truncation=True, max_length=128)

def qnli_preprocess(examples):
    return tokenizer(examples["question"], examples["sentence"], padding="max_length", truncation=True, max_length=128)

def qqp_preprocess(examples):
    return tokenizer(examples["question1"], examples["question2"], padding="max_length", truncation=True, max_length=128)

def mrpc_preprocess(examples):
    return tokenizer(examples["sentence1"], examples["sentence2"], padding="max_length", truncation=True, max_length=128)

def stsb_preprocess(examples):
    return tokenizer(examples["sentence1"], examples["sentence2"], padding="max_length", truncation=True, max_length=128)

def rte_preprocess(examples):
    return tokenizer(examples["sentence1"], examples["sentence2"], padding="max_length", truncation=True, max_length=128)

preprocess_functions = {
    "sst2": sst2_preprocess,
    "cola": cola_preprocess,
    "qnli": qnli_preprocess,
    "qqp": qqp_preprocess,
    "mrpc": mrpc_preprocess,
    "stsb": stsb_preprocess,
    "rte": rte_preprocess,
}

# ========== Load and Preprocess Dataset ==========


dataset = load_dataset("glue", task)

train_dataset = dataset["train"].map(preprocess_functions[task], batched=True)
val_dataset = dataset["validation"].map(preprocess_functions[task], batched=True)
test_dataset = dataset.get("test", val_dataset).map(preprocess_functions[task], batched=True)  # fallback if no test set


if task == "stsb":
    train_dataset = train_dataset.cast_column("label", Value("float32"))
    val_dataset = val_dataset.cast_column("label", Value("float32"))
    test_dataset = test_dataset.cast_column("label", Value("float32"))
else:
    train_dataset = train_dataset.cast_column("label", Value("int64"))
    val_dataset = val_dataset.cast_column("label", Value("int64"))
    test_dataset = test_dataset.cast_column("label", Value("int64"))


train_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
val_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
test_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])

# ========== DataLoaders ==========
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=cfg["batch_size"])
eval_dataloader = DataLoader(val_dataset, shuffle=False, collate_fn=data_collator, batch_size=cfg["batch_size"])
test_dataloader = DataLoader(test_dataset, shuffle=False, collate_fn=data_collator, batch_size=cfg["batch_size"])

# ========== Prepare with Accelerator ==========
model, train_dataloader, eval_dataloader = accelerator.prepare(model, train_dataloader, eval_dataloader)
model.to(torch.float32)

print(f"✔ Loaded task: {task}, train size: {len(train_dataset)}, val size: {len(val_dataset)}")


In [None]:
# Inference (You can skip this part)

evaluate_model_bert(model, tokenizer, eval_dataloader, task=task)


Step 2. Train the original model

In [None]:
# Training original model (We are going to apply SVD on the best trained model)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5) # you can change the lr, 5e-5 is default
train_model_bert(model, tokenizer, optimizer, train_dataloader, eval_dataloader, accelerator, epochs=3, grad_accum_steps=1)

In [None]:
#  Save the model (Not necessary, but recommend to store your best model)

save_dir = f'./model_{task}'
os.makedirs(save_dir, exist_ok=True)
torch.save(model.state_dict(), f'./model_{task}/finetuned_best')
print("Model weights saved successfully!")

In [None]:
#  Load the model (Optional)
# model.load_state_dict(torch.load(f'./model_{task}/finetuned_best'))

Step 3. SVD decomposition

In [None]:
# Replace linear layer with SVD decomposed & traniner layer

replace_linear_layer_bert(model)

In [None]:
# Inference (To check accuarcy degradation after svd, not necessary)

evaluate_model_bert(model, tokenizer, eval_dataloader, task=task)

Step 4. Fine tuning & Gradient redistribution

In [None]:
#  You can change lr
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
model, train_dataloader, eval_dataloader = accelerator.prepare(model, train_dataloader, eval_dataloader)


#  Trainable Parameters 
trainable_params = [name for name, param in model.named_parameters() if param.requires_grad]
print(f" Trainable Parameters Count: {len(trainable_params)}")

In [None]:
# Training SVD-ed model (w/ gradient redistribution)

train_model_bert_gradient_saving(model, tokenizer, optimizer, train_dataloader, eval_dataloader, accelerator, epochs=3, grad_accum_steps=1)

In [None]:
# Save the trained model!

torch.save(model.state_dict(), f'./model_{task}/finetuned_best_after_svd')

print("Model weights after svd saved successfully!")

In [None]:
#  Load model (optional)

# model.load_state_dict(torch.load(f'./model_{task}/finetuned_best_after_svd'))

Step 5. Noise Injection simulation
 : You might change list of thresholds (which is SLC rate);
 e.g., thresholds = [0, 20] --> Simulation for 0% of SLC rate, and 20% of SLC rate --> generating plot figures! 

In [None]:
# Inference with noise injected model

std = 0.025
thresholds = [0, 5, 10, 30, 40, 50, 100]
model_path = f'./model_{task}/finetuned_best_after_svd'

accuracies = []

for th in thresholds:
    print(f"\n==== [Threshold: {th}%] ====")
    model.load_state_dict(torch.load(model_path))
    model.to("cuda" if torch.cuda.is_available() else "cpu")
    load_gradients_bert(model)
    model = apply_noise_to_bert(model, std, th)
    acc = evaluate_model_bert(model, tokenizer, eval_dataloader, task=task)
    accuracies.append(acc)
    print(f"[Threshold={th}%] Validation Accuracy: {acc:.4f}")


accuracy_percent = [a * 100 for a in accuracies]

plt.figure(figsize=(8, 5))
plt.bar([str(t) + "%" for t in thresholds], accuracy_percent, color='skyblue')
plt.title("Noise Injection Threshold vs Validation Accuracy (BERT)")
plt.xlabel("Noise Injection Threshold (%)")
plt.ylabel("Validation Accuracy (%)")
plt.ylim(0, 100)
plt.grid(axis='y')
plt.tight_layout()

