# BoA Implementation

## Team Members: Amitesh Shekhar(22B0014), Toshan Achintya Golla(22B2234), Vatsal Melwani (22B0396)

### Importing Necessary Modules

In [None]:
import torch
import torch.nn as nn
from torch.nn.functional import cross_entropy
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from datasets import load_dataset
import numpy as np
from tqdm.auto import tqdm
import gc
import math
import time
import os
import evaluate

### Instantiating the model

In [None]:
# Base model to quantize
MODEL_ID = "facebook/opt-125m"

# Device setup: use GPU if available (Kaggle T4/P100), otherwise CPU
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# BoA / GPTQ calibration and quantization hyperparameters
SEQ_LEN = 2048      # sequence length for each calibration sample
NSAMPLES = 128      # number of calibration samples
Target_Bits = 2     # quantization precision (e.g., 3 -> INT3)
Group_Size = 128    # group size for per-group quantization parameters

# Safety check for GPU
if DEVICE == "cpu":
    print("⚠️ WARNING: You are running on CPU. For speed, enable GPU ")

print(f"Running on {DEVICE} with model {MODEL_ID}")

### Helper Functions

In [None]:
def get_wikitext2(tokenizer, nsamples, seqlen):
    """
    Load calibration and evaluation data from WikiText-2.

    We:
      - build a long tokenized training text
      - slice nsamples windows of length seqlen uniformly at random
      - return those windows as calibration data

    Returns:
      trainloader : list of (inp, tar) pairs, but only inp is used in our code
      testdata    : Hugging Face Dataset for the test split
    """
    traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    testdata  = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    
    # Tokenize the entire training corpus as a single long stream
    trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")

    import random
    random.seed(0)

    trainloader = []
    for _ in range(nsamples):
        # sample a random start index for a window of length seqlen
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen

        inp = trainenc.input_ids[:, i:j]  # shape: (1, seqlen)
        tar = inp.clone()
        # This masking is not used later, but kept for completeness / compatibility
        tar[:, :-1] = -100                # ignore all but the last token
        trainloader.append((inp, tar))

    return trainloader, testdata


def quantize_scalar(x, scale, zero, maxq):
    """
    Elementwise uniform quantization + dequantization:
        q = clamp(round(x / scale) + zero, 0, maxq)
        x_hat = scale * (q - zero)

    x, scale, zero should be broadcastable to the same shape.
    """
    q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
    return scale * (q - zero)


def find_params(w, bits=Target_Bits, group_size=Group_Size):
    """
    Compute simple per-group min-max quantization parameters.

    Args:
      w          : weight matrix (rows, cols)
      bits       : number of bits (e.g. 3 for INT3)
      group_size : number of weights per group along last dimension

    Returns:
      scale : tensor of shape (num_groups, 1)
      zero  : tensor of shape (num_groups, 1)
      maxq  : integer max quantized value (2^bits - 1)
    """
    dev = w.device
    maxq = 2**bits - 1

    shape = w.shape
    if group_size > 0:
        # reshape into (num_groups, group_size)
        w = w.reshape(-1, group_size)

    # per-group min and max
    wmax = torch.amax(w, dim=1, keepdim=True)
    wmin = torch.amin(w, dim=1, keepdim=True)

    # avoid zero dynamic range by enforcing at least 1e-5 in max
    wmax = torch.max(wmax, torch.tensor(1e-5, device=w.device))

    scale = (wmax - wmin) / maxq
    zero = torch.round(-wmin / scale)

    return scale.to(dev), zero.to(dev), maxq

def get_wikitext2(tokenizer, nsamples, seqlen):
    """
    Load calibration and evaluation data from WikiText-2.

    We:
      - build a long tokenized training text
      - slice nsamples windows of length seqlen uniformly at random
      - return those windows as calibration data

    Returns:
      trainloader : list of (inp, tar) pairs, but only inp is used in our code
      testdata    : Hugging Face Dataset for the test split
    """
    traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    testdata  = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    
    # Tokenize the entire training corpus as a single long stream
    trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")

    import random
    random.seed(0)

    trainloader = []
    for _ in range(nsamples):
        # sample a random start index for a window of length seqlen
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen

        inp = trainenc.input_ids[:, i:j]  # shape: (1, seqlen)
        tar = inp.clone()
        # This masking is not used later, but kept for completeness / compatibility
        tar[:, :-1] = -100                # ignore all but the last token
        trainloader.append((inp, tar))

    return trainloader, testdata


