In [1]:
! pip install datasets transformers seqeval rouge-score nltk gradio

Collecting datasets
  Downloading datasets-1.14.0-py3-none-any.whl (290 kB)
[K     |████████████████████████████████| 290 kB 4.4 MB/s 
[?25hCollecting transformers
  Downloading transformers-4.12.0-py3-none-any.whl (3.1 MB)
[K     |████████████████████████████████| 3.1 MB 40.9 MB/s 
[?25hCollecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[K     |████████████████████████████████| 43 kB 1.7 MB/s 
[?25hCollecting rouge-score
  Downloading rouge_score-0.0.4-py2.py3-none-any.whl (22 kB)
Collecting gradio
  Downloading gradio-2.4.1-py3-none-any.whl (2.0 MB)
[K     |████████████████████████████████| 2.0 MB 42.4 MB/s 
Collecting aiohttp
  Downloading aiohttp-3.7.4.post0-cp37-cp37m-manylinux2014_x86_64.whl (1.3 MB)
[K     |████████████████████████████████| 1.3 MB 48.7 MB/s 
Collecting xxhash
  Downloading xxhash-2.0.2-cp37-cp37m-manylinux2010_x86_64.whl (243 kB)
[K     |████████████████████████████████| 243 kB 63.2 MB/s 
[?25hCollecting fsspec[http]>=2021.05.0
  Downloading

## Imports

In [2]:
import numpy as np
from datasets import load_dataset, load_metric
from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML
import transformers
from transformers import AutoTokenizer
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers import DataCollatorForTokenClassification
from transformers import pipeline
import gradio as gr

## Load Data

In [4]:
model_checkpoint = "t5-small"
raw_datasets = load_dataset("xsum")
metric = load_metric("rouge")
raw_datasets['train'][100]

Downloading:   0%|          | 0.00/1.93k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/954 [00:00<?, ?B/s]

Using custom data configuration default


Downloading and preparing dataset xsum/default (download: 245.38 MiB, generated: 507.60 MiB, post-processed: Unknown size, total: 752.98 MiB) to /root/.cache/huggingface/datasets/xsum/default/1.2.0/4957825a982999fbf80bca0b342793b01b2611e021ef589fb7c6250b3577b499...


  0%|          | 0/2 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/255M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.00M [00:00<?, ?B/s]

  0%|          | 0/2 [00:00<?, ?it/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset xsum downloaded and prepared to /root/.cache/huggingface/datasets/xsum/default/1.2.0/4957825a982999fbf80bca0b342793b01b2611e021ef589fb7c6250b3577b499. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

{'document': 'Samsung said: "Shipments of the Galaxy Note 7 are being temporarily delayed for additional quality assurance inspections."\nThere are reports in South Korea and the US of the Galaxy Note 7 "exploding" either during or just after charging.\nHowever, it is unclear whether the delay is because of these reports.\nPictures and videos shared online depict charred and burnt handsets.\nShares fell as much as 3.5% during trade in Seoul before making a partial recovery to close 2% down on the day.\nSister company Samsung SDI told Reuters that while it was a supplier of Galaxy Note 7 batteries, it had received no information to suggest the batteries were faulty.\nA YouTube user who says they live in the US uploaded a video of a Galaxy Note 7 with burnt rubber casing and damaged screen under the name Ariel Gonzalez on 29 August.\nHe said the handset "caught fire" shortly after he unplugged the official Samsung charger, less than a fortnight after purchasing it.\n"I came home after wo

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
prefix = "summarize: "
max_input_length = 768
max_target_length = 128

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs



Downloading:   0%|          | 0.00/1.17k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

In [13]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

  0%|          | 0/205 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

## Fine Tuning Model

In [6]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

Downloading:   0%|          | 0.00/231M [00:00<?, ?B/s]

In [10]:
batch_size = 32
model_name = model_checkpoint.split("/")[-1]
args = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned-xsum",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    predict_with_generate=True
)

In [11]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

In [14]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [15]:
import re
def clean_text(text):
    text = text.encode("ascii", errors="ignore").decode(
        "ascii"
    )  # remove non-ascii, Chinese characters
    text = re.sub(r"\n", " ", text)
    text = re.sub(r"\n\n", " ", text)
    text = re.sub(r"\t", " ", text)
    text = text.strip(" ")
    text = re.sub(
        " +", " ", text
    ).strip()  # get rid of multiple spaces and replace with a single
    return text

In [16]:
pipeline_summ = pipeline(
    "summarization",
    model="facebook/bart-large-cnn", # switch out to "t5-small" etc if you wish
    tokenizer="facebook/bart-large-cnn", # as above
    framework="pt",
)

# First of 2 summarization function
def fb_summarizer(text):
    input_text = clean_text(text)
    results = pipeline_summ(input_text)
    return results[0]["summary_text"]

# First of 2 Gradio apps that we'll put in "parallel"
summary1 = gr.Interface(
    fn=fb_summarizer,
    inputs=gr.inputs.Textbox(),
    outputs=gr.outputs.Textbox(label="Summary"),
)

https://huggingface.co/facebook/bart-large-cnn/resolve/main/config.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpy4mypy65


Downloading:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

storing https://huggingface.co/facebook/bart-large-cnn/resolve/main/config.json in cache at /root/.cache/huggingface/transformers/199ab6c0f28e763098fd3ea09fd68a0928bb297d0f76b9f3375e8a1d652748f9.930264180d256e6fe8e4ba6a728dd80e969493c23d4caa0a6f943614c52d34ab
creating metadata file for /root/.cache/huggingface/transformers/199ab6c0f28e763098fd3ea09fd68a0928bb297d0f76b9f3375e8a1d652748f9.930264180d256e6fe8e4ba6a728dd80e969493c23d4caa0a6f943614c52d34ab
loading configuration file https://huggingface.co/facebook/bart-large-cnn/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/199ab6c0f28e763098fd3ea09fd68a0928bb297d0f76b9f3375e8a1d652748f9.930264180d256e6fe8e4ba6a728dd80e969493c23d4caa0a6f943614c52d34ab
Model config BartConfig {
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_final_layer_norm": false,
  "architectures": [
    "BartForConditionalGeneration"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classif_dr

Downloading:   0%|          | 0.00/1.51G [00:00<?, ?B/s]

storing https://huggingface.co/facebook/bart-large-cnn/resolve/main/pytorch_model.bin in cache at /root/.cache/huggingface/transformers/4ccdf4cdc01b790f9f9c636c7695b5d443180e8dbd0cbe49e07aa918dda1cef0.fa29468c10a34ef7f6cfceba3b174d3ccc95f8d755c3ca1b829aff41cc92a300
creating metadata file for /root/.cache/huggingface/transformers/4ccdf4cdc01b790f9f9c636c7695b5d443180e8dbd0cbe49e07aa918dda1cef0.fa29468c10a34ef7f6cfceba3b174d3ccc95f8d755c3ca1b829aff41cc92a300
loading weights file https://huggingface.co/facebook/bart-large-cnn/resolve/main/pytorch_model.bin from cache at /root/.cache/huggingface/transformers/4ccdf4cdc01b790f9f9c636c7695b5d443180e8dbd0cbe49e07aa918dda1cef0.fa29468c10a34ef7f6cfceba3b174d3ccc95f8d755c3ca1b829aff41cc92a300
All model checkpoint weights were used when initializing BartForConditionalGeneration.

All the weights of BartForConditionalGeneration were initialized from the model checkpoint at facebook/bart-large-cnn.
If your task is similar to the task the model of th

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

storing https://huggingface.co/facebook/bart-large-cnn/resolve/main/vocab.json in cache at /root/.cache/huggingface/transformers/4d8eeedc3498bc73a4b72411ebb3219209b305663632d77a6f16e60790b18038.d67d6b367eb24ab43b08ad55e014cf254076934f71d832bbab9ad35644a375ab
creating metadata file for /root/.cache/huggingface/transformers/4d8eeedc3498bc73a4b72411ebb3219209b305663632d77a6f16e60790b18038.d67d6b367eb24ab43b08ad55e014cf254076934f71d832bbab9ad35644a375ab
https://huggingface.co/facebook/bart-large-cnn/resolve/main/merges.txt not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmp9j3cskbn


Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

storing https://huggingface.co/facebook/bart-large-cnn/resolve/main/merges.txt in cache at /root/.cache/huggingface/transformers/0ddddd3ca9e107b17a6901c92543692272af1c3238a8d7549fa937ba0057bbcf.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b
creating metadata file for /root/.cache/huggingface/transformers/0ddddd3ca9e107b17a6901c92543692272af1c3238a8d7549fa937ba0057bbcf.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b
https://huggingface.co/facebook/bart-large-cnn/resolve/main/tokenizer.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpuqsom46j


Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

storing https://huggingface.co/facebook/bart-large-cnn/resolve/main/tokenizer.json in cache at /root/.cache/huggingface/transformers/55c96bd962ce1d360fde4947619318f1b4eb551430de678044699cbfeb99de6a.fc9576039592f026ad76a1c231b89aee8668488c671dfbe6616bab2ed298d730
creating metadata file for /root/.cache/huggingface/transformers/55c96bd962ce1d360fde4947619318f1b4eb551430de678044699cbfeb99de6a.fc9576039592f026ad76a1c231b89aee8668488c671dfbe6616bab2ed298d730
loading file https://huggingface.co/facebook/bart-large-cnn/resolve/main/vocab.json from cache at /root/.cache/huggingface/transformers/4d8eeedc3498bc73a4b72411ebb3219209b305663632d77a6f16e60790b18038.d67d6b367eb24ab43b08ad55e014cf254076934f71d832bbab9ad35644a375ab
loading file https://huggingface.co/facebook/bart-large-cnn/resolve/main/merges.txt from cache at /root/.cache/huggingface/transformers/0ddddd3ca9e107b17a6901c92543692272af1c3238a8d7549fa937ba0057bbcf.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b
loading fi

In [None]:
summary1.launch(debug=True)

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
This share link will expire in 72 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted
Running on External URL: https://25270.gradio.app


Your max_length is set to 142, but you input_length is only 31. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)
