In [12]:
%load_ext jupyter_black

In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_DISABLED"] = "true"

# KAGGLE TOGGLE 
IS_KAGGLE = os.environ.get('KAGGLE_KERNEL_RUN_TYPE', None) is not None  # Detect if running on Kaggle
print(f"Running on Kaggle: {IS_KAGGLE}")
if IS_KAGGLE:
    os.system('pip install -r /kaggle/input/requirements/requirements.txt')

Running on Kaggle: False


In [2]:
from __future__ import annotations
import argparse, json, math, os, random, re, time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np, pandas as pd, matplotlib.pyplot as plt, evaluate, torch
from datasets import load_dataset
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_recall_fscore_support,
    confusion_matrix,
    matthews_corrcoef,
    mean_squared_error,
    mean_absolute_error,
    r2_score,
)
from transformers import (
    T5ForConditionalGeneration,
    T5TokenizerFast,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)
from peft import (
    get_peft_model,
    LoraConfig,
    PrefixTuningConfig,
    PromptTuningConfig,
    PromptTuningInit,
    TaskType,
)

import logging
import warnings
import json  # For saving log_history if needed

warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# DEVICE DETECTION
DEVICE = (
    torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if IS_KAGGLE
    else (
        torch.device("mps")
        if torch.backends.mps.is_available()
        else torch.device("cuda" if torch.cuda.is_available() else "cpu")
    )
)
print(f"Using device: {DEVICE}")

# CONFIGURATION
MODEL_NAME = "t5-small"  # Switched from flan-t5-small to avoid config dim bug (num_heads=6 mismatch)
SUMMARIZATION_DATASET = "knkarthick/samsum"

BENCHAMARK_GLUE = "glue"
GLUE_DATASET_TASK_SC = "sst2"  # SST-2 for sentiment classification

DATASET_SIZE = 'full'  # or 'full'
RUN_ABLATIONS = False  # Toggle to enable/disable ablation study (modular flag)

RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

NUM_VIRTUAL_TOKENS = 20  # For truncation safety
MAX_POS = 512

Using device: cuda


In [4]:
import torch

In [5]:
print("Number of GPU: ", torch.cuda.device_count())
print("GPU Name: ", torch.cuda.get_device_name())


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Number of GPU:  1
GPU Name:  NVIDIA GeForce RTX 3050 6GB Laptop GPU
Using device: cuda


In [19]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
DEVICE

device(type='cuda')

### Helpers

In [7]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def limit_dataset_size(dataset, size):
    if size == 'full':
        return dataset
    if isinstance(size, int) and size > 0:
        return dataset.select(range(min(size, len(dataset))))
    raise ValueError(f"Invalid size: {size}")


def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)


def count_trainable_parameters(model) -> Tuple[int, int]:
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return trainable, total


def safe_cleanup():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    elif device.type == 'mps':
        torch.mps.empty_cache()

def readable(n: int) -> str:
    if n >= 1e9:
        return f"{n/1e9:.2f}B"
    if n >= 1e6:
        return f"{n/1e6:.2f}M"
    if n >= 1e3:
        return f"{n/1e3:.1f}K"
    return str(n)


def decode(tokenizer, ids) -> str:
    ids = [int(i) for i in ids if i >= 0]
    return tokenizer.decode(ids, skip_special_tokens=True).strip()

### Data Prep

In [8]:
# SST-2 text-to-text label mapping
SST2_LABELS = ["negative", "positive"]

def prepare_sst2(
    tokenizer,
    max_source_len=256,
    max_target_len=8,
    train_size='full',
    eval_size='full',
):
    # Build prompts/targets
    def build_prompt_target_sst2(ex):
        label_txt = "positive" if ex["label"] == 1 else "negative"
        return {
            "prompt": f"sst2: sentence: {ex['sentence']}\nSentiment:",
            "target": label_txt,
        }

    ds = load_dataset(BENCHAMARK_GLUE, GLUE_DATASET_TASK_SC)
    ds = ds.map(build_prompt_target_sst2)

    ds["train"] = limit_dataset_size(ds["train"], train_size)
    ds["validation"] = limit_dataset_size(ds["validation"], eval_size)
    eval_split = "validation"

    def tok_fn(batch):
        inputs = tokenizer(
            batch["prompt"],
            max_length=max_source_len,
            padding="max_length",
            truncation=True,
        )
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(
                batch["target"],
                max_length=max_target_len,
                padding="max_length",
                truncation=True,
            )
        inputs["labels"] = labels["input_ids"]
        return inputs

    tokenized = ds.map(tok_fn, batched=True, remove_columns=ds["train"].column_names)
    return tokenized, eval_split

