In [1]:
# !pip install -U transformers
# !pip install -U datasets
# !pip install tensorboard
# !pip install sentencepiece
# !pip install accelerate

In [2]:
from datasets import load_dataset
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    AutoModelForCausalLM
)

import os

## Dataset Preparation

In [3]:
dataset = load_dataset('gopalkalpande/bbc-news-summary', split='train')
full_dataset = dataset.train_test_split(test_size=0.2, shuffle=True)
dataset_train = full_dataset['train']
dataset_valid = full_dataset['test']
 
print(dataset_train)
print(dataset_valid)

Dataset({
    features: ['File_path', 'Articles', 'Summaries'],
    num_rows: 1779
})
Dataset({
    features: ['File_path', 'Articles', 'Summaries'],
    num_rows: 445
})


In [4]:
print(dataset_train[0])

{'File_path': 'sport', 'Articles': 'Newry to fight cup exit in courts..Newry City are expected to discuss legal avenues on Friday regarding overturning their ejection from the Nationwide Irish Cup...The IFA upheld its original decision to throw Newry out of the cup following the Andy Crawford registration row. \'\'A law firm will put a case forward for Newry FC, and see what the legal implications of all this is are,\'\' said Newry boss Roy McCreadie. \'\'This is a big issue, now that we have an appeal pending,\'\' On Wednesday, a fresh IFA hearing into Crawford registration saga, ruled that last week\'s original verdict had been correct. It meant that Bangor, beaten 5-1 by Newry on the field, will take on Portadown in the sixth round. Newry had claimed they had uncovered "fresh evidence", in respect of the dates relating to the registration. But McCreadie is not further annoyed that full details of Wednesday\'s meeting was not relayed to the club. \'\'Even to this day, we have as much

## Configurations

In [5]:
MODEL = 'distilgpt2'
BATCH_SIZE = 4
NUM_PROCS = os.cpu_count()
EPOCHS = 5
OUT_DIR = 'results_distilgpt2_bbc_news_summary'
MAX_LENGTH = 512 # Maximum context length to consider while preparing dataset.

## Tokenization

In [6]:
tokenizer = AutoTokenizer.from_pretrained('distilgpt2')

In [7]:
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [8]:
text = dataset_train[10]['Articles'] + ' TL;DR: ' + dataset_train[10]['Summaries']

In [9]:
text

'Stevens named in England line-up..England have named Bath prop Matt Stevens in the starting line-up for their Six Nations match against Ireland at Lansdowne Road on Sunday...Fellow Bath prop Duncan Bell will start on the bench, as coach Andy Robinson makes just one change to the team that was beaten by France. It will be Stevens\' first start after two caps as a replacement against the All Blacks last year. Leicester duo Ollie Smith and Andy Goode have been drafted onto the bench. Stevens takes over from Phil Vickery, who suffered a broken arm playing for Gloucester last weekend. "I\'m confident Matt will grasp this opportunity and make his mark against Ireland," said Robinson..."All three players have shown outstanding form of late, most recently in the England A win against France A and for their club," added Robinson. "Selection beckons when players demonstrate such consistent ability. "This game against Ireland will be massive. We recognise it\'s a must-win game for us this season

In [10]:
def preprocess_function(example):
    context = f"{example['Articles'] + ' TL;DR: ' + example['Summaries']}"
    final_tokens = tokenizer(context, max_length=MAX_LENGTH, padding='max_length')
    return final_tokens

In [11]:
tokenized_train = dataset_train.map(
    preprocess_function,
#     batched=True,
    num_proc=NUM_PROCS,
    remove_columns=dataset_train.column_names,
)
tokenized_valid = dataset_valid.map(
    preprocess_function,
#     batched=True,
    num_proc=NUM_PROCS,
    remove_columns=dataset_valid.column_names,
)

Map (num_proc=16):   0%|          | 0/1779 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/445 [00:00<?, ? examples/s]

In [12]:
print(tokenized_train[0])

{'input_ids': [3791, 563, 284, 1907, 6508, 8420, 287, 8028, 492, 3791, 563, 2254, 389, 2938, 284, 2112, 2742, 34265, 319, 3217, 5115, 17586, 278, 511, 22189, 295, 422, 262, 49232, 8685, 5454, 986, 464, 314, 7708, 24816, 663, 2656, 2551, 284, 3714, 968, 563, 503, 286, 262, 6508, 1708, 262, 12382, 24499, 9352, 5752, 13, 10148, 32, 1099, 4081, 481, 1234, 257, 1339, 2651, 329, 968, 563, 10029, 11, 290, 766, 644, 262, 2742, 10939, 286, 477, 428, 318, 389, 14004, 531, 968, 563, 6478, 9817, 5108, 961, 494, 13, 10148, 1212, 318, 257, 1263, 2071, 11, 783, 326, 356, 423, 281, 5198, 13310, 14004, 1550, 3583, 11, 257, 4713, 314, 7708, 4854, 656, 24499, 9352, 26784, 11, 8879, 326, 938, 1285, 338, 2656, 15593, 550, 587, 3376, 13, 632, 4001, 326, 9801, 273, 11, 13125, 642, 12, 16, 416, 968, 563, 319, 262, 2214, 11, 481, 1011, 319, 4347, 324, 593, 287, 262, 11695, 2835, 13, 968, 563, 550, 4752, 484, 550, 18838, 366, 48797, 2370, 1600, 287, 2461, 286, 262, 9667, 11270, 284, 262, 9352, 13, 887, 5108, 96

In [13]:
def group_texts(examples):
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    if total_length >= MAX_LENGTH:
        total_length = (total_length // MAX_LENGTH) * MAX_LENGTH
    result = {
        k: [t[i : i + MAX_LENGTH] for i in range(0, total_length, MAX_LENGTH)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [14]:
lm_dataset_train = tokenized_train.map(
    group_texts, num_proc=NUM_PROCS, batched=True
)
lm_dataset_valid = tokenized_valid.map(
    group_texts, num_proc=NUM_PROCS, batched=True
)

Map (num_proc=16):   0%|          | 0/1779 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/445 [00:00<?, ? examples/s]

In [15]:
print(lm_dataset_train[0])

{'input_ids': [3791, 563, 284, 1907, 6508, 8420, 287, 8028, 492, 3791, 563, 2254, 389, 2938, 284, 2112, 2742, 34265, 319, 3217, 5115, 17586, 278, 511, 22189, 295, 422, 262, 49232, 8685, 5454, 986, 464, 314, 7708, 24816, 663, 2656, 2551, 284, 3714, 968, 563, 503, 286, 262, 6508, 1708, 262, 12382, 24499, 9352, 5752, 13, 10148, 32, 1099, 4081, 481, 1234, 257, 1339, 2651, 329, 968, 563, 10029, 11, 290, 766, 644, 262, 2742, 10939, 286, 477, 428, 318, 389, 14004, 531, 968, 563, 6478, 9817, 5108, 961, 494, 13, 10148, 1212, 318, 257, 1263, 2071, 11, 783, 326, 356, 423, 281, 5198, 13310, 14004, 1550, 3583, 11, 257, 4713, 314, 7708, 4854, 656, 24499, 9352, 26784, 11, 8879, 326, 938, 1285, 338, 2656, 15593, 550, 587, 3376, 13, 632, 4001, 326, 9801, 273, 11, 13125, 642, 12, 16, 416, 968, 563, 319, 262, 2214, 11, 481, 1011, 319, 4347, 324, 593, 287, 262, 11695, 2835, 13, 968, 563, 550, 4752, 484, 550, 18838, 366, 48797, 2370, 1600, 287, 2461, 286, 262, 9667, 11270, 284, 262, 9352, 13, 887, 5108, 96

## Model

In [16]:
model = AutoModelForCausalLM.from_pretrained('distilgpt2')

In [17]:
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

81,912,576 total parameters.
81,912,576 training parameters.


## Training

In [18]:
training_args = TrainingArguments(
    output_dir=OUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir=OUT_DIR,
    logging_steps=10,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    save_total_limit=1,
    report_to='tensorboard',
    learning_rate=0.00005,
    fp16=True,
    dataloader_num_workers=NUM_PROCS
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_dataset_train,
    eval_dataset=lm_dataset_valid,
)

history = trainer.train()

Epoch,Training Loss,Validation Loss
1,3.146,2.998546
2,2.8118,2.944553
3,2.9421,2.91839
4,2.719,2.913258
5,2.8716,2.9134


There were missing keys in the checkpoint model loaded: ['lm_head.weight'].


In [19]:
model.save_pretrained(f"{OUT_DIR}/final_model")

In [20]:
tokenizer.save_pretrained(f"{OUT_DIR}/final_model")

('results_distilgpt2_bbc_news_summary/final_model/tokenizer_config.json',
 'results_distilgpt2_bbc_news_summary/final_model/special_tokens_map.json',
 'results_distilgpt2_bbc_news_summary/final_model/vocab.json',
 'results_distilgpt2_bbc_news_summary/final_model/merges.txt',
 'results_distilgpt2_bbc_news_summary/final_model/added_tokens.json',
 'results_distilgpt2_bbc_news_summary/final_model/tokenizer.json')

## Inference

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer

import torch

In [2]:
model_path = 'results_distilgpt2_bbc_news_summary/final_model'
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [3]:
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [4]:
prompt = """Architecturally, the school has a Catholic character.\
Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of \
the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \
'Venite Ad Me Omnes'. Next to the Main Building is the Basilica of the Sacred Heart. Immediately \
behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the \
grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in \
1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), \
is a simple, modern stone statue of Mary"""

In [5]:
print(prompt)

Architecturally, the school has a Catholic character.Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend 'Venite Ad Me Omnes'. Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary


In [8]:
def summarize_text(text, model, tokenizer, max_length=512, num_beams=5):
    # Preprocess the text
    inputs = tokenizer.encode(
        text + ' TL;DR: ',
        return_tensors='pt',
        max_length=max_length,
        truncation=True
    )

    with torch.no_grad():
        # Generate the summary
        summary_ids = model.generate(
            inputs,
            max_length=512,
            num_beams=num_beams,
            early_stopping=True,
        )

    # Decode and return the summary
    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

In [9]:
summarize_text(prompt, model, tokenizer, num_beams=2)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


"Architecturally, the school has a Catholic character.Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend 'Venite Ad Me Omnes'. Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary TL;DR: Architecturally, the school has a Catholic character.Atop the Main Building's gold dome is a golden statue of the Virgin Mary.Atop the Main Building's gold dome is a golden statue of the Virgin Mary.Atop the Main Building's gold dome is a golden statue of the Virgin Mary.Atop the Main Building's g