def quantize_scalar(x, scale, zero, maxq):
    """
    Elementwise uniform quantization + dequantization:
        q = clamp(round(x / scale) + zero, 0, maxq)
        x_hat = scale * (q - zero)

    x, scale, zero should be broadcastable to the same shape.
    """
    q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
    return scale * (q - zero)


def find_params(w, bits=Target_Bits, group_size=Group_Size):
    """
    Compute simple per-group min-max quantization parameters.

    Args:
      w          : weight matrix (rows, cols)
      bits       : number of bits (e.g. 3 for INT3)
      group_size : number of weights per group along last dimension

    Returns:
      scale : tensor of shape (num_groups, 1)
      zero  : tensor of shape (num_groups, 1)
      maxq  : integer max quantized value (2^bits - 1)
    """
    dev = w.device
    maxq = 2**bits - 1

    shape = w.shape
    if group_size > 0:
        # reshape into (num_groups, group_size)
        w = w.reshape(-1, group_size)

    # per-group min and max
    wmax = torch.amax(w, dim=1, keepdim=True)
    wmin = torch.amin(w, dim=1, keepdim=True)

    # avoid zero dynamic range by enforcing at least 1e-5 in max
    wmax = torch.max(wmax, torch.tensor(1e-5, device=w.device))

    scale = (wmax - wmin) / maxq
    zero = torch.round(-wmin / scale)

    return scale.to(dev), zero.to(dev), maxq

### BoA Quantizer Class (Relaxed BoA: Q/K use attention-aware Hessians)

