In [34]:
!pip -q install -U "transformers>=4.44" "datasets>=2.20" "peft>=0.12" accelerate evaluate rouge-score scikit-learn sentencepiece "pyarrow<20.0.0a0"

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [35]:
# Imports and config
import os
import re
from tqdm.auto import tqdm
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
import evaluate

## Load model and tokenizer
We'll use the small FLAN-T5 model to keep things light.
- Tokenizer converts text ↔ tokens
- Model generates outputs given the tokens

In [36]:
MODEL_NAME = "google/flan-t5-small"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("device:", DEVICE)

device: cuda


In [37]:
print("Loading model and tokenizer... This may take a minute")
from transformers import logging as hf_logging
hf_logging.set_verbosity_error()

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()

Loading model and tokenizer... This may take a minute


T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=384, bias=False)
              (k): Linear(in_features=512, out_features=384, bias=False)
              (v): Linear(in_features=512, out_features=384, bias=False)
              (o): Linear(in_features=384, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 6)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=512, out_features=1024, bias=False)
              (wi_1): Linear(in_features=512, out_features=1024, bias=False)
              (wo): 

In [38]:
print(model.config )

T5Config {
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "classifier_dropout": 0.0,
  "d_ff": 1024,
  "d_kv": 64,
  "d_model": 512,
  "decoder_start_token_id": 0,
  "dense_act_fn": "gelu_new",
  "dropout_rate": 0.1,
  "dtype": "float32",
  "eos_token_id": 1,
  "feed_forward_proj": "gated-gelu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 8,
  "num_heads": 6,
  "num_layers": 8,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 200,
      "min_length": 30,
      "no_repeat_ngram_size": 3,
      "num_beams": 4,
      "prefix": "summarize: "
    },
    "translation_en_to_de": {
      "early_stopping": true,
      "max_length": 300,
     

In [39]:
print(f"Hidden size (d_model): {model.config.d_model}")  
print(f"Encoder layers: {model.config.num_layers}")      
print(f"Decoder layers: {model.config.num_decoder_layers}")  

print(f"Number of attention heads: {model.config.num_heads}")
print(f"Key-value dimension per head: {model.config.d_kv}")
print(f"Total Q/K/V dimension: {model.config.num_heads * model.config.d_kv}")

Hidden size (d_model): 512
Encoder layers: 8
Decoder layers: 8
Number of attention heads: 6
Key-value dimension per head: 64
Total Q/K/V dimension: 384


In [40]:
# See all parameter names
for name, param in model.named_parameters():
    if 'SelfAttention' in name and 'q' in name:
        print(f"{name}: {param.shape}")
        

encoder.block.0.layer.0.SelfAttention.q.weight: torch.Size([384, 512])
encoder.block.1.layer.0.SelfAttention.q.weight: torch.Size([384, 512])
encoder.block.2.layer.0.SelfAttention.q.weight: torch.Size([384, 512])
encoder.block.3.layer.0.SelfAttention.q.weight: torch.Size([384, 512])
encoder.block.4.layer.0.SelfAttention.q.weight: torch.Size([384, 512])
encoder.block.5.layer.0.SelfAttention.q.weight: torch.Size([384, 512])
encoder.block.6.layer.0.SelfAttention.q.weight: torch.Size([384, 512])
encoder.block.7.layer.0.SelfAttention.q.weight: torch.Size([384, 512])
decoder.block.0.layer.0.SelfAttention.q.weight: torch.Size([384, 512])
decoder.block.1.layer.0.SelfAttention.q.weight: torch.Size([384, 512])
decoder.block.2.layer.0.SelfAttention.q.weight: torch.Size([384, 512])
decoder.block.3.layer.0.SelfAttention.q.weight: torch.Size([384, 512])
decoder.block.4.layer.0.SelfAttention.q.weight: torch.Size([384, 512])
decoder.block.5.layer.0.SelfAttention.q.weight: torch.Size([384, 512])
decode

In [41]:
# Total parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")  # 76,961,152

# trainable parameters
trainable = sum(p.numel() for p in model.parameters() 
                  if p.requires_grad)

print(f"trainable parameters: {trainable:,}")  # ~6,144,512

Total parameters: 76,961,152
trainable parameters: 76,961,152


In [42]:
# Check a specific attention layer
encoder_attn = model.encoder.block[0].layer[0].SelfAttention

print("Query weight shape:", encoder_attn.q.weight.shape)  # (384, 512)
print("Key weight shape:", encoder_attn.k.weight.shape)    # (384, 512)
print("Value weight shape:", encoder_attn.v.weight.shape)  # (384, 512)
print("Output weight shape:", encoder_attn.o.weight.shape) # (384, 512)


Query weight shape: torch.Size([384, 512])
Key weight shape: torch.Size([384, 512])
Value weight shape: torch.Size([384, 512])
Output weight shape: torch.Size([512, 384])


- Loads SST-2 and SAMSum from Hugging Face datasets.
- Runs zero-shot classification on SST-2 using google/flan-t5-small (prompting the model to return exactly one label).
- Runs zero-shot summarization on SAMSum (prompting the model for 1–2 sentence summaries).
- Evaluates classification (accuracy) and summarization (ROUGE).
- Uses small subsets by default so that we can iterate quickly on CPU/GPU.

In [43]:
def classify(texts, max_new_tokens=10):
    if isinstance(texts, str):
        texts = [texts]
    prompts = [f"sst2: {t}" for t in texts]
    enc = tokenizer(prompts, return_tensors='pt', padding=True).to(device)
    with torch.no_grad():
        out = model.generate(**enc, max_new_tokens=max_new_tokens)
    decoded = tokenizer.batch_decode(out, skip_special_tokens=True)
    # Normalize a bit for readability
    return [d.strip().lower() for d in decoded]

examples = [
    "I absolutely loved this movie. It was fantastic!",
    "The plot was predictable and the acting was bad.",
    "Not great, not terrible."
]
preds = classify(examples)
for t, p in zip(examples, preds):
    print(f"Text: {t}\nPrediction: {p}\n")

Text: I absolutely loved this movie. It was fantastic!
Prediction: i loved this movie! it was a great

Text: The plot was predictable and the acting was bad.
Prediction: sst2: the plot was predictable

Text: Not great, not terrible.
Prediction: sst2: not great, not



In [44]:

max_examples = 200
# Generation settings
GEN_KWARGS_CLASS = {
    "max_length": 16,
    "num_beams": 5,
    "early_stopping": True,
    "do_sample": False,
    "temperature": 0.0,
}

GEN_KWARGS_SUM = {
    "max_length": 120,
    "num_beams": 4,
    "early_stopping": True,
    "do_sample": False,
    "temperature": 0.0,
}

In [45]:
# Utility: normalize model-generated text
import unicodedata

def normalize_text(s: str):
    if s is None:
        return ""
    s = s.strip().lower()
    # normalize unicode
    s = unicodedata.normalize("NFKD", s)
    # remove punctuation except spaces
    s = re.sub(r"[^\w\s]", "", s)
    s = re.sub(r"\s+", " ", s)
    return s


## Zero-shot classification (SST-2 style)
FLAN-T5 understands instructions. For SST-2, prompting with `sst2: <text>` often produces `positive` or `negative`.
We'll write a tiny helper to classify one or more texts.

In [46]:
def zero_shot_sst2_classify(ds,labels=["positive", "negative"]):    

    preds = []
    sentence = []
    true_labels = ["negative" if sentence["label"] == 0 else "positive" for sentence in ds]

    for ex in tqdm(ds, desc="SST-2 zero-shot"):
        text = ex["sentence"]
        prompt = (
            "Classify the sentiment of the text as one of the following labels: "
            + ", ".join(labels)
            + ".\n\n"
            + f"Text: \"{text}\"\n\nAnswer with exactly one word: "
        )
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
        out = model.generate(**inputs, **GEN_KWARGS_CLASS)
        out_text = tokenizer.decode(out[0], skip_special_tokens=True)
        out_text_norm = normalize_text(out_text)

        mapped = None
        for lab in labels:
            if normalize_text(lab) == out_text_norm:
                mapped = lab
                break
        if mapped is None:
            for lab in labels:
                if normalize_text(lab) in out_text_norm or out_text_norm in normalize_text(lab):
                    mapped = lab
                    break
        if mapped is None:
            for lab in labels:
                if normalize_text(lab).split()[0] in out_text_norm:
                    mapped = lab
                    break
        if mapped is None:
            mapped = labels[0]
            print("Warning: couldn't map output:", out_text, "-> falling back to", mapped)

        preds.append(mapped)
        sentence.append(text)

    # compute accuracy
    acc = sum(1 for p, t in zip(preds, true_labels) if p == t) / len(preds)
    print(f"SST-2 zero-shot accuracy on {len(preds)} examples: {acc:.4f}")
    return {"sentence": sentence, "preds": preds, "trues": true_labels, "accuracy": acc}

ds = load_dataset("glue", "sst2", split="validation")
if max_examples:
    ds = ds.select(range(min(len(ds), max_examples)))

# Run classification (adjust MAX_EXAMPLES if needed)
sst2_res = zero_shot_sst2_classify(ds, labels=["positive", "negative"])



SST-2 zero-shot: 100%|██████████| 200/200 [00:13<00:00, 14.78it/s]

SST-2 zero-shot accuracy on 200 examples: 0.8600





In [47]:
# Show a few classification examples
for i in range(20):
    print(i, "sentence: ", sst2_res["sentence"][i], "pred:", sst2_res["preds"][i], "true:", sst2_res["trues"][i])

examples = [
    "I absolutely loved this movie. It was fantastic!",
    "The plot was predictable and the acting was bad.",
    "Not great, not terrible."
]


# preds = classify(examples)
# for t, p in zip(examples, preds):
#     print(f"Text: {t}\nPrediction: {p}")

0 sentence:  it 's a charming and often affecting journey .  pred: positive true: positive
1 sentence:  unflinchingly bleak and desperate  pred: negative true: negative
2 sentence:  allows us to hope that nolan is poised to embark a major career as a commercial yet inventive filmmaker .  pred: positive true: positive
3 sentence:  the acting , costumes , music , cinematography and sound are all astounding given the production 's austere locales .  pred: positive true: positive
4 sentence:  it 's slow -- very , very slow .  pred: negative true: negative
5 sentence:  although laced with humor and a few fanciful touches , the film is a refreshingly serious look at young women .  pred: positive true: positive
6 sentence:  a sometimes tedious film .  pred: negative true: negative
7 sentence:  or doing last year 's taxes with your ex-wife .  pred: negative true: negative
8 sentence:  you do n't have to know about music to appreciate the film 's easygoing blend of comedy and romance .  pred: p

## Zero-shot summarization
For summarization, prefix the input with `summarize:` and provide the content (e.g., a short dialogue).

In [49]:
# Cell: Zero-shot summarization on SAMSum

def zero_shot_samsum_summarization(ds_samsum, summary_sentences=(1,2)):
   

    preds = []
    refs = []

    for ex in tqdm(ds_samsum, desc="SAMSum zero-shot"):
        convo = ex["dialogue"]
        prompt = (
            f"Summarize the following conversation in {summary_sentences[0]}-{summary_sentences[1]} sentences:\n\n"
            + convo
            + "\n\nSummary:"
        )
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
        out = model.generate(**inputs, **GEN_KWARGS_SUM)
        summary = tokenizer.decode(out[0], skip_special_tokens=True)
        preds.append(summary.strip())
        refs.append(ex["summary"].strip())

    return {"preds": preds, "refs": refs}


ds_samsum = load_dataset("knkarthick/samsum", split="test")
if max_examples:
        ds_samsum = ds_samsum.select(range(min(len(ds_samsum), max_examples)))

samsum_res = zero_shot_samsum_summarization(ds_samsum, summary_sentences=(1,2))



SAMSum zero-shot: 100%|██████████| 200/200 [01:50<00:00,  1.82it/s]


In [50]:
# Show a few summarization examples
for i in range(20):
    print("REF:", samsum_res["refs"][i])
    print("PRED:", samsum_res["preds"][i])
    print("-" * 60)

REF: Hannah needs Betty's number but Amanda doesn't have it. She needs to contact Larry.
PRED: Larry called Hannah last time they were at the park together. Hannah doesn't know Larry well.
------------------------------------------------------------
REF: Eric and Rob are going to watch a stand-up on youtube.
PRED: Eric and Rob are watching a stand-up on YouTube.
------------------------------------------------------------
REF: Lenny can't decide which trousers to buy. Bob advised Lenny on that topic. Lenny goes with Bob's advice to pick the trousers that are of best quality.
PRED: Bob sends Lenny photos of his trousers. Lenny will buy the first pair or the third pair.
------------------------------------------------------------
REF: Emma will be home soon and she will let Will know.
PRED: Will is going to pick Emma up. Emma will be home soon.
------------------------------------------------------------
REF: Jane is in Warsaw. Ollie and Jane has a party. Jane lost her calendar. They wil

In [52]:
rouge = evaluate.load("rouge")
rouge_res = rouge.compute(predictions=samsum_res["preds"], references=samsum_res["refs"]) 
print("ROUGE results (f-measure):")
for k, v in rouge_res.items():
    # v is a dict with precision/recall/fmeasure when using evaluate
    if isinstance(v, dict) and "f1" in v:
        print(f"  {k}: {v['f1']:.4f}")
    else:
        # compatibility fallback
        try:
            print(f" comp  {k}: {v.mid.fmeasure:.4f}")
        except Exception:
            print(k, v)

ROUGE results (f-measure):
rouge1 0.4452036070844183
rouge2 0.19870507399527154
rougeL 0.3655029081019473
rougeLsum 0.3654140327156863


In [58]:
# Save predictions to disk for later analysis
import json



def save_outputs(df, dir, file_name):
    os.makedirs(dir, exist_ok=True)
    with open(os.path.join(dir, file_name), "w", encoding="utf-8") as f:
        json.dump(df, f, ensure_ascii=False, indent=2)

def format_sst2_readable(res):
    lines = [
        f"{i}\tPRED={p}\tTRUE={t}\tSENT={s}"
        for i, (s, p, t) in enumerate(zip(res["sentence"], res["preds"], res["trues"]))
    ]
    return lines

def format_samsum_readable(res):
    lines = [
        f"{i}\nREF: {r}\nPRED: {p}\n" + "-"*60
        for i, (r, p) in enumerate(zip(res["refs"], res["preds"]))
    ]
    return lines

def write_text(lines, dir, file_name):
    os.makedirs(dir, exist_ok=True)
    with open(os.path.join(dir, file_name), "w", encoding="utf-8") as f:
        f.write("\n".join(lines))

# Create readable text versions
sst2_readable = format_sst2_readable(sst2_res)
samsum_readable = format_samsum_readable(samsum_res)

# Write text files
write_text(sst2_readable, "outputs", "sst2_preds-zeroshot.txt")
write_text(samsum_readable, "outputs", "samsum_preds-zeroshot.txt")

# Write JSON files
save_outputs(sst2_res, "outputs", "sst2_preds-zeroshot.json")
save_outputs(samsum_res, "outputs", "samsum_preds-zeroshot.json")

print("Saved readable .txt and JSON files in ./outputs/")


Saved readable .txt and JSON files in ./outputs/
