In [1]:
! pip install datasets rouge-score nltk

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


# Fine-tuning a model on a summarization task

In this notebook, 怎么fine-tune一个summarization摘要模型[🤗 Transformers](https://github.com/huggingface/transformers). 数据集是[XSum dataset](https://arxiv.org/pdf/1808.08745.pdf)，包含BBC新闻和新闻摘要.

使用 🤗 Datasets 加载数据和 `Trainer` API去fine-tune数据.

[Model Hub](https://huggingface.co/models) 中的模型有sequence-to-sequence都可以. 比如[`t5-small`](https://huggingface.co/t5-small). 

# 加载数据集

In [1]:
from datasets import load_dataset, load_metric

raw_datasets = load_dataset("xsum")
metric = load_metric("rouge")

Using custom data configuration default
Reusing dataset xsum (/raid/wuxiaojun/.cache/huggingface/datasets/xsum/default/1.2.0/4957825a982999fbf80bca0b342793b01b2611e021ef589fb7c6250b3577b499)


In [2]:
# 下载下来数据集长这样，20w+训练集，1w+验证集和测试集
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'id', 'summary'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'id', 'summary'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'id', 'summary'],
        num_rows: 11334
    })
})

To access an actual element, you need to select a split first, then give an index:

In [3]:
raw_datasets["train"][0]

{'document': 'Recent reports have linked some France-based players with returns to Wales.\n"I\'ve always felt - and this is with my rugby hat on now; this is not region or WRU - I\'d rather spend that money on keeping players in Wales," said Davies.\nThe WRU provides £2m to the fund and £1.3m comes from the regions.\nFormer Wales and British and Irish Lions fly-half Davies became WRU chairman on Tuesday 21 October, succeeding deposed David Pickering following governing body elections.\nHe is now serving a notice period to leave his role as Newport Gwent Dragons chief executive after being voted on to the WRU board in September.\nDavies was among the leading figures among Dragons, Ospreys, Scarlets and Cardiff Blues officials who were embroiled in a protracted dispute with the WRU that ended in a £60m deal in August this year.\nIn the wake of that deal being done, Davies said the £3.3m should be spent on ensuring current Wales-based stars remain there.\nIn recent weeks, Racing Metro fla

In [4]:
# 随机展示一些数据
import datasets
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=5):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [5]:
show_random_elements(raw_datasets["train"])