In [None]:
class BoAQuantizer:
    def __init__(self, layer, bits=None, group_size=None):
        """
        layer      : the nn.Linear layer to be quantized
        bits       : bit-width for this layer (if None, fall back to global Target_Bits)
        group_size : quantization group size (if None, fall back to global Group_Size)
        """
        
        """
        Wrapper around a linear layer to perform:
          - GPTQ-style column quantization for all layers
          - plus BoA-style row coupling for attention Q/K projections
        """
        
        self.layer = layer
        self.dev = self.layer.weight.device

        self.rows = self.layer.weight.data.shape[0]
        self.cols = self.layer.weight.data.shape[1]

        # Store quantization configuration per instance
        self.bits = bits if bits is not None else Target_Bits
        self.group_size = group_size if group_size is not None else Group_Size

        # Column Hessian for GPTQ: H_col = X^T X
        self.H_col = torch.zeros((self.cols, self.cols), device=self.dev, dtype=torch.float32)

        # Row Hessian for BoA (per head): built only for Q/K using other_input
        self.H_row = None
        self.nsamples = 0
        self.is_boa_layer = False  # True only for Q/K, where other_input is provided

    def add_batch(self, inp, other_input=None):
        """
        Accumulate Hessian statistics from a batch of hidden states.

        inp:          input activations to the linear layer (X)
        other_input:  additional activations used for attention-aware Hessians
                      (e.g., K when quantizing Q, and Q when quantizing K)
        """
        # Always work in float32 for Hessian accumulation
        if inp.dtype != torch.float32:
            inp = inp.float()
        if len(inp.shape) == 3:
            # (batch, seq, hidden) -> (batch*seq, hidden)
            inp = inp.reshape((-1, inp.shape[-1]))

        # GPTQ column Hessian H_col += X^T X
        self.H_col += torch.matmul(inp.t(), inp)

        # If we have other_input, this layer is treated as BoA-aware
        if other_input is not None:
            self.is_boa_layer = True

            if other_input.dtype != torch.float32:
                other_input = other_input.float()
            if len(other_input.shape) == 3:
                other_input = other_input.reshape((-1, other_input.shape[-1]))

            # Hard-coded for OPT-125M: 12 heads
            num_heads = 12
            head_dim = self.rows // num_heads

            # Reshape to (tokens, num_heads, head_dim)
            other_reshaped = other_input.view(-1, num_heads, head_dim)

            # Initialize row Hessians per head if needed
            if self.H_row is None:
                self.H_row = [
                    torch.zeros((head_dim, head_dim), device=self.dev, dtype=torch.float32)
                    for _ in range(num_heads)
                ]

            # For each head, accumulate H_row += A^T A style term
            for h in range(num_heads):
                oh = other_reshaped[:, h, :]     # (tokens, head_dim)
                self.H_row[h] += torch.matmul(oh.t(), oh)

        self.nsamples += 1

    def gptq_quantize_block(self, W_f, H_inv, bits=None):
        """
        Quantize a weight block W_f using GPTQ-style algorithm:

        - use simple per-group min-max quantization (find_params)
        - apply column-wise decorrelation update using H_inv
        """

        if bits is None:
            bits = self.bits

        
        Q_f = torch.zeros_like(W_f)
        E_f = torch.zeros_like(W_f)

        rows, cols = W_f.shape

        # Compute per-group scale and zero for all rows combined
        scale, zero, maxq = find_params(W_f, bits=bits, group_size=self.group_size)

        # Reshape scale/zero to (rows, num_groups)
        groups_per_row = cols // self.group_size
        scale = scale.reshape(rows, groups_per_row)
        zero  = zero.reshape(rows, groups_per_row)

        inv_diag = 1.0 / torch.diag(H_inv)

        for i in range(cols):
            w = W_f[:, i]        # i-th column of W_f, shape (rows,)
            d = inv_diag[i]

            # Select the group that this column belongs to
            g_idx = i // self.group_size
            s = scale[:, g_idx]  # per-row scale for this group
            z = zero[:, g_idx]   # per-row zero for this group

            # Uniform quantization to nearest grid point
            q = quantize_scalar(w, s, z, maxq)
            Q_f[:, i] = q

            # GPTQ error term, scaled using H_inv diagonal
            err = (w - q) / d
            E_f[:, i] = err

            # Decorrelate future columns using H_inv
            if i < cols - 1:
                update = torch.matmul(err.unsqueeze(1), H_inv[i, i+1:].unsqueeze(0))
                W_f[:, i+1:] -= update

        return Q_f, E_f

    def execute(self):
        """
        Final quantization entry point.

        Steps:
          1. Build (damped) GPTQ Hessian H_col and invert it.
          2. If layer is not BoA-aware: run GPTQ quantization only.
          3. If BoA-aware (Q/K): compute row Hessians per head and apply
             head-wise BoA row updates on top of GPTQ quantized weights.
        """
        # Copy weights as float32 for numerical stability
        W = self.layer.weight.data.float()
        rows, cols = W.shape

        # GPTQ column Hessian with Tikhonov damping
        H_col = self.H_col
        damp = 0.01 * torch.mean(torch.diag(H_col))
        diag = torch.arange(cols, device=self.dev)
        H_col[diag, diag] += damp

        try:
            L = torch.linalg.cholesky(H_col)
            L_inv = torch.linalg.inv(L)
            H_inv = torch.matmul(L_inv.t(), L_inv)
        except Exception:
            print("Hessian singular! Using Identity for H_inv.")
            H_inv = torch.eye(cols, device=self.dev)

        # If no BoA row Hessian, just do GPTQ
        if not self.is_boa_layer or self.H_row is None:
            Q, _ = self.gptq_quantize_block(W, H_inv)
            self.layer.weight.data = Q.to(self.layer.weight.dtype)
            return

        # Otherwise, use relaxed BoA for Q/K
        num_heads = 12
        head_dim = rows // num_heads
        U_rows = []

        # Build U_row per head from H_row (damped inverse)
        for h in range(num_heads):
            Hr = self.H_row[h]
            damp_r = 0.01 * torch.mean(torch.diag(Hr))
            Hr[torch.arange(head_dim), torch.arange(head_dim)] += damp_r

            try:
                Lr = torch.linalg.cholesky(Hr)
                Lr_inv = torch.linalg.inv(Lr)
                Hr_inv = torch.matmul(Lr_inv.t(), Lr_inv)
                Ur = torch.linalg.cholesky(Hr_inv).t()
                U_rows.append(Ur)
            except Exception:
                U_rows.append(torch.eye(head_dim, device=self.dev))

        Q_final = torch.zeros_like(W)

        # Process each "row index within a head"
        for j in tqdm(range(head_dim), leave=False):
            # These are the actual row indices in W belonging to row j of each head
            indices = [j + h * head_dim for h in range(num_heads)]
            W_j = W[indices, :].clone()

            # GPTQ quantization for stacked rows from all heads
            Q_j, E_j = self.gptq_quantize_block(W_j, H_inv)
            Q_final[indices, :] = Q_j

            # BoA row update within each head, using U_rows
            for h in range(num_heads):
                Ur = U_rows[h]
                diag_val = Ur[j, j]

                if j < head_dim - 1:
                    v = Ur[j+1:, j].unsqueeze(1)  # vector for rows below j
                    err_vec = E_j[h].unsqueeze(0)
                    update_factor = (v / diag_val)
                    target_indices = [(k + h * head_dim) for k in range(j+1, head_dim)]
                    W[target_indices, :] -= torch.matmul(update_factor, err_vec)

        # Write back quantized weights to the original layer, preserving dtype
        self.layer.weight.data = Q_final.to(self.layer.weight.dtype)