In [9]:
def prepare_samsum(
    tokenizer,
    max_source_len=768,
    max_target_len=128,
    train_size='full',
    eval_size='full',
):
    def build_prompt_target_samsum(ex):
        return {
            "prompt": f"summarize: {ex['dialogue']}",
            "target": ex["summary"],
        }

    ds = load_dataset(SUMMARIZATION_DATASET)
    ds = ds.map(build_prompt_target_samsum)

    ds["train"] = limit_dataset_size(ds["train"], train_size)
    ds["validation"] = limit_dataset_size(ds["validation"], eval_size)
    eval_split = "validation"

    def tok_fn(batch):
        inputs = tokenizer(
            batch["prompt"],
            max_length=max_source_len,
            padding="max_length",
            truncation=True,
        )
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(
                batch["target"],
                max_length=max_target_len,
                padding="max_length",
                truncation=True,
            )
        inputs["labels"] = labels["input_ids"]
        return inputs

    tokenized = ds.map(tok_fn, batched=True, remove_columns=ds["train"].column_names)
    return tokenized, eval_split

### PEFT

In [10]:
def apply_prompt_tuning(model, tokenizer_name: str, num_virtual_tokens=20, init_text: str | None = None):
    cfg = PromptTuningConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        num_virtual_tokens=num_virtual_tokens,
        prompt_tuning_init=PromptTuningInit.TEXT if init_text else PromptTuningInit.RANDOM,
        prompt_tuning_init_text=init_text,
        tokenizer_name_or_path=tokenizer_name,
    )
    return get_peft_model(model, cfg)

### Metrics & Plots

In [11]:
def sst2_parse_label(text: str) -> str:
    t = text.lower()
    if "positive" in t and "negative" not in t: return "positive"
    if "negative" in t and "positive" not in t: return "negative"
    if t.startswith("pos"): return "positive"
    if t.startswith("neg"): return "negative"
    # fallback neutral heuristic
    return "positive" if "pos" in t else "negative"

def compute_sst2_metrics(preds_text: List[str], refs_text: List[str]) -> Dict[str, float]:
    # Map labels to 0/1
    pred = [1 if sst2_parse_label(p) == "positive" else 0 for p in preds_text]
    ref = [1 if (("positive" in r.lower()) or ("pos" in r.lower())) else 0 for r in refs_text]
    acc = accuracy_score(ref, pred)
    pr, rc, f1, _ = precision_recall_fscore_support(ref, pred, average="binary", zero_division=0)
    return {"accuracy": float(acc), "precision": float(pr), "recall": float(rc), "f1": float(f1)}, ref, pred

def plot_confusion_matrix(y_true, y_pred, out_path: Path, labels=("negative", "positive"), title="Confusion Matrix (SST-2)"):
    from itertools import product
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    fig = plt.figure()
    plt.imshow(cm, interpolation="nearest")
    plt.title(title)
    ticks = np.arange(len(labels))
    plt.xticks(ticks, labels, rotation=45)
    plt.yticks(ticks, labels)
    thresh = cm.max() / 2.0 if cm.max() > 0 else 0.5
    for i, j in product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], "d"), ha="center", va="center",
                 color="white" if cm[i, j] > thresh else "black")
    plt.ylabel("True label")
    plt.xlabel("Predicted label")
    plt.tight_layout()
    fig.savefig(out_path, dpi=150, bbox_inches="tight")
    plt.close(fig)

def plot_loss_curve(trainer, out_path: Path):
    logs = [e for e in trainer.state.log_history if "loss" in e]
    if not logs: return
    steps = [e.get("step", i) for i, e in enumerate(logs)]
    losses = [e["loss"] for e in logs]
    fig = plt.figure()
    plt.plot(steps, losses)
    plt.xlabel("Step")
    plt.ylabel("Training loss")
    plt.title("Loss Curve")
    plt.tight_layout()
    fig.savefig(out_path, dpi=150, bbox_inches="tight")
    plt.close(fig)