Unnamed: 0,document,id,summary
0,"Michael Cockerell told reporters about the plan at a press screening of his new series Inside the Commons.\n""I'm not fingering anyone by name,"" Mr Cockerell said, when asked who was involved in the plot.\nBut he did say they were ""right wing Tories... what Downing Street know as the berserkers, the naughty bench"".\nHe declined to name the cameraman who was the subject of the apparent skulduggery.\nIn the first episode of the four-part series, to be shown on Tuesday, the Conservative MP Bill Wiggin is seen complaining to the Speaker during a session in the Commons about the presence of camera crews in the chamber itself.\nMichael Cockerell said Mr Wiggin was not involved in the plot.\n""We heard of a plan to knock over the cameraman and cause the House to be suspended, and then they would blame it on us and suggest we shouldn't be there,"" he said, adding that Parliamentary staff had let them know about the plot and had managed to prevent it from happening.\nHe said there were ""very few"" opponents to the documentary, but ""in Parliament every day there are cunning plans, it is a place made for plotting and conspiracy"".\nThe documentary was filmed over the course of a year - after six years of attempting to persuade the parliamentary authorities to allow them the access they required.\nAtlantic Productions, the producers of the series for the BBC, gathered 600 hours of raw material for the four hours that will be broadcast throughout February.\nThe first episode is broadcast on Tuesday on BBC Two at 21:00 GMT.",31039104,"MPs plotted to knock over a BBC cameraman in the House of Commons - in the hope of stopping a new documentary on Westminster life, a film-maker says."
1,"Â£2.5m is being invested in new facilities at Liberton High where 12-year-old Keane Wallis Bennett died after the ""modesty"" wall fell on her.\nEdinburgh City Council said a permanent memorial was planned as well as a new extension to the PE block.\nThe Health and Safety Executive is continuing its investigation into the tragedy, which happened on 1 April.\nPaul Godzik, Edinburgh City Council's education convener, said: ""The overwhelming view from staff at the school, parents and the local community is that the gym should be demolished as soon as possible.\n""The proposal is that this will happen at the earliest opportunity, and assuming that the necessary consents are in place, we hope this will be able to take place over the summer break to minimise disruption to the school.\n""The tragic incident two months ago has obviously had a devastating effect on the local community and we are determined to work with them and other partners to ensure nothing similar can ever happen again.\n""Discussions about a suitable memorial at the school for Keane are continuing and we hope to make an announcement in the near future.""",28015018,The gym hall where an Edinburgh schoolgirl died when a wall collapsed is to be demolished this summer.
2,"Katrina O'Hara, 44, was stabbed at Jocks Barbers in East Street, Blandford Forum, on 7 January.\nDorset Police said it referred itself to the Independent Police Complaints Commission (IPCC) as it had ""prior contact with people involved"".\nStuart Thomas, 49, who has been charged with murdering Ms O'Hara, is due before Winchester Crown Court on 1 April.\nAn IPCC spokeswoman said: ""The IPCC has begun an independent investigation into previous contact between Dorset Police and Katrina O'Hara, and with Stuart Thomas, also known as George Thomas.""",35350406,A police watchdog is to investigate circumstances relating to the suspected murder of a Dorset hairdresser.
3,"Michael Cole, 29, of Newhaven, East Sussex, admitted inciting sexual activity with boys and possessing indecent images of children.\nHe was charged after police appealed to teenagers who may have been threatened online by a man posing as a woman.\nCole was bailed at Lewes Crown Court until sentencing on 19 June.\nDuring its investigation, Sussex Police said teenagers, mainly boys aged 13 to 17 in the Seaford, Newhaven and Peacehaven areas, were approached by a person calling themselves Jenny Lane and threatened if they did not send inappropriate images.\nDet Con Steve Shimmons said Cole, aka ""Jenny Lane"", used social media, including Facebook and Skype, to contact teenagers and sent out more than 100 friend requests.",32629489,"A man has pleaded guilty to sex offences against boys after he tricked teenagers into sending him ""inappropriate"" images online."
4,"Jean Brooks uses the hairdryer as a pretend speed gun outside her home in Hucknall, Nottinghamshire.\nA video of her has been viewed more than 33 million times on the BBC Radio Nottingham Facebook page.\nSince her new-found fame, she said she is asked for selfies and fans visit her home.\n""People come up to me and hug me and say 'thank you for making the street safer',"" said the 63-year-old.\n""As far as the local school kids are concerned, I'm a national hero.\n""If the traffic slows down nobody's going to get hurt. I've already got a cat that's only got three legs because he didn't know his green cross code.""\nThe video has since spawned memes quoting parts of the interview, in which Ms Brooks said: ""If we can't be safe in our own streets, how the hell are we going to be safe in the world?""\nPeople have also quoted the ""neooooow, neooooow, neoooow"" noise she made when imitating the sound of vehicles speeding past her home.\nA hitchhiker from Germany turned up at her home with a gift on Wednesday after seeing the hairdryer video.\n""I nearly burst into tears when I saw him,"" she said.\n""I hugged him so hard I thought I was going to crack his ribs, bless him.""\nUpdates on this story and more from the East Midlands\nBartek Zabel, from Hamburg, said: ""It's brilliant because you don't need expensive equipment to slow the drivers down.""\nHe has sent the video to his friends and said they loved it too.\n""I just had to visit her because she's like a local hero and like a celebrity,"" he added.\nA man who viewed the hairdryer video also delivered a van of nearly 200 toys to her home on Saturday night, which she gave to children on her estate.\n""It's not often I'm made speechless but I was speechless,"" said Ms Brooks.\nMs Brooks, who is a biker herself, contacted the BBC after it published a video of irresponsible quad bikers and motorcyclists who rode dangerously in Nottingham city centre.\nShe claimed the gang of bikers - who were caught on CCTV doing wheelies, driving the wrong way and weaving in and out of traffic - regularly went down her road.\nShe said she wanted the hairdryer to become symbolic of communities taking back their streets.\nShe encouraged people to vote - particularly young people - and said they should take a hairdryer with them when they do.\n""Don't say anything, don't do anything, just carry a hairdryer,"" she said.\nMs Brooks does other work for her community, including running a charity cafe in her garden to encourage people to talk to each other.\nIn 2015, she collected Easter eggs then delivered them to children's homes and hospitals with the help of a group of bikers.",40184757,"A hairdryer-wielding grandmother who became an internet sensation says schoolchildren in her area now consider her a ""national hero""."


# ROUGE评测指标 [`datasets.Metric`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Metric):

In [6]:
# rouge_1：unigram匹配情况
# rouge_2：bigram匹配情况
# rouge_l: 最长公共子序列
metric