### Main Quantization Loop

In [None]:
def run_quantization(bits=None):

    if bits is None:
        bits = Target_Bits

    """
    Run relaxed BoA quantization over all Transformer layers.

    Strategy:
      - Load model in fp16.
      - Capture inputs to the first layer (decoder layer 0).
      - For each decoder layer:
          * move layer to GPU
          * collect Hessian statistics via forward passes
          * run BoAQuantizer.execute() for each linear module
          * propagate outputs to use as inputs for next layer
          * move layer back to CPU
      - Return quantized model on DEVICE.
    """

    print(f"\n[run_quantization] Quantizing with bit-width = {bits}")
    
    # Clean memory before starting
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    print("Loading model...")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    model.eval()  # always eval mode during quantization
    
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    
    print("Loading calibration data...")
    trainloader, _ = get_wikitext2(tokenizer, NSAMPLES, SEQ_LEN)
    
    # Hook to capture inputs into the first decoder layer
    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            if isinstance(inp, tuple):
                inp = inp[0]
            # store activations as fp16 on CPU for memory efficiency
            inps.append(inp.detach().to("cpu", dtype=torch.float16))
            raise ValueError  # stop the forward after capturing

    layers = model.model.decoder.layers
    
    # Keep embeddings on GPU
    model.model.decoder.embed_tokens    = model.model.decoder.embed_tokens.to(DEVICE)
    model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(DEVICE)
    
    inps = []
    # Replace first layer with catcher, temporarily
    layers[0] = Catcher(layers[0])
    
    print("Capturing inputs...")
    with torch.no_grad():
        for batch in tqdm(trainloader, desc="Capturing"):
            try:
                model(batch[0].to(DEVICE))
            except ValueError:
                # Raised intentionally by Catcher to break after first layer
                pass

    # Restore original first layer
    layers[0] = layers[0].module  
    
    # Stack captured activations: shape (total_tokens, hidden_size)
    inps = torch.cat(inps, dim=0).view(-1, inps[0].shape[-1])
    print(f"Captured Inputs Shape (tokens x hidden): {inps.shape}")

    print("Starting Quantization...")
    for i in tqdm(range(len(layers)), desc="Quantizing Layers"):
        layer = layers[i]
        # Move current layer to GPU
        layer = layer.to(DEVICE)
        
        # Submodules to quantize in each transformer block
        subset = {
            "self_attn.q_proj":  layer.self_attn.q_proj,
            "self_attn.k_proj":  layer.self_attn.k_proj,
            "self_attn.v_proj":  layer.self_attn.v_proj,
            "self_attn.out_proj": layer.self_attn.out_proj,
            "fc1":               layer.fc1,
            "fc2":               layer.fc2,
        }
        
        # Wrap each linear submodule with BoAQuantizer
                # CHANGED: pass bits and group_size into each BoAQuantizer
        quantizers = {
            name: BoAQuantizer(m, bits=bits, group_size=Group_Size)
            for name, m in subset.items()
        }
        
        # Hook to collect Hessian stats for each submodule
        def add_batch(name):
            def tmp(_, inp, out):
                # inp[0] is the hidden state for that module
                quantizers[name].add_batch(inp[0].data)
            return tmp

        handles = [m.register_forward_hook(add_batch(n)) for n, m in subset.items()]
        
        # Compute how many batches we can form from captured activations
        num_batches = inps.shape[0] // SEQ_LEN
        
        # First forward pass: collect GPTQ H_col for all modules,
        # and BoA H_row for Q/K using X_norm, Q, K.
        with torch.no_grad():
            for j in range(num_batches):
                batch_inp = inps[j*SEQ_LEN:(j+1)*SEQ_LEN].unsqueeze(0).to(DEVICE)  # (1, seqlen, hidden)
                layer(batch_inp)

                # LayerNorm output used to compute Q/K BoA Hessians
                X_norm = layer.self_attn_layer_norm(batch_inp)
                Q_mat  = layer.self_attn.q_proj(X_norm)
                K_mat  = layer.self_attn.k_proj(X_norm)
                
                quantizers["self_attn.q_proj"].add_batch(X_norm, other_input=K_mat)
                quantizers["self_attn.k_proj"].add_batch(X_norm, other_input=Q_mat)
                
                # Free intermediates
                del X_norm, Q_mat, K_mat, batch_inp
        
        # Remove hooks after stats collection
        for h in handles:
            h.remove()
        
        # Execute quantization for each module (internally uses BoA/GPTQ)
        for name, quantizer in quantizers.items():
            quantizer.execute()
            
        # Free quantizers and Hessians
        del quantizers
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # Second forward pass: propagate quantized outputs as new inputs
        outs = []
        with torch.no_grad():
            for j in range(num_batches):
                batch_inp = inps[j*SEQ_LEN:(j+1)*SEQ_LEN].unsqueeze(0).to(DEVICE)
                out = layer(batch_inp)[0]  # (1, seqlen, hidden)
                outs.append(out.cpu())
                del batch_inp, out
        
        # Replace inps with concatenated outputs for next layer
        del inps
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        if len(outs) > 0:
            # combine along the sequence dimension
            inps = torch.cat(outs, dim=1).squeeze(0)  # (seqlen, hidden)
            # flatten back to (tokens, hidden)
            inps = inps.reshape(-1, inps.shape[-1])
        
        # Move quantized layer back to CPU to save VRAM
        layers[i] = layer.cpu()
        del layer
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    # Once all layers are quantized, move entire model back to DEVICE
    model.to(DEVICE)
    torch.cuda.empty_cache()
    gc.collect()
    return model

