# Project 3b: Natural Language Generation

In this part of the homework, you will implement decoding algorithms covered in class -- greedy decoding, random sampling, temperature sampling, top-k sampling, and top-p (nucleus) sampling. You will also learn how to use knowledge distillation to use strong teacher models to improve performance of weaker student models. The knowledge distillation exercise is only mandatory for CSE 547M students

*The section on Decoding algorithms is adapted from the code by Jiacheng Liu.*

## Section 0: Setup

In [1]:
"""set device and random seeds"""

######################################################
#  The following helper functions are given to you.
######################################################
import os
from tqdm.notebook import tqdm
import torch
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = device
print(f"device: {device}")


def set_seed(seed=19260817):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed()

device: cuda


### 0.1 Dataset

In [2]:
"""load datasets"""

######################################################
#  The following helper code is given to you.
######################################################

from datasets import load_dataset

dataset = load_dataset("Ximing/ROCStories")
train_data, dev_data, test_data = (
    dataset["train"],
    dataset["validation"],
    dataset["test"],
)

print(train_data[0])

{'story_id': '080198fc-d0e7-42b3-8e63-b2144e59d816', 'prompt': 'On my way to work I stopped to get some coffee.', 'continuation': 'I went through the drive through and placed my order. I paid the cashier and patiently waited for my drink. When she handed me the drink, the lid came off and spilled on me. The coffee hurt and I had to go home and change clothes.', 'constraint_words': ['drive', 'order', 'drink', 'lid', 'coffee', 'hurt', 'home', 'change', 'clothes']}


### 0.2 Evaluation Metrics

In [3]:
"""prepare evaluation"""

######################################################
#  The following helper code is given to you.
######################################################

from evaluate import load
from transformers import RobertaForSequenceClassification, RobertaTokenizer

perplexity_scorer = load("perplexity", module_type="metric")
cola_model_name = "textattack/roberta-base-CoLA"
cola_tokenizer = RobertaTokenizer.from_pretrained(cola_model_name)
cola_model = RobertaForSequenceClassification.from_pretrained(cola_model_name).to(
    device
)


def batchify(data, batch_size):
    assert batch_size > 0

    batch = []
    for item in data:
        # Yield next batch
        if len(batch) == batch_size:
            yield batch
            batch = []

        batch.append(item)

    # Yield last un-filled batch
    if len(batch) != 0:
        yield batch

Some weights of the model checkpoint at textattack/roberta-base-CoLA were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
"""set up evaluation metric"""

######################################################
#  The following helper code is given to you.
######################################################


def compute_perplexity(texts, model="gpt2", batch_size=8):
    score = perplexity_scorer.compute(
        predictions=texts, add_start_token=True, batch_size=batch_size, model_id=model
    )
    return score["mean_perplexity"]


def compute_fluency(texts, batch_size=8):
    scores = []
    for b_texts in batchify(texts, batch_size):
        inputs = cola_tokenizer(
            texts, padding=True, truncation=True, return_tensors="pt"
        ).to(device)
        with torch.no_grad():
            logits = cola_model(**inputs).logits
            probs = logits.softmax(dim=-1)
            scores.extend(probs[:, 1].tolist())
    return sum(scores) / len(scores)


def compute_diversity(texts):
    unigrams, bigrams, trigrams = [], [], []
    total_words = 0
    for gen in texts:
        o = gen.split(" ")
        total_words += len(o)
        for i in range(len(o)):
            unigrams.append(o[i])
        for i in range(len(o) - 1):
            bigrams.append(o[i] + "_" + o[i + 1])
        for i in range(len(o) - 2):
            trigrams.append(o[i] + "_" + o[i + 1] + "_" + o[i + 2])
    return (
        len(set(unigrams)) / len(unigrams),
        len(set(bigrams)) / len(bigrams),
        len(set(trigrams)) / len(trigrams),
    )


def evaluate(generations, experiment):
    generations = [_ for _ in generations if _ != ""]
    perplexity = compute_perplexity(generations)
    fluency = compute_fluency(generations)
    diversity = compute_diversity(generations)
    print(experiment)
    print(f"perplexity = {perplexity:.2f}")
    print(f"fluency = {fluency:.2f}")
    print(f"diversity = {diversity[0]:.2f}, {diversity[1]:.2f}, {diversity[2]:.2f}")
    print()


debug_sents = [
    "This restaurant is awesome",
    "My dog is cute and I love it.",
    "Today is sunny.",
]
evaluate(debug_sents, "debugging run")

  0%|          | 0/1 [00:00<?, ?it/s]

debugging run
perplexity = 178.64
fluency = 0.98
diversity = 0.87, 1.00, 1.00



### 0.3: Load Model

In [5]:
"""load model and tokenizer"""

######################################################
#  The following helper code is given to you.
######################################################

from transformers import GPT2LMHeadModel, GPT2Tokenizer

model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name, pad_token="<|endoftext|>")
model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

## **Section 1: Decoding Algorithms**

In this section, you will implement a few basic decoding algorithms:
1. Greedy decoding
2. Vanilla sampling
3. Temperature sampling
4. Top-k sampling
5. Top-p sampling 

We have provided a wrapper function `decode()` that takes care of batching, controlling max length, and handling the EOS token.
You will be asked to implement the core function of each method: *given the pre-softmax logits of the next token, decide what the next token is.*

**The wrapper calls the core function of each decoding algorithm, which you will implement in the subsections below.**

In [6]:
"""decode main wrapper function"""

######################################################
#  The following helper code is given to you.
######################################################

def _update_model_kwargs_for_generation(
    outputs, model_kwargs, is_encoder_decoder: bool = False
):
    # update past
    if "past_key_values" in outputs:
        model_kwargs["past_key_values"] = outputs.past_key_values
    elif "mems" in outputs:
        model_kwargs["past_key_values"] = outputs.mems
    elif "past_buckets_states" in outputs:
        model_kwargs["past_key_values"] = outputs.past_buckets_states
    else:
        model_kwargs["past_key_values"] = None

    # update token_type_ids with last value
    if "token_type_ids" in model_kwargs:
        token_type_ids = model_kwargs["token_type_ids"]
        model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)

    # update attention mask
    if not is_encoder_decoder:
        if "attention_mask" in model_kwargs:
            attention_mask = model_kwargs["attention_mask"]
            model_kwargs["attention_mask"] = torch.cat(
                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
            )

    return model_kwargs