Metric(name: "rouge", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}, usage: """
Calculates average rouge scores for a list of hypotheses and references
Args:
    predictions: list of predictions to score. Each predictions
        should be a string with tokens separated by spaces.
    references: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    rouge_types: A list of rouge types to calculate.
        Valid names:
        `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
        `"rougeL"`: Longest common subsequence based scoring.
        `"rougeLSum"`: rougeLsum splits text using `"
"`.
        See details in https://github.com/huggingface/datasets/issues/617
    use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.
    use_agregator: Return aggregates if this is set to True
Retu

You can call its `compute` method with your predictions and labels, which need to be list of decoded strings:

In [7]:
fake_preds = ["hello there", "general kenobi"]
fake_labels = ["hello there", "general kenobi"]
metric.compute(predictions=fake_preds, references=fake_labels)

{'rouge1': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),
 'rouge2': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),
 'rougeL': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),
 'rougeLsum': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0))}

# 处理数据

 `AutoTokenizer.from_pretrained` 的好处:

- 我们得到一个与我们要使用的模型架构相对应的Tokenizer,
- 我们下载预训练这个checkpoint时使用的词汇表.


In [8]:
from transformers import AutoTokenizer

model_checkpoint = "t5-small"
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [9]:
tokenizer("Hello, this one sentence!")

{'input_ids': [8774, 6, 48, 80, 7142, 55, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}

In [10]:
tokenizer(["Hello, this one sentence!", "This is another sentence."])

{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}

To prepare the targets for our model, we need to tokenize them inside the `as_target_tokenizer` context manager. This will make sure the tokenizer uses the special tokens corresponding to the targets:

In [11]:
with tokenizer.as_target_tokenizer():
    print(tokenizer(["Hello, this one sentence!", "This is another sentence."]))

{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}


T5模型需要加上前缀 "summarize:" (也可以做translate任务).

In [12]:
if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
    prefix = "summarize: "
else:
    prefix = ""

In [13]:
max_input_length = 1024
max_target_length = 128

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True) # truncation截断

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

This function works with one or several examples. In the case of several examples, the tokenizer will return a list of lists for each key:

In [24]:
preprocess_function(raw_datasets['train'][:1])

{'input_ids': [[21603, 10, 17716, 2279, 43, 5229, 128, 1410, 18, 390, 1508, 28, 5146, 12, 10256, 5, 96, 196, 31, 162, 373, 1800, 3, 18, 11, 48, 19, 28, 82, 22209, 3, 547, 30, 230, 117, 48, 19, 59, 1719, 42, 549, 8503, 3, 18, 27, 31, 26, 1066, 1492, 24, 540, 30, 2627, 1508, 16, 10256, 976, 243, 28571, 5, 37, 549, 8503, 795, 17586, 51, 12, 8, 3069, 11, 3996, 13606, 51, 639, 45, 8, 6266, 5, 18263, 10256, 11, 2390, 11, 7262, 10371, 7, 3971, 18, 17114, 28571, 1632, 549, 8503, 13404, 30, 2818, 1401, 1797, 6, 7229, 53, 20, 12151, 1955, 8356, 49, 53, 826, 3, 19585, 643, 9768, 5, 216, 19, 230, 3122, 3, 9, 2103, 1059, 12, 1175, 112, 1075, 38, 24260, 350, 16103, 10282, 7, 5752, 4297, 227, 271, 3, 11060, 30, 12, 8, 549, 8503, 1476, 16, 1600, 5, 28571, 47, 859, 8, 1374, 5638, 859, 10282, 7, 6, 411, 7, 2026, 63, 7, 6, 14586, 7677, 11, 26911, 2419, 7, 4298, 113, 130, 10960, 52, 26786, 16, 3, 9, 813, 11674, 11044, 28, 8, 549, 8503, 24, 3492, 16, 3, 9, 3996, 3328, 51, 1154, 16, 1660, 48, 215, 5, 86, 8,

In [31]:
print("text:\n",tokenizer.decode(preprocess_function(raw_datasets['train'][:1])['input_ids'][0]))
print("labels:\n",tokenizer.decode(preprocess_function(raw_datasets['train'][:1])['labels'][0]))

ERROR! Session/line number was not unique in database. History logging moved to new session 145
text:
 summarize: Recent reports have linked some France-based players with returns to Wales. "I've always felt - and this is with my rugby hat on now; this is not region or WRU - I'd rather spend that money on keeping players in Wales," said Davies. The WRU provides £2m to the fund and £1.3m comes from the regions. Former Wales and British and Irish Lions fly-half Davies became WRU chairman on Tuesday 21 October, succeeding deposed David Pickering following governing body elections. He is now serving a notice period to leave his role as Newport Gwent Dragons chief executive after being voted on to the WRU board in September. Davies was among the leading figures among Dragons, Ospreys, Scarlets and Cardiff Blues officials who were embroiled in a protracted dispute with the WRU that ended in a £60m deal in August this year. In the wake of that deal being done, Davies said the £3.3m should be 

`dataset` object使用 `map` 方法处理，训练集、验证集、测试集一次性处理

In [15]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

Loading cached processed dataset at /raid/wuxiaojun/.cache/huggingface/datasets/xsum/default/1.2.0/4957825a982999fbf80bca0b342793b01b2611e021ef589fb7c6250b3577b499/cache-d2672c5c9532e4c6.arrow
Loading cached processed dataset at /raid/wuxiaojun/.cache/huggingface/datasets/xsum/default/1.2.0/4957825a982999fbf80bca0b342793b01b2611e021ef589fb7c6250b3577b499/cache-90eecd02bce010ce.arrow
Loading cached processed dataset at /raid/wuxiaojun/.cache/huggingface/datasets/xsum/default/1.2.0/4957825a982999fbf80bca0b342793b01b2611e021ef589fb7c6250b3577b499/cache-adae3b78978db3b1.arrow


`map`使用 `load_from_cache_file=False` 可以不使用缓存.


# Fine-tuning the model

我们的模型是一个seq2seq模型，所以导入 `AutoModelForSeq2SeqLM` class. 

In [16]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

使用`Seq2SeqTrainer`, 我们需要定义好[`Seq2SeqTrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.Seq2SeqTrainingArguments)

In [17]:
batch_size = 16
model_name = model_checkpoint.split("/")[-1]
args = Seq2SeqTrainingArguments(
    "test-summarization", # 模型保存的位置
#     overwrite_output_dir = True,
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01, # 权重衰减
    save_total_limit=3, # 最多保存三次模型
    num_train_epochs=5,
    predict_with_generate=True, # 生成摘要
    fp16=True, # 激活混合精度训练（以更快一点）
)

In [18]:
# 数据整理器，pad the inputs to the maximum length in the batch, but also the labels
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [19]:
import nltk
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

Then we just need to pass all of this along with our datasets to the `Seq2SeqTrainer`:

In [20]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [21]:
import nltk
nltk.download('punkt')

ERROR! Session/line number was not unique in database. History logging moved to new session 142


[nltk_data] Downloading package punkt to /raid/wuxiaojun/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

We can now finetune our model by just calling the `train` method:

In [22]:
trainer.train()

Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,2.7137,2.472508,28.3994,7.836,22.3626,22.3669,18.81
2,2.6854,2.450603,28.7622,8.0235,22.6615,22.661,18.8214
3,2.6676,2.438797,28.8666,8.0818,22.7343,22.7342,18.827
4,2.6584,2.435523,28.8977,8.1233,22.784,22.7869,18.8232
5,2.6584,2.435521,28.8945,8.125,22.786,22.789,18.8233




TrainOutput(global_step=7975, training_loss=2.68158392927116, metrics={'train_runtime': 4529.7799, 'train_samples_per_second': 225.226, 'train_steps_per_second': 1.761, 'total_flos': 4.0956649065630106e+17, 'train_loss': 2.68158392927116, 'epoch': 5.0})

# 模型预测

In [32]:
text_example = preprocess_function(raw_datasets['train'][:1])

ERROR! Session/line number was not unique in database. History logging moved to new session 146


In [33]:
# 参考这个文章https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_summarization_wandb.ipynb
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generated_ids = model.generate(
                input_ids = torch.Tensor(text_example['input_ids']).to(device, dtype = torch.long),
                attention_mask = torch.Tensor(text_example['attention_mask']).to(device, dtype = torch.long), 
                max_length=150, 
                num_beams=2,
                repetition_penalty=2.5, 
                length_penalty=1.0, 
                early_stopping=True
                )

ERROR! Session/line number was not unique in database. History logging moved to new session 150


In [34]:
preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]


ERROR! Session/line number was not unique in database. History logging moved to new session 151


In [35]:
print(preds)

['Newport Gwent Dragons fly-half David Davies has said he would rather spend £3.3m on keeping players in Wales.']


# 参考链接

- huggingface官方例子：https://github.com/huggingface/notebooks/blob/master/examples/summarization.ipynb