### Evaluation Utilities (Perplexity, Saving, and Generation)

In [None]:
# Evaluate WikiText PPL
def compute_ppl(model, tokenizer, desc="PPL Eval"):
    """
    Computes perplexity (PPL) on WikiText-2 using sliding-window evaluation.

    The model is run over the test split in chunks of length `stride`,
    and loss is computed by shifting labels while masking unused tokens.
    """
    print(f"\n--- {desc} on WikiText-2 ---")

    # Load WikiText-2 test split
    testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

    # Tokenize entire dataset at once (stored on CPU)
    enc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')

    seq_len = enc.input_ids.size(1)
    stride = 512
    max_len = model.config.max_position_embeddings  # usually 2048 for OPT
    nlls = []

    model.eval()
    with torch.no_grad():
        # Slide model forward in windows of length `stride`
        for i in tqdm(range(0, seq_len, stride), desc="PPL Progress"):
            j = min(i + max_len, seq_len)

            # Input tokens for this window
            inp = enc.input_ids[:, i:j].to(DEVICE)

            # Labels: mask out all tokens except last (causal LM style)
            tgt = inp.clone()
            tgt[:, :- (j - i)] = -100  # ignore all but final shifted tokens

            # Compute loss and accumulate it
            loss = model(inp, labels=tgt).loss
            nlls.append(loss)

    # Convert negative log-likelihood to perplexity
    ppl = torch.exp(torch.stack(nlls).mean())
    print("PPL =", float(ppl))
    return ppl