def compute_summarization_metrics(preds: List[str], refs: List[str]) -> Dict[str, float]:
    rouge = evaluate.load("rouge")
    sacrebleu = evaluate.load("sacrebleu")
    chrf = evaluate.load("chrf")

    rouge_scores = rouge.compute(predictions=preds, references=refs, use_stemmer=True)
    sacrebleu_score = sacrebleu.compute(predictions=preds, references=[[r] for r in refs])["score"]
    chrf_score = chrf.compute(predictions=preds, references=refs)["score"]

    gen_lens = np.array([len(p.split()) for p in preds], dtype=float)
    ref_lens = np.array([len(r.split()) for r in refs], dtype=float)
    length_ratio = float(np.mean(gen_lens / np.maximum(ref_lens, 1)))
    return {
        "rouge1": float(rouge_scores["rouge1"]),
        "rouge2": float(rouge_scores["rouge2"]),
        "rougeL": float(rouge_scores["rougeL"]),
        "rougeLsum": float(rouge_scores.get("rougeLsum", rouge_scores["rougeL"])),
        "sacrebleu": float(sacrebleu_score),
        "chrf": float(chrf_score),
        "avg_gen_len": float(np.mean(gen_lens)),
        "avg_ref_len": float(np.mean(ref_lens)),
        "length_ratio": length_ratio,
    }

def plot_length_hist(lengths, out_path: Path, title="Generated Summary Lengths (words)"):
    fig = plt.figure()
    plt.hist(lengths, bins=30)
    plt.xlabel("Length (words)")
    plt.ylabel("Count")
    plt.title(title)
    plt.tight_layout()
    fig.savefig(out_path, dpi=150, bbox_inches="tight")
    plt.close(fig)

def plot_ratio_hist(ratios, out_path: Path, title="Generated/Reference Length Ratio"):
    fig = plt.figure()
    plt.hist(ratios, bins=30)
    plt.xlabel("Length Ratio")
    plt.ylabel("Count")
    plt.title(title)
    plt.tight_layout()
    fig.savefig(out_path, dpi=150, bbox_inches="tight")
    plt.close(fig)

def plot_scatter(x, y, out_path: Path, xlabel: str, ylabel: str, title: str):
    fig = plt.figure()
    plt.scatter(x, y, alpha=0.5)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.tight_layout()
    fig.savefig(out_path, dpi=150, bbox_inches="tight")
    plt.close(fig)

In [12]:
@dataclass
class RunConfig:
    task: str               # "samsum" or "sst2"
    method: str             # "prompt_tuning", "lora", "full_finetune"
    # Models: summarization must use knkarthick/samsum; classification default t5-small
    model_name: str | None = None
    epochs: int = 1
    batch_size: int = 8
    lr: float = 5e-4
    weight_decay: float = 0.0
    warmup_ratio: float = 0.03
    gradient_accumulation_steps: int = 1
    max_source_len: int = 768
    max_target_len: int = 128
    gen_max_new_tokens: int = 128
    prompt_tokens: int = 20
    prompt_init_text: str | None = None
    lora_r: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    max_train_samples: int | None = None
    seed: int = 42
    train_size: int|'full' = 'full'
    eval_size: int|'full' = 'full'
    output_root: Path = Path("outputs_two_tasks")

def build_model_and_tokenizer(cfg: RunConfig) -> Tuple[T5ForConditionalGeneration, T5TokenizerFast]:
    # Choose model per task if not specified
    model_name = cfg.model_name
    if cfg.task == "samsum" and model_name is None:
        model_name = "knkarthick/samsum"
    if cfg.task == "sst2" and model_name is None:
        model_name = "t5-small"

    tokenizer = T5TokenizerFast.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name).to(DEVICE)

    if cfg.method == "prompt_tuning":
        model = apply_prompt_tuning(model, tokenizer_name=model_name,
                                    num_virtual_tokens=cfg.prompt_tokens,
                                    init_text=cfg.prompt_init_text)
    elif cfg.method == "lora":
        model = apply_lora(model, r=cfg.lora_r, alpha=cfg.lora_alpha, dropout=cfg.lora_dropout)
    elif cfg.method == "full_finetune":
        for p in model.parameters():
            p.requires_grad = True
    else:
        raise ValueError("Unknown method: " + cfg.method)

    return model, tokenizer