def decode(prompts, max_len, method, **kwargs):
    encodings_dict = tokenizer(prompts, return_tensors="pt", padding=True)
    input_ids = encodings_dict["input_ids"].to(device)
    attention_mask = encodings_dict["attention_mask"].to(device)

    model_kwargs = {"attention_mask": attention_mask}
    batch_size, input_seq_len = input_ids.shape

    unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)

    for step in range(max_len):
        model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
        with torch.no_grad():
            outputs = model(
                **model_inputs,
                return_dict=True,
                output_attentions=False,
                output_hidden_states=False,
            )

        if step == 0:
            last_non_masked_idx = torch.sum(attention_mask, dim=1) - 1
            next_token_logits = outputs.logits[
                range(batch_size), last_non_masked_idx, :
            ]
        else:
            next_token_logits = outputs.logits[:, -1, :]

        log_prob = F.log_softmax(next_token_logits, dim=-1)

        if method == "greedy":
            next_tokens = greedy(next_token_logits)
        elif method == "sample":
            next_tokens = sample(next_token_logits)
        elif method == "temperature":
            next_tokens = temperature(next_token_logits, temperature=kwargs.get("temperature", 0.8))
        elif method == "topk":
            next_tokens = topk(
                next_token_logits,
                k=kwargs.get("k", 20),
                temperature=kwargs.get("temperature", 1.0),
            )
        elif method == "topp":
            next_tokens = topp(
                next_token_logits,
                p=kwargs.get("p", 0.7),
                temperature=kwargs.get("temperature", 1.0),
            )

        # finished sentences should have their next token be a padding token
        next_tokens = next_tokens * unfinished_sequences + tokenizer.pad_token_id * (
            1 - unfinished_sequences
        )

        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
        # model_kwargs["attention_mask"] = torch.cat(
        #     [attention_mask, torch.ones_like(next_tokens[:, None])], dim=-1
        # )
        model_kwargs = _update_model_kwargs_for_generation(
            outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
        )

        # model_kwargs = model._update_model_kwargs_for_generation(
        #     outputs,
        #     model_kwargs,
        #     is_encoder_decoder=model.config.is_encoder_decoder,
        # )

        # if eos_token was found in one sentence, set sentence to finished
        unfinished_sequences = unfinished_sequences.mul(
            (next_tokens != tokenizer.eos_token_id).long()
        )

        if unfinished_sequences.max() == 0:
            break

    response_ids = input_ids[:, input_seq_len:]
    response_text = [
        tokenizer.decode(
            output, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
        for output in response_ids
    ]

    return response_text

In [7]:
"""debug helper code"""

######################################################
#  The following helper code is given to you.
######################################################

# For debugging, we duplicate a single prompt 10 times so that we obtain 10 generations for the same prompt
dev_prompts = [dev_data[0]["prompt"]] * 10


def print_generations(prompts, generations):
    for prompt, generation in zip(prompts, generations):
        print(f"{[prompt]} ==> {[generation]}")

### 1.1: Greedy Decoding

In [8]:
def greedy(next_token_logits):
    """
    Applies greedy decoding to get the next token.
    inputs:
    - next_token_logits: Tensor(size = (B, V), dtype = float)
    outputs:
    - next_tokens: Tensor(size = (B), dtype = long)

    """
    
    # YOUR CODE HERE
    next_tokens = F.softmax(next_token_logits, dim=-1).argmax(dim=-1)
    return next_tokens

In [9]:
generations = decode(dev_prompts, max_len=20, method="greedy")
print_generations(dev_prompts, generations)

['Ryan was called by his friend to skip work one day.'] ==> ['\n\n"I was like, \'I\'m going to go to work tomorrow,\'" he said.']
['Ryan was called by his friend to skip work one day.'] ==> ['\n\n"I was like, \'I\'m going to go to work tomorrow,\'" he said.']
['Ryan was called by his friend to skip work one day.'] ==> ['\n\n"I was like, \'I\'m going to go to work tomorrow,\'" he said.']
['Ryan was called by his friend to skip work one day.'] ==> ['\n\n"I was like, \'I\'m going to go to work tomorrow,\'" he said.']
['Ryan was called by his friend to skip work one day.'] ==> ['\n\n"I was like, \'I\'m going to go to work tomorrow,\'" he said.']
['Ryan was called by his friend to skip work one day.'] ==> ['\n\n"I was like, \'I\'m going to go to work tomorrow,\'" he said.']
['Ryan was called by his friend to skip work one day.'] ==> ['\n\n"I was like, \'I\'m going to go to work tomorrow,\'" he said.']
['Ryan was called by his friend to skip work one day.'] ==> ['\n\n"I was like, \'I\'m goin

### 1.2: Vanilla Sampling and Temperature Sampling

In [10]:
def sample(next_token_logits):

    """
    inputs:
    - next_token_logits: Tensor(size = (B, V), dtype = float)
    outputs:
    - next_tokens: Tensor(size = (B), dtype = long)

    Hint: use torch.multinomial()
    """

    # YOUR CODE HERE
    next_tokens = F.softmax(next_token_logits, dim=-1)
    next_tokens = torch.multinomial(next_tokens, 1).squeeze(-1)
    return next_tokens

In [11]:
set_seed()
generations = decode(dev_prompts, max_len=20, method="sample")
print_generations(dev_prompts, generations)

['Ryan was called by his friend to skip work one day.'] ==> ['\n\n"Let\'s walk through the window," the banner states (it didn\'t say that at']
['Ryan was called by his friend to skip work one day.'] ==> [' Tucked inside a basement window, he found a packed pack of cigarettes on a tin counter in the']
['Ryan was called by his friend to skip work one day.'] ==> [' He said he wanted to enroll in school next Thursday.\n\nPolice said a woman and her child']
['Ryan was called by his friend to skip work one day.'] ==> [' The person had taken him on a hike.\n\nWas filming normal? Yep, except that the']
['Ryan was called by his friend to skip work one day.'] ==> [' Employees of Target were telling him that Garrett wanted to take action on the education rights of low-income']
['Ryan was called by his friend to skip work one day.'] ==> [' He heard his friend was in trouble and put on a so-called militia, intending to die,']
['Ryan was called by his friend to skip work one day.'] ==> [' In a bit 

In [12]:
def temperature(next_token_logits, temperature):
    """
    inputs:
    - next_token_logits: Tensor(size = (B, V), dtype = float)
    - temperature: Temperature parameter float
    outputs:
    - next_tokens: Tensor(size = (B), dtype = long)
    """
    # YOUR CODE HERE
    return sample(next_token_logits / temperature)

In [13]:
set_seed()
generations = decode(dev_prompts, max_len=20, method="temperature", temperature=0.8)
print_generations(dev_prompts, generations)

['Ryan was called by his friend to skip work one day.'] ==> ['\n\n"I told him if I listened to the job, (it would pay him) $']
['Ryan was called by his friend to skip work one day.'] ==> [' The hunger strike was held in a tunnel deep beneath Northeast Cleveland, which was hosting a public hearing.']
['Ryan was called by his friend to skip work one day.'] ==> [' He said he wanted to enroll in school.\n\n"I didn\'t feel safe and didn\'t']
['Ryan was called by his friend to skip work one day.'] ==> [' The person had taken him on a hike.\n\nWas filming normal?\n\nAnyone who saw']
['Ryan was called by his friend to skip work one day.'] ==> ['\n\n"I said, \'Well, how about you head to the cafeteria?\'" Williams said.']
['Ryan was called by his friend to skip work one day.'] ==> [' He heard his friend was in trouble and was on a business trip.\n\nWhen Bremming']
['Ryan was called by his friend to skip work one day.'] ==> [' In a conversation with his girlfriend, the former bluesman said that 

### 1.3: Top-k Sampling

Useful tips:
- Recall that in Top-k sampling, we only sample from the top-k tokens with the highest probabilities. To ensure that we set the logits other than the top-k to be -inf. You can use `float("-inf")` to represent infinity in python.
- You will find `torch.topk()` useful for getting the top-k logits and indices. Check out the [documentation](https://pytorch.org/docs/stable/generated/torch.topk.html) for the function for more details.
- Do not forget to divide the logits by the temperature before applying softmax.

In [14]:
def topk(next_token_logits, k, temperature = 1):

    """
    Applies the top-k sampling decoding algorithm to get the next token.
    inputs:
    - next_token_logits: Tensor(size = (B, V), dtype = float)
    - k: int, the number of top tokens to consider
    - temperature: Temperature parameter float
    outputs:
    - next_tokens: Tensor(size = (B), dtype = long)
    """
    # YOUR CODE HERE
    _, sorted_indices = torch.topk(next_token_logits / temperature, k, dim=-1)
    mask = torch.zeros_like(next_token_logits).scatter_(1, sorted_indices, 1)
    next_token_logits = torch.where(mask == 1, next_token_logits, float('-inf'))
    return sample(next_token_logits)

In [15]:
set_seed()
generations = decode(dev_prompts, max_len=20, method="topk", k=20)
print_generations(dev_prompts, generations)

['Ryan was called by his friend to skip work one day.'] ==> ['\n\n"I told him if I got to the job, he\'d tell me to skip lunch']
['Ryan was called by his friend to skip work one day.'] ==> [" The second time, however, he didn't make it because he'd lost interest in his job."]
['Ryan was called by his friend to skip work one day.'] ==> [' He said he wanted to come home as soon as possible to have coffee, a meal and to do']
['Ryan was called by his friend to skip work one day.'] ==> [' The police had taken him on a search warrant that had to be approved as part of the search for']
['Ryan was called by his friend to skip work one day.'] ==> ['\n\n"I said, \'Well, how about you get a job,\' " he said,']
['Ryan was called by his friend to skip work one day.'] ==> [" He didn't. But, at night, on a Saturday morning, he would stay awake at a"]
['Ryan was called by his friend to skip work one day.'] ==> [' In a conversation that led to his dismissal, they became good friends. In September 2008,

### 1.4: Top-p Sampling

In [16]:
def topp(next_token_logits, p, temperature = 1):
    """
    Applies the top-p sampling or nucleus sampling decoding algorithm to get the next token.
    inputs:
    - next_token_logits: Tensor(size = (B, V), dtype = float)
    - p: float, the cutoff probability for the top-p sampling
    - temperature: Temperature parameter float
    outputs:
    - next_tokens: Tensor(size = (B), dtype = long)
    """
    
    # TODO: Sort the logits in descending order, and compute
    # the cumulative probabilities `cum_probs` on the sorted logits
    sorted_logits, sorted_indices = torch.sort(next_token_logits / temperature, descending=True, dim=-1)
    sorted_probs = F.softmax(sorted_logits, dim=-1)
    cum_probs = torch.cumsum(sorted_probs, dim=-1)

    # Create a mask to zero out all logits not in top-p
    sorted_indices_to_remove = cum_probs > p
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = 0
    # Restore mask to original indices
    indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)

    # Mask the logits
    next_token_logits[indices_to_remove] = float('-inf')

    # TODO: Sample from the masked logits
    return sample(next_token_logits)

In [17]:
set_seed()
generations = decode(dev_prompts, max_len=20, method="topp", p=0.7)
print_generations(dev_prompts, generations)

['Ryan was called by his friend to skip work one day.'] ==> ['\n\n"I told him I had a game plan for my game day and I told him,']
['Ryan was called by his friend to skip work one day.'] ==> [' The owner said, "This is the guy who shot us, who didn\'t want to do this']
['Ryan was called by his friend to skip work one day.'] ==> [" He said he wanted to skip school one day. He said he didn't want to waste time,"]
['Ryan was called by his friend to skip work one day.'] ==> [' The person had taken him to the hospital.\n\nZimmerman said the couple had planned to']
['Ryan was called by his friend to skip work one day.'] ==> ['\n\n"I said, \'Well, how about you get a job,\' " Williams said.']
['Ryan was called by his friend to skip work one day.'] ==> [" He didn't follow through with his contract because he wasn't going to get the money, but instead"]
['Ryan was called by his friend to skip work one day.'] ==> [' In a bit of an odd coincidence, he was having a car accident when, in the back se

### 1.5: Evaluate!

Run the following cell to obtain the evaluation results, which you should include in your writeup.
Also don't forget to answer the questions.

In [18]:
prompts = [item["prompt"] for item in test_data][:10]
GENERATIONS_PER_PROMPT = 10
MAX_LEN = 100

for experiment in ["greedy", "sample", "temperature", "topk", "topp"]:
    generations = []
    for prompt in tqdm(prompts):
        generations += decode(
            [prompt] * GENERATIONS_PER_PROMPT, max_len=MAX_LEN, method=experiment
        )
    evaluate(generations, experiment)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

greedy
perplexity = 2.08
fluency = 0.78
diversity = 0.01, 0.02, 0.03



  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

sample
perplexity = 68.17
fluency = 0.39
diversity = 0.44, 0.90, 0.99



  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

temperature
perplexity = 15.91
fluency = 0.65
diversity = 0.31, 0.77, 0.95



  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

topk
perplexity = 11.86
fluency = 0.73
diversity = 0.27, 0.74, 0.96



  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

topp
perplexity = 12.24
fluency = 0.71
diversity = 0.29, 0.77, 0.95



You should see the following values:

- For greedy:
perplexity = 2.08
fluency = 0.78
diversity = 0.01, 0.02, 0.03

- For sample:
perplexity = 61.54
fluency = 0.37
diversity = 0.42, 0.89, 0.99

- For temperature:
perplexity = 16.15
fluency = 0.66
diversity = 0.31, 0.77, 0.96

- For topk:
perplexity = 27.46
fluency = 0.70
diversity = 0.26, 0.74, 0.96

- For topp:
perplexity = 12.03
fluency = 0.72
diversity = 0.29, 0.76, 0.95

### *Do I always need to use all this code to generate text from a language model?*

The exercises above were to help you understand the underlying mechanisms of different decoding methods. In practice, you don't need to implement all these decoding methods from scratch. You can use the `generate()` method in the 🤗 Transformers library to generate text from a language model. Below we provide an example of how to use the `generate()` method to generate text from a language model. Please pay close attention to this especially if you are going to be attempting the optional section (mandatory for masters students) on Knowledge Distillation.

In [19]:
from transformers import AutoModelForCausalLM, AutoTokenizer

# Step 0: Load a pre-trained language model, and it's the corresponding tokenizer from the Hugging Face model hub
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left") # Setting padding_side=left is important when using these models for open ended text generation 

# Move the model to the device
model = model.to(device)

example_prompt = "Once upon a time"

# Step 1: Tokenize the input prompt
tokenized_input = tokenizer(example_prompt, return_tensors="pt").to(device)
print("Tokenizer output:")
print(tokenized_input)
print("*******************")

# Step 2: Generate text from the model
output = model.generate(**tokenized_input, max_new_tokens=100, do_sample=True, top_p=0.9, temperature=1.0,)
print("Model Generate output:")
print(output)
print("*******************")

# Step 3: Convert the output ids to text
output_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("Text output:")
print(output_text)
print("*******************")
# Step 3 (Optional): .generate() returns the model inputs as well. We can ignore that by slicing the output
output_text = tokenizer.decode(output[0][len(tokenized_input.input_ids[0]):], skip_special_tokens=True)
print("Text output (ignoring model inputs):")
print(output_text)
print("*******************")

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Tokenizer output:
{'input_ids': tensor([[7454, 2402,  257,  640]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1]], device='cuda:0')}
*******************
Model Generate output:
tensor([[ 7454,  2402,   257,   640,    11,   345,   481,   307,   287,   257,
           995,   543,   345,   481,   691,   766,   530,   393,   517,   286,
            13,   887,   611,   345,   765,   284,  3802,   340,   757,    11,
           345,   481,   691,   307,  1498,   284,   766,  1115,   393,  1440,
           287,   262,  2003,   526,   357,    36,   527, 10424,    11,   279,
            13,  2681,  2014,   198,   198,  1722,   281,  1981,   356,   836,
           470,   761,   284,   766,   284,  3802,   656,   257,   995,   286,
           477,  1115,   286,   262,  7683, 16055,    11,   475,   644,   338,
           517,  1593,   621,   257,  1598,  2000,   318,   257,   880,    12,
         24071,  2858,   287,   543,   356,   460,   651,   284,   760,   290,
          1254,   530, 

Let's go over each of these one by one:

- **Step 1:** The `tokenizer()` function takes the input prompt and converts it into a format that the model can understand. This is essentially converting the input prompt into a sequence of tokens. Note that the tokenizer returns a dictionary with the tokenized input ids and attention mask. The `input_ids` are the token ids that the model will use as input. `"attention_mask`" is a mask vector that indicates if a particular token corresponds to padding. Padding is extremely important when we are dealing with variable length sequences. Through padding, we can ensure that all the sequences in a batch are of the same size. When feeding sequences with padding to a transformer based model, we need to make sure that the model doesn't attend to the padding tokens. The `attention_mask` is used the tokens that are to be ignored in the attention operation. Note that here since we only used a single input sequence, there was no need of padding and that's why all the attention mask values are 1, i.e. none of the tokens in the sequences should be ignored by attention blocks.

- **Step 2:** The `generate()` function takes the tokenized input and generates a sequence of tokens as output. The function takes in a number of arguments, the most important of which are `max_new_tokens` and `do_sample`. The `max_new_tokens` argument specifies the maximum number of new tokens to be generated. The `do_sample` argument is used to turn on sampling, if false, greedy decoding is used. If `do_sample` is set to `True`, then the `top_p` and `temperature` arguments are used to control the sampling process. The `top_p` argument is the value of $p$ parameter for top-p or nucleus sampling and the `temperature` argument is used to control the temperature sampling. Notice that the `generate()` function returns as output a sequence of token ids.

- **Step 3:** The `tokenizer.decode` method converts the token ids into text. The `skip_special_tokens=True` argument is used to ignore the special tokens (e.g. start of sequence, end of sequence, padding etc.) that are added by the tokenizer.

- **Step 3 (Optional):** The `generate()` function returns the model inputs as well. We can ignore that by slicing the output.

## [Optional for CSE 447]  **Section 2: Knowledge Distillation**

In this part of the homework, we will learn how we can use knowledge distillation from a larger teacher model to a smaller student model. Particularly, we will be focusing on the task of text summarization and using the CNN/Daily Mail dataset. We will use Qwen2.5-1.5B-Instruct as our teacher model, which is a 1.5B parameter decoder-only mode pre-trained on 18T tokens of data and then further fine-tuned to follow instructions to perform different tasks (similar to something like ChatGPT). You can read more about Qwen2.5 models [here](https://qwenlm.github.io/blog/qwen2.5/). For the student model, we will be using the default GPT-2 model, which is a 124M parameter model. 

Since Qwen2.5-1.5B-Instruct is a much bigger model and trained on a lot of data, it is more capable of generating better summaries as compared to the default GPT-2 model which is a smaller model and has seen less data. Knowledge distillation is a technique to transfer the knowledge of a larger teacher model to a smaller student model. In this way, we can leverage the large amount of data and compute resources used to train the teacher model to improve the performance of the student model.

This assignment will also make heavy use of the [🤗 Transformers Library](https://huggingface.co/docs/transformers/index). Don't worry if you are not familiar with the library, we will discuss its usage in detail.

In [20]:
# Load packages for this section
import os
from pprint import pprint
from datasets import load_from_disk, Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm


As always, we will start by loading the dataset.

In [21]:
# import os
# import requests
# from zipfile import ZipFile
# 
# # Create the data directory if it doesn't exist
# os.makedirs("data", exist_ok=True)
# 
# # Function to download a file
# def download_file(url, output_path):
#     print(f"Downloading {url} to {output_path}")
#     response = requests.get(url, stream=True)
#     response.raise_for_status()  # Raise an error if the download fails
#     with open(output_path, "wb") as file:
#         for chunk in response.iter_content(chunk_size=8192):
#             file.write(chunk)
# 
# # Function to unzip a file
# def unzip_file(zip_path, extract_to):
#     print(f"Unzipping {zip_path} into {extract_to}")
#     with ZipFile(zip_path, "r") as zip_ref:
#         zip_ref.extractall(extract_to)
# 
# # Define the dataset URLs and file paths
# datasets = [
#     ("https://homes.cs.washington.edu/~kahuja/cse447/project3/data/cnn_dm_cse447_dataset.hf.zip", 
#      "data/cnn_dm_cse447_dataset.hf.zip"),
#     ("https://homes.cs.washington.edu/~kahuja/cse447/project3/data/kd_dataset.hf.zip", 
#      "data/kd_dataset.hf.zip")
# ]
# 
# # Download and unzip the datasets
# for url, file_path in datasets:
#     download_file(url, file_path)
#     unzip_file(file_path, "data/")

In [22]:
######################################################
#  The following code is given to you. DO NOT MODIFY.
######################################################
parent_dir = os.path.dirname(os.path.abspath("__file__"))
data_dir = os.path.join(parent_dir, "data")
cnn_dm_cse447_dataset = load_from_disk(os.path.join(data_dir, "cnn_dm_cse447_dataset.hf"))
cnn_dm_cse447_dataset

DatasetDict({
    train: Dataset({
        features: ['article', 'summary'],
        num_rows: 10000
    })
    val: Dataset({
        features: ['article', 'summary'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['article', 'summary'],
        num_rows: 1000
    })
})

In [23]:
# Preview the dataset
print("Full article:")
pprint(cnn_dm_cse447_dataset["val"][0]["article"])
print("\n\n")
print("Gold summary:")
pprint(cnn_dm_cse447_dataset["val"][0]["summary"])

Full article:
('(CNN)French striker Bafetimbi Gomis, who has a history of fainting, said he '
 'is now "feeling well" after collapsing during Swansea\'s 3-2 loss at '
 'Tottenham in the Premier League on Wednesday. The worrying incident occurred '
 'in the first half at White Hart Lane -- after Tottenham scored in the '
 'seventh minute -- but the 29-year-old left the pitch conscious following '
 'about five minutes of treatment. The Guardian added that he was wearing an '
 'oxygen mask. Play was temporarily stopped before resuming. As the match '
 'progressed, Swansea tweeted that Gomis was "fine," with manager Garry Monk '
 "using the same word to describe Gomis' condition. Gomis spent the night in "
 'hospital as a precaution, Swansea said on its website. "I wanted to reassure '
 'you concerning my health," Gomis told the website. "It actually looks much '
 'scarier than it is physically dangerous, and I am feeling well now. "I have '
 "been under a great deal of stress and fatigue 

### 2.1: Student Model

In this exercise we will try to understand how well the student model i.e. GPT-2 does on the summarizaion task out of box. In later exercises we will try to improve the performance of the student model using knowledge distillation.

We start by loading the student model and tokenizer.

In [24]:
######################################################
#  The following code is given to you. DO NOT MODIFY.
######################################################
from transformers import AutoModelForCausalLM, AutoTokenizer

student_model = AutoModelForCausalLM.from_pretrained("gpt2")
student_tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")

student_model = student_model.to(device)
student_tokenizer.pad_token_id = student_tokenizer.eos_token_id
student_model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

Notice GPT-2 is a decoder-only model and has 12 layers. This is the smallest model in the GPT-2 family i.e. with only 124M parameters.

### 2.1.1 Preparing Data for Student Model

Our student model is a language model and inherently a language model's job is to predict continuations of a sequence by predicting one token at a time. To perform specific tasks like summarization using language models, we need to prepare the data in such a format such that the possible continuation of the sequence is the output we want i.e. in this case the summary of the article.

The GPT-2 paper found adding the TL;DR to the end of the article helps the model in generating better summaries. Implement the `prepare_articles` function that adds the string `"\nTL;DR:"` to the end of each input articles and then tokenizes the articles.

Useful tips:
- While tokenizing the articles, by calling `tokenizer()` make sure to set `padding="max_length"`, `truncation=True`, and `max_length=max_len`  so that the articles are padded to the same length and truncated if they are longer than the maximum length. Also, make sure to set `return_tensors="pt"` so that the output is a PyTorch tensor.

In [25]:
def prepare_articles_for_student_model(articles, student_tokenizer, max_len=1024):

    """
    Processes and tokenizes articles into a format that can be used for summarization by the student model
    and then tokenizes the articles using the student tokenizer.
    
    Inputs:
    - articles: A list of articles to be summarized.
    - student_tokenizer: The tokenizer to use for tokenizing the articles.
    - max_len: The maximum length of the articles.
    
    Returns:
    - tokenized_articles: A dictionary containing the input ids and attention mask of the tokenized articles.
    """
    tokenized_articles = [article + "\nTL;DR:" for article in articles]
    tokenized_articles = student_tokenizer(
        tokenized_articles,
        padding="max_length",
        truncation=True,
        max_length=max_len,
        return_tensors="pt",
    )
    return tokenized_articles

In [26]:
######################################################
#  The following code is given to you. DO NOT MODIFY.
######################################################


def test_prepare_articles_for_student_model():
    # Setup
    student_tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
    student_tokenizer.pad_token_id = student_tokenizer.eos_token_id

    # Test 1: Basic functionality
    articles = ["This is a test article."]
    output = prepare_articles_for_student_model(articles, student_tokenizer)
    print(output[:10])
    assert "input_ids" in output, "Output should contain input_ids"
    assert "attention_mask" in output, "Output should contain attention_mask"
    assert torch.is_tensor(output["input_ids"]), "input_ids should be a tensor"
    assert torch.is_tensor(
        output["attention_mask"]
    ), "attention_mask should be a tensor"

    # Test 2: Multiple articles
    articles = ["First article.", "Second article.", "Third article."]
    output = prepare_articles_for_student_model(articles, student_tokenizer)
    print(output)
    assert (
        output["input_ids"].shape[0] == 3
    ), "Batch size should match number of articles"
    assert (
        output["attention_mask"].shape[0] == 3
    ), "Batch size should match number of articles"

    # Test 3: TL;DR addition
    articles = ["Test article."]
    output = prepare_articles_for_student_model(articles, student_tokenizer)
    decoded = student_tokenizer.decode(output["input_ids"][0], skip_special_tokens=True)
    assert "TL;DR:" in decoded, "TL;DR: should be added to the article"
    assert decoded == "Test article.\nTL;DR:", "Article format should be correct"

    # Test 4: Padding and truncation
    long_article = "This is a very " * 1000  # Create a very long article
    short_article = "Short."
    articles = [long_article, short_article]
    output = prepare_articles_for_student_model(articles, student_tokenizer)
    assert (
        output["input_ids"].shape[1] == 1024
    ), "Should be padded/truncated to max length"
    assert (output["attention_mask"][1] == 0).any(), "Short article should have padding"
    print("All tests passed!")

test_prepare_articles_for_student_model()
######################################################

[Encoding(num_tokens=1024, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])]
{'input_ids': tensor([[50256, 50256, 50256,  ...,    26,  7707,    25],
        [50256, 50256, 50256,  ...,    26,  7707,    25],
        [50256, 50256, 50256,  ...,    26,  7707,    25]]), 'attention_mask': tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]])}
All tests passed!


### 2.1.2 Summarizing with Student Model

Use the `generate()` method to generate summaries from the student model. Refer to the end of the Section 1 of the notebook for an example of how to use the `generate()` method. You can also learn more about the `generate()` method [here](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationMixin.generate).

Helpful tips:
- Instead of generating summaries one by one, we recommend generating summaries in batches to speed up the process.

In [27]:
def summarize_wth_student_model(
    articles,
    student_model,
    student_tokenizer,
    batch_size=8,
    max_new_tokens=100,
    do_sample=True,
    p=0.9,
    temperature=1.0,
    device="cuda",
):

    """
    Generates a list of summaries for a list of articles using the student model.
    
    Inputs:
    - articles: A list of articles to be summarized.
    - student_model: The student model to use for summarization.
    - student_tokenizer: The tokenizer corresponding to the student model.
    - batch_size: The batch size to use for summarization.
    - max_new_tokens: The maximum number of new tokens to generate.
    - do_sample: Whether to use sampling or greedy decoding.
    - p: The p parameter for top-p sampling.
    - temperature: The temperature for sampling.
    
    Returns:
    - summaries: A list of summaries for the articles.
    """

    student_model.eval()
    summaries = []
    with torch.no_grad():
        for i in tqdm(range(0, len(articles), batch_size)):
            # TODO: Tokenize the batch of articles
            tokenized_articles = prepare_articles_for_student_model(articles[i:i+batch_size], student_tokenizer)

            # Move the data to the device
            tokenized_articles = {
                key: value.to(device) for key, value in tokenized_articles.items()
            }

            # TODO: Generate summaries from the model using the .generate method
            generated_summaries = student_model.generate(
                tokenized_articles["input_ids"],
                attention_mask=tokenized_articles["attention_mask"],
                max_new_tokens=max_new_tokens,
                do_sample=do_sample,
                top_p=p,
                temperature=temperature,
                pad_token_id=student_tokenizer.pad_token_id,
                eos_token_id=student_tokenizer.eos_token_id
            )
            generated_summaries = generated_summaries[:, tokenized_articles["input_ids"].shape[1]:]

            # TODO: Convert the generated summaries to text. Hint: Use the `batch_decode` method of the tokenizer.
            # Make sure to only decode the generated tokens, i.e., ignore the input tokens.
            # Also, make sure to set `skip_special_tokens=True` to ignore the special tokens.
            decoded_summaries = student_tokenizer.batch_decode(generated_summaries, skip_special_tokens=True)
           
            # Add the decoded summaries to the list
            summaries.extend(decoded_summaries)

    return summaries

In [28]:
def test_summarize_wth_student_model():
    # Test 1: Basic functionality
    articles = ["This is a test article about AI."]
    summaries = summarize_wth_student_model(
        articles, student_model, student_tokenizer, device=device
    )
    assert isinstance(summaries, list), "Output should be a list"
    assert len(summaries) == len(articles), "Should generate one summary per article"
    assert isinstance(summaries[0], str), "Each summary should be a string"
    assert len(summaries[0]) > 0, "Summaries should not be empty"

    # Test 2: Batch processing
    articles = [
        "First article.",
        "Second article.",
        "Third article.",
    ] * 4  # 12 articles
    batch_size = 4
    summaries = summarize_wth_student_model(
        articles, student_model, student_tokenizer, batch_size=batch_size, device=device
    )
    assert len(summaries) == len(articles), "Should generate summary for each article"

    # Test 3: Generation parameters
    articles = [
        "Test article for parameter checking.",
        "Test article for parameter checking.",
    ]
    # Test with different generation parameters
    summaries_greedy = summarize_wth_student_model(
        articles, student_model, student_tokenizer, do_sample=False, device=device
    )
    summaries_sampling = summarize_wth_student_model(
        articles,
        student_model,
        student_tokenizer,
        do_sample=True,
        p=0.9,
        temperature=0.7,
        device=device,
    )

    assert summaries_greedy[0] == summaries_greedy[1], "Greedy decoding should give the same result"
    assert (
        summaries_sampling[0] != summaries_sampling[1]
    ), "Sampling should give different results"

    # Test 4: Check only generated tokens are in summary
    article = "This is a test article about artificial intelligence. It contains specific phrases that should not appear in the summary unless generated."
    summaries = summarize_wth_student_model(
        [article], student_model, student_tokenizer, device=device
    )
    # Verify the input article text isn't in the summary
    assert article not in summaries[0], "Summary should not contain the input article"

    # Test 5: Check for special tokens in output
    articles = ["Test article for special token checking."]
    summaries = summarize_wth_student_model(
        articles, student_model, student_tokenizer, device=device
    )

    # Check if any special tokens exist in the summary
    for special_token in student_tokenizer.all_special_tokens:
        assert (
            special_token not in summaries[0]
        ), f"Summary contains special token: {special_token}"
    print("All tests passed!")

test_summarize_wth_student_model()

  0%|          | 0/1 [00:00<?, ?it/s]This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (1024). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.
100%|██████████| 1/1 [00:00<00:00,  1.41it/s]
100%|██████████| 3/3 [00:02<00:00,  1.29it/s]
100%|██████████| 1/1 [00:00<00:00,  1.48it/s]
100%|██████████| 1/1 [00:00<00:00,  1.43it/s]
100%|██████████| 1/1 [00:00<00:00,  1.44it/s]
100%|██████████| 1/1 [00:00<00:00,  1.40it/s]

All tests passed!





Let's generate summaries for the validation set using the student model.

In [29]:
set_seed()
pred_summaries = summarize_wth_student_model(
    cnn_dm_cse447_dataset["val"]["article"],
    student_model,
    student_tokenizer,
    p=0.9,
    temperature=1.0,
    device="cuda"
)

100%|██████████| 125/125 [01:39<00:00,  1.25it/s]


In [30]:
# Inspect the summaries
print("Article:")
pprint(cnn_dm_cse447_dataset["val"][0]["article"])
print("***********************\n\n")
print("Generated summary:")
pprint(pred_summaries[0])
print("***********************\n\n")
print("Gold summary:")
pprint(cnn_dm_cse447_dataset["val"][0]["summary"])


Article:
('(CNN)French striker Bafetimbi Gomis, who has a history of fainting, said he '
 'is now "feeling well" after collapsing during Swansea\'s 3-2 loss at '
 'Tottenham in the Premier League on Wednesday. The worrying incident occurred '
 'in the first half at White Hart Lane -- after Tottenham scored in the '
 'seventh minute -- but the 29-year-old left the pitch conscious following '
 'about five minutes of treatment. The Guardian added that he was wearing an '
 'oxygen mask. Play was temporarily stopped before resuming. As the match '
 'progressed, Swansea tweeted that Gomis was "fine," with manager Garry Monk '
 "using the same word to describe Gomis' condition. Gomis spent the night in "
 'hospital as a precaution, Swansea said on its website. "I wanted to reassure '
 'you concerning my health," Gomis told the website. "It actually looks much '
 'scarier than it is physically dangerous, and I am feeling well now. "I have '
 "been under a great deal of stress and fatigue due t

As you can see, the generated summary is not very good, with a lot of completely irrelevant text. Let's try to quantify this by evaluating the summaries using the ROUGE score.

### Evaluating Student Model

We will use the ROUGE score to evaluate the summaries generated by the student model. ROUGE measures how similar a generated summary is to a reference (human-written) summary by comparing overlapping words and phrases. Think of it like a sophisticated "spot the differences" between texts.

Example:
- **Reference**: "The cat sat on the mat"
- **Generated**: "The cat lay on the mat"

ROUGE has three main variants:
1. **ROUGE-1**: Matches single words (In example: 5/6 words match)
2. **ROUGE-2**: Matches word pairs (In example: "the cat", "on the", "the mat" match)
3. **ROUGE-L**: Finds longest matching sequences in order

Scores range from 0 to 1, with higher being better. While widely used in summarization evaluation, ROUGE isn't perfect - it focuses on matching words rather than meaning.

Reference: Lin, C. Y. (2004). ROUGE: A package for automatic evaluation of summaries.

In [31]:
# Evaluate the summaries
import evaluate
rouge = evaluate.load("rouge")
scores = rouge.compute(predictions=pred_summaries, references=cnn_dm_cse447_dataset["val"]["summary"], use_stemmer=True)
scores

{'rouge1': 0.1889532160203738,
 'rouge2': 0.030835822727504077,
 'rougeL': 0.12198681628431995,
 'rougeLsum': 0.16021233885997313}

You should see the following scores:
- rouge1: 0.19
- rouge2: 0.03
- rougeL: 0.12
- rougeLsum: 0.16

Differences of +/- 0.02 in ROUGE scores are acceptable.


The model performs relatively poorly on the dataset, especially considering the ROUGE-2 and ROUGE-L scores, which are very low. Let's see if we can improve this by using a teacher model.

### 2.2 Generate Data from a Teacher Model

We will now start with step 1 of knowledge distillation i.e. generating data from a teacher model. As we mentioned earlier, we will use the Qwen2.5-1.5B-Instruct model as our teacher model. Note that Qwen2.5-1.5B-Instruct is an instruction tuned model, which is different from a base language model like GPT-2. Instruction tuned models are obtained from base language models by fine-tuning them on a diverse set of instructions and their desired outputs. This enables the model to become better at following instructions and is the magic recipe for the success of recent large language models.

In [32]:
# Load a teacher model.
from transformers import AutoModelForCausalLM, AutoTokenizer
teacher_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-1.5B-Instruct"
)
teacher_tokenizer = AutoTokenizer.from_pretrained(
    "Qwen/Qwen2.5-1.5B-Instruct", padding_side="left"
)
teacher_model = teacher_model.to(device)
teacher_model

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1536)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
      )
    )
    (norm): Qw

### 2.2.1 Preparing Data for Teacher Model

We need to prepare the data for the teacher model in a specific format. Since the teacher model is an instruction-tuned model, it has been adapted to follow instructions. Hence, instead of formatting the text according to a completion problem, we need to format it according to an instruction following the problem. We will do this by adding an instruction to the beginning of each article, i.e., "Summarize the following article." Further, we will also instruct the model to output the summary in a specific format by appending a suffix to the end of each article, i.e., "Start your summary with 'TL;DR:. '" In short, the articles should be formatted as follows:

`"Summarize the following article:\n\n<article>\n\nStart your summary with 'TL;DR:'"`

In [33]:
def prepare_articles_teacher(articles, teacher_tokenizer, max_len=1024):

    """
    Prepares the articles for the teacher model and performs tokenization.

    Inputs
    - articles: The articles to prepare and tokenize
    - teacher_tokenizer: The tokenizer for the teacher model
    - max_len: The max length to use for tokenizing

    Returns:
    - tokenized_articles: The tokenized articles in the format of a dictionary with keys `input_ids` and `attention_mask`
    """

    # ToDo: Add the instructions as specified above to each article
    articles_with_instructions = [f"Summarize the following article:\n\n{article}\n\nStart your summary with 'TL;DR:'" for article in articles]

    # The instruction-tuned models to be in chat format. Read more here: https://huggingface.co/docs/transformers/main/en/chat_templating
    articles_chat = [[
        {"role": "system", "content": "You are a helpful assistant and an expert at summarizing articles."},
        {"role": "user", "content": article}
    ] for article in articles_with_instructions]
    # Apply chat template to the articles
    articles_chat = teacher_tokenizer.apply_chat_template(articles_chat, tokenize=False, add_generation_prompt=True)

    # ToDo: Tokenize the articles. Hint: Set `add_special_tokens=False` since special tokens are already added in the chat template.
    tokenized_articles = teacher_tokenizer(articles_chat, padding="max_length", truncation=True, max_length=max_len, return_tensors="pt", add_special_tokens=False)

    return tokenized_articles

In [34]:
def test_prepare_articles_teacher():

    # Test 1: Basic functionality
    articles = ["This is a test article."]
    output = prepare_articles_teacher(articles, teacher_tokenizer)

    assert "input_ids" in output, "Output should contain input_ids"
    assert "attention_mask" in output, "Output should contain attention_mask"
    assert torch.is_tensor(output["input_ids"]), "input_ids should be a tensor"
    assert torch.is_tensor(
        output["attention_mask"]
    ), "attention_mask should be a tensor"

    # Test 2: Instruction addition
    decoded = teacher_tokenizer.decode(output["input_ids"][0], skip_special_tokens=True)
    assert (
        "Summarize the following article:" in decoded
    ), "Should contain summarization instruction"
    assert (
        "Start your summary with 'TL;DR:'" in decoded
    ), "Should contain TL;DR instruction"

    # Test 3: Multiple articles
    articles = ["First article.", "Second article.", "Third article."]
    output = prepare_articles_teacher(articles, teacher_tokenizer)
    assert (
        output["input_ids"].shape[0] == 3
    ), "Batch size should match number of articles"
    assert (
        output["attention_mask"].shape[0] == 3
    ), "Batch size should match number of articles"

    # Test 4: Padding and truncation
    long_article = "This is a very " * 1000  # Create a very long article
    short_article = "Short."
    articles = [long_article, short_article]
    output = prepare_articles_teacher(articles, teacher_tokenizer, max_len=1024)
    assert (
        output["input_ids"].shape[1] == 1024
    ), "Should be padded/truncated to max length"
    assert (output["attention_mask"][1] == 0).any(), "Short article should have padding"
    
    print("All tests passed!")
    
test_prepare_articles_teacher()

All tests passed!


Let's have a look at the tokenized articles.

In [35]:
tokenized_articles = prepare_articles_teacher(
    cnn_dm_cse447_dataset["val"]["article"][:2], teacher_tokenizer
)
decoded_articles = teacher_tokenizer.batch_decode(
    tokenized_articles["input_ids"], skip_special_tokens=True
)
pprint(decoded_articles[0])

('system\n'
 'You are a helpful assistant and an expert at summarizing articles.\n'
 'user\n'
 'Summarize the following article:\n'
 '\n'
 '(CNN)French striker Bafetimbi Gomis, who has a history of fainting, said he '
 'is now "feeling well" after collapsing during Swansea\'s 3-2 loss at '
 'Tottenham in the Premier League on Wednesday. The worrying incident occurred '
 'in the first half at White Hart Lane -- after Tottenham scored in the '
 'seventh minute -- but the 29-year-old left the pitch conscious following '
 'about five minutes of treatment. The Guardian added that he was wearing an '
 'oxygen mask. Play was temporarily stopped before resuming. As the match '
 'progressed, Swansea tweeted that Gomis was "fine," with manager Garry Monk '
 "using the same word to describe Gomis' condition. Gomis spent the night in "
 'hospital as a precaution, Swansea said on its website. "I wanted to reassure '
 'you concerning my health," Gomis told the website. "It actually looks much '
 'sc

Notice the system, user, and assistant tags in the prompt. The system tag is followed by an instruction to ground the model to be a helpful assistant and be an expert at summarizing articles. The user tag is followed by the article to be summarized and the instruction about the task. We have the assistant tag in the end and the model is expected to generate the summary after the assistant tag.

### 2.2.2. Generate summaries with teacher model

Similar to the student model, we will use the `generate()` method to generate summaries from the teacher model. Implement the `summarize_with_teacher_model` function below to do this.

In [36]:
def summarize_with_teacher_model(
    articles,
    teacher_model,
    teacher_tokenizer,
    batch_size=8,
    max_new_tokens=100,
    do_sample=True,
    p=0.9,
    temperature=1.0,
    device="cuda",
):
    
    """
    Generates summaries for a list of articles using the teacher model.
    Essentially the same as the function `summarize_wth_student_model`, but we will use the prepare articles function for the teacher model here.
    """
    teacher_model.eval()
    summaries = []
    with torch.no_grad():
        for i in tqdm(range(0, len(articles), batch_size)):
            # TODO: Tokenize the batch of articles
            tokenized_articles = prepare_articles_teacher(articles[i:i+batch_size], teacher_tokenizer)

            # Move the data to the device
            tokenized_articles = {
                key: value.to(device) for key, value in tokenized_articles.items()
            }

            # TODO: Generate summaries from the model using the .generate method
            generated_summaries = teacher_model.generate(
                tokenized_articles["input_ids"],
                attention_mask=tokenized_articles["attention_mask"],
                max_new_tokens=max_new_tokens,
                do_sample=do_sample,
                top_p=p,
                temperature=temperature,
                pad_token_id=teacher_tokenizer.pad_token_id,
                eos_token_id=teacher_tokenizer.eos_token_id
            )
            generated_summaries = generated_summaries[:, tokenized_articles["input_ids"].shape[1]:]

            # TODO: Convert the generated summaries to text. Hint: Use the `batch_decode` method of the tokenizer.
            # Make sure to only decode the generated tokens, i.e., ignore the input tokens.
            # Also, make sure to set `skip_special_tokens=True` to ignore the special tokens.
            decoded_summaries = teacher_tokenizer.batch_decode(generated_summaries, skip_special_tokens=True)
           
            # Add the decoded summaries to the list
            summaries.extend(decoded_summaries)

    return summaries

Let's first check how well the teacher model performs on the validation set.

In [37]:

set_seed()
# Since the teacher model is large, we will only evaluate on a small subset of the validation set.
val_summaries_generated = summarize_with_teacher_model(
    cnn_dm_cse447_dataset["val"]["article"][:100],
    teacher_model,
    teacher_tokenizer,
    batch_size=8,
    max_new_tokens=100,
    do_sample=True,
    p=0.9,
    temperature=1.0,
    device="cuda"
)

100%|██████████| 13/13 [00:34<00:00,  2.68s/it]


In [38]:
# Inspect the summaries
print("Article:")
pprint(cnn_dm_cse447_dataset["val"][0]["article"])
print("***********************\n\n")
print("Generated summary:")
pprint(val_summaries_generated[0])
print("***********************\n\n")
print("Gold summary:")
pprint(cnn_dm_cse447_dataset["val"][0]["summary"])

Article:
('(CNN)French striker Bafetimbi Gomis, who has a history of fainting, said he '
 'is now "feeling well" after collapsing during Swansea\'s 3-2 loss at '
 'Tottenham in the Premier League on Wednesday. The worrying incident occurred '
 'in the first half at White Hart Lane -- after Tottenham scored in the '
 'seventh minute -- but the 29-year-old left the pitch conscious following '
 'about five minutes of treatment. The Guardian added that he was wearing an '
 'oxygen mask. Play was temporarily stopped before resuming. As the match '
 'progressed, Swansea tweeted that Gomis was "fine," with manager Garry Monk '
 "using the same word to describe Gomis' condition. Gomis spent the night in "
 'hospital as a precaution, Swansea said on its website. "I wanted to reassure '
 'you concerning my health," Gomis told the website. "It actually looks much '
 'scarier than it is physically dangerous, and I am feeling well now. "I have '
 "been under a great deal of stress and fatigue due t

You should see that the generated summary is much better than the one generated by the student model. Let's now evaluate the summaries generated by the teacher model.

In [39]:
set_seed()
teacher_rouge = evaluate.load("rouge")
scores = teacher_rouge.compute(
    predictions=val_summaries_generated,
    references=cnn_dm_cse447_dataset["val"]["summary"][:100],
    use_stemmer=True
)
scores  


{'rouge1': 0.2664340663954573,
 'rouge2': 0.08403782865692987,
 'rougeL': 0.18580286294970424,
 'rougeLsum': 0.21615203083136378}

You should see the following scores:
- rouge1: 0.261
- rouge2: 0.082
- rougeL: 0.181
- rougeLsum: 0.211

Differences of +/- 0.02 in ROUGE scores are acceptable.


While these are only on a small subset of the validation set, we also ran the teacher model on the entire validation set and got the following scores:
- rouge1: 0.284
- rouge2: 0.085
- rougeL: 0.188
- rougeLsum: 0.229

As you can see, the teacher model performs much better than the student model. While the numbers are not super high, note that our teacher model is only 1.5B parameters, which is considered a tiny model in the current LLM landscape. Bigger models with 8B, 30B, 70B, or even higher are expected to perform even better.

#### Generating Data for Distillation

We can now generate data for distillation. We will use the articles from the training dataset and generate summaries for those articles using the teacher model. Note that we do have the original summaries for the training set, but we are not using them for distillation. Our goal is to show you how distillation works and how you can use it for your own use cases where you might not have labelled training data available.

In [40]:
def generate_synthetic_data_for_distillation(
    articles,
    teacher_model,
    teacher_tokenizer,
    num_samples=1000,
    batch_size=8,
    max_new_tokens=100,
    do_sample=True,
    p=0.9,
    temperature=1.0,
    device="cuda",
):
    articles = articles[:num_samples]
    summaries = summarize_with_teacher_model(
        articles,
        teacher_model,
        teacher_tokenizer,
        batch_size=batch_size,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        p=p,
        temperature=temperature,
        device=device,
    )

    return summaries

In [41]:
# set_seed()
# train_summaries_generated = generate_synthetic_data_for_distillation(
#     cnn_dm_cse447_dataset["train"]["article"],
#     teacher_model,
#     teacher_tokenizer,
#     num_samples=1000,
#     batch_size=8,
#     max_new_tokens=100,
#     do_sample=True,
#     p=0.9,
#     temperature=1.0,
#     device="cuda",
# )

In [42]:
# kd_dataset = {
#     "article": cnn_dm_cse447_dataset["train"]["article"][:1000],
#     "summary": train_summaries_generated,
#     "gold_summary": cnn_dm_cse447_dataset["train"]["summary"][:1000]
# }
# kd_dataset = Dataset.from_dict(kd_dataset)
# kd_dataset

In [43]:
# # Save the dataset to disk, so that you can load it later without re-generating the data.
# kd_dataset.save_to_disk(f"{data_dir}/kd_dataset.hf")

In case you are running into compute issues while generating data for distillation, we provide you with the distilled dataset in the `data` folder, which you can directly load using the `datasets` library.

```python
kd_dataset = load_from_disk(f"{data_dir}/kd_dataset.hf")
```

### 3. Fine-tuning student model with teacher model's generated data

#### 3.1 Preparing data for distillation

You will now implement the `prepare_data_for_distillation` function to prepare the data for distillation. This function will format the data in a specific way so that it can be used to fine-tune the student model. You will follow pretty much the same process as you did for the student model in `prepare_articles_for_student_model` with a few changes.

- First we will include the summaries in the input text along with the articles. This is done because we are now training the student model to generate summaries from the articles. Hence the format of the input text will be `<article>\nTL;DR:<summary>`.
- In the tokenization dictionary, we now need to add a new key, `labels,` which contains the labels to train the language model. For language models, the labels are the same as the input IDs since the model is expected to generate the next word in the sequence. However, while fine-tuning, we want the model to learn how to generate the summaries from the articles and we do not care about the model learning to predict tokens in the original articles. Therefore, we replace the labels for the prompt tokens with -100, which is a special token id that is used to signal the loss function to ignore the loss for those tokens.


In [44]:
def prepare_data_for_distillation(article, summary, student_tokenizer, max_length=1024):
    """
    Prepares the data for distillation.

    Inputs:`
    - article: The article to be summarized.
    - summary: Summary to be used as labels for distillation.
    - student_tokenizer: Tokenizer for the student model.
    - max_length: Maximum length of the tokenized input.

    Returns:
    - tokenized_data: Tokenized data in the format of a dictionary with keys `input_ids`, `attention_mask`, and `labels`.

    Note: Unlike the functions preparing data for student and teacher models before, this function only takes a single
    article as input instead of a list of articles. We are doing this because while training we want to create batches
    dynamically such that they are padded based on the longest article in the batch. Earlier functions were padding all
    the articles with the same length. While it doesn't make any difference in the performance whether you use
    dynamic padding or fixed padding, dynamic padding can be much more compute and memory efficient, which is especially
    important for training
    """
    
    # # TODO: Format the article and summary for distillation.
    # Hint: `summary` might already contain the TL;DR: prefix, make sure you are not adding it twice!
    article_wth_summary = article + "\nTL;DR:" + summary if "TL;DR:" not in summary else article + "\n" + summary
    article_prompt = article + "\nTL;DR:"  # This should only contain the article and <article>\nTL;DR: you will need this to set the labels for the prompt tokens to -100

    # TODO: Tokenize `articles_wth_summaries`. Note set padding=False, since we want to do dynamic padding during training.
    tokenized_data = student_tokenizer(
        article_wth_summary,
        padding=False,
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    tokenized_prompt = student_tokenizer(
        article_prompt,
        padding=False,
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )

    # Tokenizer adds a batch dimension to the input IDs and attention mask when the input is a single sequence.
    # We need to remove the batch dimension since we are only processing a single article.
    tokenized_data["input_ids"], tokenized_data["attention_mask"] = (
        tokenized_data["input_ids"][0],
        tokenized_data["attention_mask"][0],
    )
    tokenized_prompt["input_ids"], tokenized_prompt["attention_mask"] = (
        tokenized_prompt["input_ids"][0],
        tokenized_prompt["attention_mask"][0],
    )

    # Set the labels to the input IDs
    tokenized_data["labels"] = tokenized_data["input_ids"].clone()

    # TODO: Replace the labels for the prompt tokens with -100.
    # Remember both labels and input_ids are pytorch tensors of shape (seq_len,)
    # YOUR CODE HERE
    tokenized_data["labels"][:tokenized_prompt["input_ids"].size(0)] = -100

    return tokenized_data

In [45]:
def test_prepare_data_for_distillation():
    # Setup
    student_tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="right")
    student_tokenizer.pad_token_id = student_tokenizer.eos_token_id

    # Test 1: Basic functionality
    article = "This is a test article."
    summary = "This is a test summary."
    output = prepare_data_for_distillation(article, summary, student_tokenizer)
    assert "input_ids" in output, "Output should contain input_ids"
    assert "attention_mask" in output, "Output should contain attention_mask"
    assert "labels" in output, "Output should contain labels"
    assert torch.is_tensor(output["input_ids"]), "input_ids should be a tensor"
    assert torch.is_tensor(
        output["attention_mask"]
    ), "attention_mask should be a tensor"
    assert torch.is_tensor(output["labels"]), "labels should be a tensor"

    # Test 2: Correct label masking
    article = "This is a test article."
    summary = "This is a test summary."
    output = prepare_data_for_distillation(article, summary, student_tokenizer)
    input_ids = output["input_ids"]
    labels = output["labels"]

    # Ensure labels for the article part are -100
    article_length = len(student_tokenizer(article + "\nTL;DR:")["input_ids"])
    assert all(
        label == -100 for label in labels[:article_length]
    ), "Labels for article tokens should be -100"

    # Ensure labels for the summary part are not -100
    assert all(
        label != -100 for label in labels[article_length:]
    ), "Labels for summary tokens should not be -100"

    # Test 3: Padding and truncation
    long_article = "This is a very " * 1000  # Create a very long article
    long_summary = "Summary " * 1000
    short_article = "Short."
    short_summary = "Short."
    output = prepare_data_for_distillation(
        long_article, long_summary, student_tokenizer, max_length=1024
    )
    assert (
        output["input_ids"].shape[0] == 1024
    ), "Long inputs should be truncated to max length"

    output = prepare_data_for_distillation(
        short_article, short_summary, student_tokenizer, max_length=1024
    )
    assert (
        output["attention_mask"][1] != 0
    ).all(), "Short article should no have padding"
    print("All tests passed!")

test_prepare_data_for_distillation()

All tests passed!


Let's now prepare the data for distillation.

In [46]:
kd_dataset = load_from_disk(f"{data_dir}/kd_dataset.hf")

In [47]:
train_tokenized_data = kd_dataset.map(
    lambda example: prepare_data_for_distillation(
        example["article"],
        example["summary"],
        student_tokenizer,
        max_length=1024,
    ),
    batched=False,
    remove_columns=kd_dataset.column_names,
)
train_tokenized_data

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 1000
})

In [48]:
val_tokenized_data = cnn_dm_cse447_dataset["val"].map(
    lambda example: prepare_data_for_distillation(
        example["article"],
        example["summary"],
        student_tokenizer,
        max_length=1024,
    ),
    batched=False,
    remove_columns=cnn_dm_cse447_dataset["val"].column_names,
)
val_tokenized_data

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 1000
})

In [49]:
# Let's have a look at the tokenized data
tokenized_data = train_tokenized_data[0]
print("Input IDs:")
print(tokenized_data["input_ids"])
print("Labels:")
print(tokenized_data["labels"])
print("***********************\n\n")

print("Input Text:")
pprint(student_tokenizer.decode(tokenized_data["input_ids"], skip_special_tokens=True))

Input IDs:
[43, 47383, 11, 4492, 357, 12637, 8, 1377, 5850, 14179, 3491, 7806, 5325, 33783, 8810, 1895, 284, 257, 2098, 4248, 1238, 1510, 7198, 3901, 13, 16, 1510, 8, 15807, 355, 339, 4962, 1248, 319, 3321, 11, 475, 339, 17424, 262, 1637, 1839, 470, 3350, 257, 4822, 319, 683, 13, 7806, 5325, 33783, 355, 5850, 14179, 287, 366, 18308, 14179, 290, 262, 8284, 286, 262, 9643, 1, 1675, 262, 18641, 286, 30914, 5721, 1023, 1088, 262, 995, 11, 262, 1862, 8674, 1139, 339, 468, 645, 3352, 284, 277, 799, 353, 465, 5003, 1497, 319, 3049, 5006, 11, 4144, 290, 16527, 4671, 13, 366, 40, 836, 470, 1410, 284, 307, 530, 286, 883, 661, 508, 11, 355, 2582, 355, 484, 1210, 1248, 11, 6451, 2822, 2405, 257, 4858, 5701, 1097, 4947, 393, 1223, 2092, 553, 339, 1297, 281, 6638, 39877, 2961, 428, 1227, 13, 366, 40, 836, 470, 892, 314, 1183, 307, 3573, 45997, 13, 366, 464, 1243, 314, 588, 7067, 389, 1243, 326, 1575, 546, 838, 8059, 1377, 3835, 290, 36731, 290, 35962, 526, 1629, 1248, 11, 5325, 33783, 481, 307, 1498

### Fine-tuning the student model

We are now ready to fine-tune the student model. While this might sound like a daunting task, we actually need not write much code for this. Transformers library has a `Trainer` class that takes care of the fine-tuning process. We will use the `TrainingArguments` class to set the training arguments and the `Trainer` class to fine-tune the model. Below is the code to fine-tune the student model.

In [50]:
torch.cuda.empty_cache()

In [51]:
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq

def fine_tune_student_model(
    student_model,
    train_tokenized_data,
    val_tokenized_data,
    student_tokenizer,
    output_dir="checkpoints/",
    logging_steps=10,
    eval_steps=50,
    logging_first_step=True,
    learning_rate=2e-5,
    warmup_steps=100,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=2,
    weight_decay=0.01,
    fp16=True,
    seed=42,
):

    """
    Fine-tunes the student (language) model.

    Inputs:
    - student_model: The student (language) model to be fine-tuned.
    - train_tokenized_data: The tokenized training data.
    - val_tokenized_data: The tokenized validation data.
    - student_tokenizer: The tokenizer for the student model.
    - output_dir: The directory to save the fine-tuned model.
    - learning_rate: The learning rate for the optimizer.
    - warmup_steps: The number of steps to warm up the learning rate. This is used to gradually increase the learning rate during training.
    - per_device_train_batch_size: The batch size for training per device. 
    - per_device_eval_batch_size: The batch size for evaluation per device.
    - gradient_accumulation_steps: The number of steps to accumulate gradients before updating the model weights. # This is a trick to use a larger effective batch size when low on memory. Basically, we accumulate gradients for `gradient_accumulation_steps` steps before updating the model weights using SGD, which is equivalent to using a batch size of `gradient_accumulation_steps * per_device_train_batch_size` times the per_device_train_batch_size.
    - num_train_epochs: The number of training epochs.
    - weight_decay: The weight decay for the optimizer. This is a regularization technique to prevent overfitting.
    - fp16: Whether to use 16-bit precision. This is used to speed up the training process and reduce the memory usage.
    - seed: The random seed to use for training. This is used to ensure reproducibility of the training process.
    """

    # Set training arguments.
    # This involves things like learning rate, batch size, number of epochs, etc.
    training_args = TrainingArguments(
        output_dir="checkpoints/",
        logging_steps=10,
        eval_strategy="steps",
        eval_steps=50,
        logging_first_step=True,
        learning_rate=learning_rate,
        warmup_steps=warmup_steps,
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        num_train_epochs=num_train_epochs,
        weight_decay=weight_decay,
        fp16=fp16,
        seed=seed,
        report_to="none",
    )

    # Create a data collator for dynamic padding
    # A data collator is used to construct batches of data from the dataset.
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=student_tokenizer, model=student_model, padding="longest"
    )

    # Create a trainer
    trainer = Trainer(
        model=student_model,
        args=training_args,
        train_dataset=train_tokenized_data,
        eval_dataset=val_tokenized_data,
        data_collator=data_collator,
    )

    # Fine-tune the model
    trainer.train()

set_seed()
student_model = AutoModelForCausalLM.from_pretrained("gpt2")
student_model = student_model.to(device)
student_tokenizer.pad_token_id = student_tokenizer.eos_token_id
fine_tune_student_model(
    student_model,
    train_tokenized_data,
    val_tokenized_data,
    student_tokenizer,
    learning_rate=2e-5,
    warmup_steps=100,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=2,
    weight_decay=0.01,
    fp16=True,
    seed=42,
)

Step,Training Loss,Validation Loss
50,17.559,3.282049
100,12.3199,3.090302
150,11.7396,3.096331
200,11.2973,3.136682
250,10.7491,3.059554


The reference training and validation losses are 2.61 and 3.01 respectively. Differences of +/- 0.2 are acceptable.

Let's now generate summaries with the fine-tuned student model and evaluate the performance.

In [52]:
set_seed()
student_tokenizer.padding_side = "left" #Padding side left for generation 
pred_summaries_ft_student = summarize_wth_student_model(
    cnn_dm_cse447_dataset["val"]["article"],
    student_model,
    student_tokenizer,
    p=0.9,
    temperature=1.0,
    device="cuda",
)

100%|██████████| 125/125 [02:05<00:00,  1.00s/it]


In [53]:
# Inspect the summaries
print("Article:")
pprint(cnn_dm_cse447_dataset["val"][0]["article"])
print("***********************\n\n")
print("Generated summary:")
pprint(pred_summaries_ft_student[0])
print("***********************\n\n")
print("Gold summary:")
pprint(cnn_dm_cse447_dataset["val"][0]["summary"])

Article:
('(CNN)French striker Bafetimbi Gomis, who has a history of fainting, said he '
 'is now "feeling well" after collapsing during Swansea\'s 3-2 loss at '
 'Tottenham in the Premier League on Wednesday. The worrying incident occurred '
 'in the first half at White Hart Lane -- after Tottenham scored in the '
 'seventh minute -- but the 29-year-old left the pitch conscious following '
 'about five minutes of treatment. The Guardian added that he was wearing an '
 'oxygen mask. Play was temporarily stopped before resuming. As the match '
 'progressed, Swansea tweeted that Gomis was "fine," with manager Garry Monk '
 "using the same word to describe Gomis' condition. Gomis spent the night in "
 'hospital as a precaution, Swansea said on its website. "I wanted to reassure '
 'you concerning my health," Gomis told the website. "It actually looks much '
 'scarier than it is physically dangerous, and I am feeling well now. "I have '
 "been under a great deal of stress and fatigue due t

While not perfect, the summaries generated by the fine-tuned student model are much better than the ones generated by the student model before fine-tuning. Let's now evaluate the performance of the fine-tuned student model.

In [54]:
import evaluate
teacher_rouge = evaluate.load("rouge")
scores = teacher_rouge.compute(
    predictions=pred_summaries_ft_student,
    references=cnn_dm_cse447_dataset["val"]["summary"],
    use_stemmer=True,
)
scores

{'rouge1': 0.2314907258730331,
 'rouge2': 0.055315187761111995,
 'rougeL': 0.14569933195061446,
 'rougeLsum': 0.1839563766484451}

You should see the following scores:
- rouge1: 0.230
- rouge2: 0.055
- rougeL: 0.145
- rougeLsum: 0.185

Differences of +/- 0.02 in ROUGE scores are acceptable.

As you can see, the fine-tuned student model performs much better than the student model before fine-tuning. This is impressive because we used very small teacher and student models and only used 1000 examples for distillation.

In practice, you would want to use a much bigger teacher model and a bigger student model to get better performance. The purpose of this exercise is to give you an idea of how distillation works and how you can use it for your own use cases.

In [57]:
# GPT2 real-data fine-tuning
# select just the first 1000 examples
train_data_og = cnn_dm_cse447_dataset["train"].select(range(1000))
train_tokenized_data_og = train_data_og.map(
    lambda example: prepare_data_for_distillation(
        example["article"],
        example["summary"],
        student_tokenizer,
        max_length=1024,
    ),
    batched=False,
    remove_columns=cnn_dm_cse447_dataset["train"].column_names
)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [58]:
gpt2_real = AutoModelForCausalLM.from_pretrained("gpt2")
gpt2_real = gpt2_real.to(device)
fine_tune_student_model(
    gpt2_real,
    train_tokenized_data_og,
    val_tokenized_data,
    student_tokenizer,
    learning_rate=2e-5,
    warmup_steps=100,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=2,
    weight_decay=0.01,
    fp16=True,
    seed=42,
)

Step,Training Loss,Validation Loss
50,17.4371,3.052013
100,12.7231,2.340843
150,11.331,2.214711
200,10.607,2.189112
250,10.6586,2.176201


In [60]:
set_seed()
pred_summaries_gpt2_real = summarize_wth_student_model(
    cnn_dm_cse447_dataset["val"]["article"],
    gpt2_real,
    student_tokenizer,
    p=0.9,
    temperature=1.0,
    device="cuda",
)

100%|██████████| 125/125 [02:08<00:00,  1.02s/it]


In [62]:
scores_gpt2_real = teacher_rouge.compute(
    predictions=pred_summaries_gpt2_real,
    references=cnn_dm_cse447_dataset["val"]["summary"],
    use_stemmer=True,
)
print(scores_gpt2_real)

{'rouge1': 0.25105259258740115, 'rouge2': 0.06543180779171379, 'rougeL': 0.15868926385145793, 'rougeLsum': 0.23297995391604004}