# Evaluate ARC-QA (Easy / Challenge)
def evaluate_arc(model, tokenizer, subset="ARC-Easy", max_samples=200):
    """
    Zero-shot ARC evaluation using per-choice log-probabilities.

    Each question has multiple choices; we evaluate the loss for each
    (question + choice) pair and choose the answer with the highest score.
    """
    print(f"\n--- ARC Eval: {subset} ---")

    # Load the chosen ARC subset (Easy or Challenge)
    ds = load_dataset("ai2_arc", subset, split="test")

    # Limit dataset for faster eval
    ds = ds.select(range(min(max_samples, len(ds))))
    correct = 0

    model.eval()
    with torch.no_grad():
        for item in tqdm(ds, desc="ARC"):
            q = item["question"]
            choices = item["choices"]["text"]
            labels  = item["choices"]["label"]
            ans = item["answerKey"]

            # Compute score per choice = negative loss of (question + choice)
            scores = []
            for choice in choices:
                enc = tokenizer(q + " " + choice, return_tensors="pt").to(DEVICE)
                loss = model(**enc, labels=enc.input_ids).loss.item()
                scores.append(-loss)  # higher = better

            # Pick the highest-scoring choice
            pred = labels[np.argmax(scores)]
            if pred == ans:
                correct += 1

    acc = correct / len(ds)
    print("Accuracy =", acc)
    return acc, correct, len(ds)

# Evaluate SQuAD-v1.1 Exact-Match & F1
def evaluate_squad(model, tokenizer, max_samples=200):
    """
    Runs a lightweight question-answering pipeline over SQuAD v1.1
    and computes F1 using HuggingFace 'evaluate'.

    Uses model.generate inside the QA pipeline to extract answer spans.
    """
    print("\n--- SQuAD Evaluation (F1) ---")

    squad = load_dataset("squad", split="validation")
    squad = squad.select(range(min(max_samples, len(squad))))

    # Metric from HuggingFace evaluate package
    metric = evaluate.load("squad")

    # Create QA pipeline
    # (Do NOT pass device when model loaded via accelerate)
    qa = pipeline(
        "question-answering",
        model=model,
        tokenizer=tokenizer,
        top_k=1,
        max_answer_len=50,
    )

    predictions = []
    references = []

    for item in tqdm(squad, desc="SQuAD"):
        question = item["question"]
        context = item["context"]
        gold_text = item["answers"]["text"][0]
        gold_start = item["answers"]["answer_start"][0]

        # Try to extract predicted answer
        try:
            out = qa(question=question, context=context)
            pred_text = out["answer"]
        except Exception:
            pred_text = ""  # fallback for pipeline runtime errors

        predictions.append({"id": item["id"], "prediction_text": pred_text})
        references.append({
            "id": item["id"],
            "answers": {"text": [gold_text], "answer_start": [gold_start]},
        })

    # Compute F1
    scores = metric.compute(predictions=predictions, references=references)
    print("SQuAD-F1 =", scores["f1"])
    return scores["f1"], len(squad)


