In [None]:
!pip install --no-deps evaluate "protobuf<4.0"

In [None]:
import os   
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_DISABLED"] = "true"

In [None]:
import torch
import numpy as np
import pandas as pd
import evaluate
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset
import traceback

from sklearn.metrics import confusion_matrix
from collections import Counter

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    AutoConfig,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq
)

from peft import (
    get_peft_model,
    LoraConfig,
    TaskType,
    PeftModel
)
import logging
import warnings
import json # For saving log_history if needed

warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
# DEVICE DETECTION 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# CONFIGURATION 
MODEL_NAME = "google/flan-t5-small"
SUMMARIZATION_DATASET = "knkarthick/samsum"

BENCHMARK_GLUE="glue"
GLUE_DATASET_TASK_SC = "sst2"  # SST-2 for sentiment classification

PROGRAM_NAME='ift-lora'
DATASET_SIZE = 'full' # 100 or 500 or 'full' 
# WARNING: DATASET_SIZE=100 is very small and only good for a 'smoke test'.
# The resulting performance will be near random chance and not suitable for a real comparison.
# Please set to 'full' or a larger number (e.g., 5000) for a meaningful benchmark.
RUN_ABLATIONS = False  # Toggle to enable/disable ablation study (modular flag)

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

NUM_VIRTUAL_TOKENS = 50 # CHANGE: Increased from 20 to 50 for better adaptation in prefix/prompt - Why: Longer tokens allow stronger task-specific tuning, fixing weak/flat metrics in prefix/prompt
MAX_POS = 512

OUTPUT_DIR = f'/home/outputs/{PROGRAM_NAME}'
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
print("="*60)
print("LoRA FT")
print("="*60)
print(f"Dataset size: {DATASET_SIZE}")
print(f"Model: {MODEL_NAME} for {SUMMARIZATION_DATASET} and {GLUE_DATASET_TASK_SC}")
print("Methods: LoRA")
if RUN_ABLATIONS:
    print("Ablations Enabled: Including ablated variants for study")
    print("Note: For LoRA ablation, using lora_alpha=0 to nullify adapter effect")
print("="*60)
print()