def train_and_eval(cfg: RunConfig) -> Dict[str, float]:
    set_seed(cfg.seed)
    model, tokenizer = build_model_and_tokenizer(cfg)

    # Data
    if cfg.task == "samsum":
        tokenized, eval_split = prepare_samsum(
        tokenizer,
        max_source_len=cfg.max_source_len,
        max_target_len=cfg.max_target_len,
        train_size=cfg.train_size,
        eval_size=cfg.eval_size,
    )
    elif cfg.task == "sst2":
        tokenized, eval_split = prepare_sst2(
        tokenizer,
        max_source_len=256,
        max_target_len=8,
        train_size=cfg.train_size,
        eval_size=cfg.eval_size,
    )
    else:
        raise ValueError("task must be 'samsum' or 'sst2'")

    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

    out_dir = cfg.output_root / f"{cfg.task}__{cfg.method}__seed{cfg.seed}"
    ensure_dir(out_dir)

    args = Seq2SeqTrainingArguments(
        output_dir=str(out_dir),
        per_device_train_batch_size=cfg.batch_size,
        per_device_eval_batch_size=max(1, cfg.batch_size // 2),
        num_train_epochs=cfg.epochs,
        learning_rate=cfg.lr,
        weight_decay=cfg.weight_decay,
        warmup_ratio=cfg.warmup_ratio,
        gradient_accumulation_steps=cfg.gradient_accumulation_steps,
        logging_steps=20,
        eval_strategy="epoch",
        save_strategy="no",
        predict_with_generate=True,
        generation_max_length=cfg.gen_max_new_tokens,
        report_to=[],
        bf16=torch.cuda.is_available(),
    )

    trainer = Seq2SeqTrainer(
        model=model,
        args=args,
        train_dataset=tokenized["train"],
        eval_dataset=tokenized[eval_split],
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    # Parameter stats
    trainable, total = count_trainable_parameters(model)
    stats = {
        "task": cfg.task,
        "method": cfg.method,
        "model_name": "t5-small",
        "seed": cfg.seed,
        "device": DEVICE.type,
        "trainable_params": int(trainable),
        "total_params": int(total),
        "trainable_readable": readable(trainable),
        "total_readable": readable(total),
    }

    # Train
    start = time.time()
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
    trainer.train()
    elapsed = time.time() - start
    stats["train_time_sec"] = float(elapsed)
    if torch.cuda.is_available():
        stats["peak_gpu_mem_bytes"] = int(torch.cuda.max_memory_allocated())
        stats["peak_gpu_mem_gb"] = float(stats["peak_gpu_mem_bytes"] / (1024 ** 3))
    else:
        stats["peak_gpu_mem_bytes"] = 0
        stats["peak_gpu_mem_gb"] = 0.0

    # Predictions
    pred_output = trainer.predict(tokenized[eval_split])
    preds_ids = pred_output.predictions[0] if isinstance(pred_output.predictions, tuple) else pred_output.predictions
    decoded_preds = [decode(tokenizer, seq) for seq in preds_ids]
    decoded_refs = [decode(tokenizer, np.array(row["labels"])) for row in tokenized[eval_split]]

    # Task-specific metrics & plots
    if cfg.task == "sst2":
        metrics, y_true, y_pred = compute_sst2_metrics(decoded_preds, decoded_refs)
        stats.update(metrics)
        # Plots
        plot_loss_curve(trainer, out_dir / "loss_curve.png")
        plot_confusion_matrix(y_true, y_pred, out_dir / "confusion_matrix.png")
    else:
        mets = compute_summarization_metrics(decoded_preds, decoded_refs)
        stats.update(mets)
        # Plots
        plot_loss_curve(trainer, out_dir / "loss_curve.png")
        gen_lens = np.array([len(p.split()) for p in decoded_preds], dtype=float)
        ref_lens = np.array([len(r.split()) for r in decoded_refs], dtype=float)
        ratios = gen_lens / np.maximum(ref_lens, 1)
        plot_length_hist(gen_lens, out_dir / "gen_length_hist.png")
        plot_ratio_hist(ratios, out_dir / "length_ratio_hist.png")
        # ROUGE-Lsum vs length ratio (as an example diagnostic)
        # Recompute per-example ROUGE-L if desired; for speed, we correlate length ratio with gen length
        plot_scatter(gen_lens, ratios, out_dir / "length_vs_ratio_scatter.png",
                     xlabel="Generated length (words)", ylabel="Length ratio (gen/ref)",
                     title="Generated Length vs Length Ratio")

    print(stats)
    # Save metrics JSON
    with open(out_dir / "metrics.json", "w") as f:
        json.dump(stats, f, indent=2)

    return stats

In [10]:
def make_comparison_plots(results: List[Dict], out_dir=Path("comparison_two_tasks")):
    ensure_dir(out_dir)
    df = pd.DataFrame(results)
    df.to_csv(out_dir / "summary_by_run.csv", index=False)

    # Per-task bar charts for primary metrics
    # Summarization -> ROUGE-Lsum; Classification -> Accuracy
    for task in ["samsum", "sst2"]:
        sub = df[df["task"] == task]
        if sub.empty: continue
        if task == "samsum":
            primary = "rougeLsum"
            secondary = "sacrebleu"
            title = "SAMSum: ROUGE-Lsum & SacreBLEU by Method"
            ylab = "Score"
        else:
            primary = "accuracy"
            secondary = "f1"
            title = "SST-2: Accuracy & F1 by Method"
            ylab = "Score"

        # Bar chart (primary/secondary)
        x = np.arange(len(sub))
        fig = plt.figure()
        width = 0.35
        plt.bar(x - width/2, sub[primary], width, label=primary)
        if secondary in sub:
            plt.bar(x + width/2, sub[secondary], width, label=secondary)
        plt.xticks(x, sub["method"], rotation=0)
        plt.ylabel(ylab)
        plt.title(title)
        plt.legend()
        plt.tight_layout()
        fig.savefig(out_dir / f"{task}_metrics_bar.png", dpi=150, bbox_inches="tight")
        plt.close(fig)

        # Params vs primary metric (log scale on X)
        fig = plt.figure()
        plt.scatter(sub["trainable_params"], sub[primary])
        for _, r in sub.iterrows():
            plt.annotate(r["method"], (r["trainable_params"], r[primary]))
        plt.xscale("log")
        plt.xlabel("Trainable params (log)")
        plt.ylabel(primary)
        plt.title(f"{task.upper()}: Params vs {primary}")
        plt.tight_layout()
        fig.savefig(out_dir / f"{task}_params_vs_primary.png", dpi=150, bbox_inches="tight")
        plt.close(fig)


def main():
    ap = argparse.ArgumentParser(description="Two-task PEFT Benchmark (SAMSum + SST-2)")
    ap.add_argument("--tasks", type=str, default="samsum,sst2", help="comma-separated from {samsum,sst2}")
    ap.add_argument("--methods", type=str, default="prompt_tuning", help="comma-separated")
    ap.add_argument("--epochs", type=int, default=1)
    ap.add_argument("--batch_size", type=int, default=8)
    ap.add_argument("--lr", type=float, default=5e-4)
    ap.add_argument("--weight_decay", type=float, default=0.0)
    ap.add_argument("--warmup_ratio", type=float, default=0.03)
    ap.add_argument("--gradient_accumulation_steps", type=int, default=1)
    ap.add_argument("--max_source_len", type=int, default=768)
    ap.add_argument("--max_target_len", type=int, default=128)
    ap.add_argument("--gen_max_new_tokens", type=int, default=128)
    ap.add_argument("--prompt_tokens", type=int, default=20)
    ap.add_argument("--prompt_init_text", type=str, default=None)
    ap.add_argument("--lora_r", type=int, default=8)
    ap.add_argument("--lora_alpha", type=int, default=16)
    ap.add_argument("--lora_dropout", type=float, default=0.05)
    ap.add_argument("--max_train_samples", type=int, default=None)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--output_root", type=str, default="outputs_two_tasks")
    ap.add_argument("--plots_only", action="store_true", help="Aggregate existing metrics and make comparison plots")
    args = ap.parse_args()

    tasks = [t.strip() for t in args.tasks.split(",") if t.strip()]
    methods = [m.strip() for m in args.methods.split(",") if m.strip()]

    results = []
    if args.plots_only:
        # Collect existing metrics
        for mf in Path(args.output_root).rglob("metrics.json"):
            with open(mf) as f:
                results.append(json.load(f))
    else:
        for task in tasks:
            for method in methods:
                cfg = RunConfig(
                    task=task,
                    method=method,
                    model_name=None,  # auto: samsum->knkarthick/samsum, sst2->t5-small
                    epochs=args.epochs,
                    batch_size=args.batch_size,
                    lr=args.lr,
                    weight_decay=args.weight_decay,
                    warmup_ratio=args.warmup_ratio,
                    gradient_accumulation_steps=args.gradient_accumulation_steps,
                    max_source_len=args.max_source_len,
                    max_target_len=args.max_target_len if task == "samsum" else 8,
                    gen_max_new_tokens=args.gen_max_new_tokens if task == "samsum" else 4,
                    prompt_tokens=args.prompt_tokens,
                    prompt_init_text=(None if args.prompt_init_text in (None, "", "None") else args.prompt_init_text),
                    lora_r=args.lora_r,
                    lora_alpha=args.lora_alpha,
                    lora_dropout=args.lora_dropout,
                    max_train_samples=args.max_train_samples,
                    seed=args.seed,
                    output_root=Path(args.output_root),
                )
                stats = train_and_eval(cfg)
                results.append(stats)

    # Save summary and make comparison plots
    comp_dir = Path("comparison_two_tasks")
    ensure_dir(comp_dir)
    pd.DataFrame(results).to_csv(comp_dir / "summary_all.csv", index=False)
    make_comparison_plots(results, out_dir=comp_dir)

    # Print compact table
    cols = ["task", "method", "trainable_readable", "train_time_sec"]
    extra = []
    if any(r.get("accuracy") is not None for r in results): extra += ["accuracy", "f1"]
    if any(r.get("rougeLsum") is not None for r in results): extra += ["rougeLsum", "sacrebleu"]
    cols += [c for c in extra if c in results[0].keys()]
    df = pd.DataFrame(results)
    print(df[ [c for c in cols if c in df.columns] ].sort_values(["task", "method"]).to_string(index=False))


In [13]:
cfg_sst2_small = RunConfig(
    task="sst2",
    method="prompt_tuning",
    model_name=MODEL_NAME,
    epochs=5,
    batch_size=8,
    train_size=DATASET_SIZE,
    eval_size=DATASET_SIZE,
)

cfg_sam_small = RunConfig(
    task="samsum",
    method="prompt_tuning",
    model_name=MODEL_NAME,
    epochs=5,
    batch_size=8,
    train_size=DATASET_SIZE,
    eval_size=DATASET_SIZE,
)

In [14]:
res_cls = train_and_eval(cfg_sst2_small)

Map: 100%|██████████| 67349/67349 [00:09<00:00, 6991.63 examples/s]
Map: 100%|██████████| 872/872 [00:00<00:00, 5824.99 examples/s]


Epoch,Training Loss,Validation Loss
1,1.9478,0.571611
2,1.2692,0.358719
3,1.1596,0.253496
4,0.9911,0.183593
5,1.0311,0.164494


{'task': 'sst2', 'method': 'prompt_tuning', 'model_name': 't5-small', 'seed': 42, 'device': 'cuda', 'trainable_params': 20480, 'total_params': 60527104, 'trainable_readable': '20.5K', 'total_readable': '60.53M', 'train_time_sec': 4659.218475818634, 'peak_gpu_mem_bytes': 866616832, 'peak_gpu_mem_gb': 0.8070998191833496, 'accuracy': 0.4908256880733945, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0}


In [15]:
res_cls

{'task': 'sst2',
 'method': 'prompt_tuning',
 'model_name': 't5-small',
 'seed': 42,
 'device': 'cuda',
 'trainable_params': 20480,
 'total_params': 60527104,
 'trainable_readable': '20.5K',
 'total_readable': '60.53M',
 'train_time_sec': 4659.218475818634,
 'peak_gpu_mem_bytes': 866616832,
 'peak_gpu_mem_gb': 0.8070998191833496,
 'accuracy': 0.4908256880733945,
 'precision': 0.0,
 'recall': 0.0,
 'f1': 0.0}

In [16]:
res_sum = train_and_eval(cfg_sam_small)

Map: 100%|██████████| 14731/14731 [00:09<00:00, 1487.02 examples/s]
Map: 100%|██████████| 818/818 [00:00<00:00, 924.64 examples/s]


Epoch,Training Loss,Validation Loss
1,1.9146,0.886336
2,1.5496,0.877559


KeyboardInterrupt: 