# CNN/DailyMail Summarization (BLEU + ROUGE-L)
def evaluate_cnn_dm(model, tokenizer, max_samples=50):
    """
    Evaluates summarization quality on CNN/DailyMail using HF pipeline.

    BLEU is computed using tokenized hypotheses and references.
    ROUGE-L is computed using raw strings.
    """
    print("\n--- CNN/DailyMail Summarization ---")

    ds = load_dataset("cnn_dailymail", "3.0.0", split="validation")
    ds = ds.select(range(min(max_samples, len(ds))))

    bleu = evaluate.load("bleu")
    rouge = evaluate.load("rouge")

    # Summarization pipeline
    summarizer = pipeline(
        "summarization",
        model=model,
        tokenizer=tokenizer,
    )

    preds_str = []        # raw model summaries
    refs_str = []         # raw gold summaries
    bleu_hypotheses = []  # tokenized predictions
    bleu_references = []  # tokenized references

    for item in tqdm(ds, desc="Summarization"):
        # Limit article length to avoid excessive generation time
        article = item["article"][:1500]
        reference = item["highlights"]

        try:
            out = summarizer(
                article,
                max_new_tokens=128,
                do_sample=False,
                truncation=True,
                pad_token_id=tokenizer.eos_token_id,
            )
            pred = out[0]["summary_text"]
        except:
            pred = ""

        # Store for ROUGE
        preds_str.append(pred)
        refs_str.append(reference)

        # Tokenize for BLEU
        bleu_hypotheses.append(pred.strip().split())
        bleu_references.append([reference.strip().split()])

    # Convert lists to BLEU-compatible string format
    bleu_predictions = [" ".join(p) for p in bleu_hypotheses]
    bleu_refs = [[ " ".join(r) ] for ref in bleu_references for r in ref]

    # Compute BLEU & ROUGE-L
    bleu_score = bleu.compute(
        predictions=bleu_predictions,
        references=bleu_refs
    )["bleu"]

    rouge_score = rouge.compute(
        predictions=preds_str,
        references=refs_str
    )["rougeL"]

    print("BLEU:", bleu_score)
    print("ROUGE-L:", rouge_score)
    return bleu_score, rouge_score, len(ds)

### Saving and Checking the Quantized Model

In [None]:
# saving the quantized model
def save_quantized_model(model, tokenizer, out_dir="quantized_model"):
    """
    Save the quantized model and tokenizer in Hugging Face format.
    """
    os.makedirs(out_dir, exist_ok=True)
    print(f"Saving quantized model to: {out_dir}")
    model.save_pretrained(out_dir)
    tokenizer.save_pretrained(out_dir)

# Sample Text Generation
def test_generation(model, tokenizer, prompt="What would humans do if aliens attacked the earth?"):
    """
    Simple text generation demo to qualitatively inspect the quantized model.
    """
    print("\nGeneration demo (quantized model):")
    model.eval()
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=64,                      
            do_sample=True,
            top_p=0.95,
            temperature=0.8,
            num_return_sequences=1,
            pad_token_id=model.config.eos_token_id 
        )
    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(text)

### Full-Precision vs Quantized Model Comparison

In [None]:
# 1) Load FP16 baseline model and tokenizer
print("\n=== Baseline: Loading FP16 Full-Precision Model ===")
fp_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto"
)
fp_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# 2) Evaluate FP16 on WikiText-2 PPL
fp_ppl = compute_ppl(fp_model, fp_tokenizer, desc="FP16 Baseline")

# 3) Evaluate FP16 on ARC-Easy and ARC-Challenge
fp_arc_easy_acc, fp_arc_easy_correct, fp_arc_easy_total = evaluate_arc(
    fp_model, fp_tokenizer, subset="ARC-Easy", max_samples=200
)
fp_arc_ch_acc, fp_arc_ch_correct, fp_arc_ch_total = evaluate_arc(
    fp_model, fp_tokenizer, subset="ARC-Challenge", max_samples=200
)

# 4) Evaluate FP16 on SQuAD (F1)
fp_squad_f1, fp_squad_n = evaluate_squad(
    fp_model, fp_tokenizer, max_samples=200
)

# 5) Evaluate FP16 on CNN/DailyMail (BLEU + ROUGE-L)
fp_bleu, fp_rougeL, fp_cnn_n = evaluate_cnn_dm(
    fp_model, fp_tokenizer, max_samples=50
)

# 6) Compute theoretical FP16 model size (in MB)
total_params = sum(p.numel() for p in fp_model.parameters())
fp16_size_mb = total_params * 16 / 8 / 1e6  # 16 bits => 2 bytes

print(f"\nFP16 Model Parameters: {total_params/1e6:.2f} M")
print(f"Theoretical FP16 Model Size: {fp16_size_mb:.2f} MB")

# 7) Free FP model before heavy quantization to save VRAM
del fp_model
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# 8) Run quantization for current Target_Bits
print(f"\n=== Running BoA Quantization (INT{Target_Bits}) ===")
q_model = run_quantization(bits=Target_Bits)
q_tokenizer = fp_tokenizer  # same tokenizer

# 9) Evaluate quantized model on WikiText-2 PPL
q_ppl = compute_ppl(q_model, q_tokenizer, desc=f"INT{Target_Bits} Quantized")

# 10) Evaluate quantized model on ARC-Easy and ARC-Challenge
q_arc_easy_acc, q_arc_easy_correct, q_arc_easy_total = evaluate_arc(
    q_model, q_tokenizer, subset="ARC-Easy", max_samples=200
)
q_arc_ch_acc, q_arc_ch_correct, q_arc_ch_total = evaluate_arc(
    q_model, q_tokenizer, subset="ARC-Challenge", max_samples=200
)

# 11) Evaluate quantized model on SQuAD (F1)
q_squad_f1, q_squad_n = evaluate_squad(
    q_model, q_tokenizer, max_samples=200
)

# 12) Evaluate quantized model on CNN/DailyMail (BLEU + ROUGE-L)
q_bleu, q_rougeL, q_cnn_n = evaluate_cnn_dm(
    q_model, q_tokenizer, max_samples=50
)

# 13) Theoretical quantized size assuming pure INT{Target_Bits} storage
quant_bits = Target_Bits
quant_size_mb = total_params * quant_bits / 8 / 1e6
compression_ratio = fp16_size_mb / quant_size_mb

print("\n=== Size and Compression Summary ===")
print(f"INT{Target_Bits} Theoretical Size: {quant_size_mb:.2f} MB")
print(f"Theoretical Compression Ratio (FP16 -> INT{Target_Bits}): {compression_ratio:.2f}x")

# 14) Save quantized model
out_dir = f"boa_opt125m_int{Target_Bits}"
save_quantized_model(q_model, q_tokenizer, out_dir=out_dir)

# 15) Quick sanity check: generate some text from quantized model
test_generation(q_model, q_tokenizer)

### Final summary of all metrics

In [None]:
print("\n=== Final Metrics Summary ===")
print(f"FP16 PPL:                 {fp_ppl:.4f}")
print(f"INT{Target_Bits} PPL:           {q_ppl:.4f}")

print(f"\nFP16 ARC-Easy:            acc={fp_arc_easy_acc:.4f}  ({fp_arc_easy_correct}/{fp_arc_easy_total})")
print(f"FP16 ARC-Challenge:       acc={fp_arc_ch_acc:.4f}   ({fp_arc_ch_correct}/{fp_arc_ch_total})")
print(f"INT{Target_Bits} ARC-Easy:      acc={q_arc_easy_acc:.4f}  ({q_arc_easy_correct}/{q_arc_easy_total})")
print(f"INT{Target_Bits} ARC-Challenge: acc={q_arc_ch_acc:.4f}   ({q_arc_ch_correct}/{q_arc_ch_total})")

print(f"\nFP16 SQuAD-F1:            {fp_squad_f1:.4f}  (N={fp_squad_n})")
print(f"INT{Target_Bits} SQuAD-F1:      {q_squad_f1:.4f}  (N={q_squad_n})")

print(f"\nFP16 CNN-DM BLEU:         {fp_bleu:.4f}  (N={fp_cnn_n})")
print(f"FP16 CNN-DM ROUGE-L:      {fp_rougeL:.4f}")
print(f"INT{Target_Bits} CNN-DM BLEU:   {q_bleu:.4f}  (N={q_cnn_n})")
print(f"INT{Target_Bits} CNN-DM ROUGE-L:{q_rougeL:.4f}")