# Text Summary with T5 from Huggingface Pytorch

In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds
from pathlib import Path
import torch
import re
import time

In [3]:
BATCH_SIZE = 16

SHUFFEL_SIZE = 1024

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

learning_rate = 3e-5

In [4]:
print(device)

cuda:0


In [5]:
cnn_dailymail = tfds.load(name="cnn_dailymail")

In [6]:
train_tfds = cnn_dailymail['train']
test_tfds = cnn_dailymail['test']
val_tfds = cnn_dailymail['validation']

In [7]:
train_ds_iter = tfds.as_numpy(train_tfds)
val_ds_iter = tfds.as_numpy(val_tfds)
test_ds_iter = tfds.as_numpy(test_tfds)

In [8]:
def write_data(iter_dataset, name, path="data/"):
    
    articles_file = Path(path + name + "/article").open("w")
    highlights_file = Path(path + name + "/highlights").open("w")

    for item in iter_dataset:
        articles_file.write(item["article"].decode("utf-8") + "\n")
        articles_file.flush()
        highlights_file.write(item["highlights"].decode("utf-8").replace("\n", " ") + "\n")
        highlights_file.flush()

In [9]:
write_data(train_ds_iter, "train")

In [10]:
write_data(test_ds_iter, "test")
write_data(val_ds_iter, "val")

## Define Model

In [11]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small').to(device)

task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
    model.config.update(task_specific_params.get("summarization", {}))
    

optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate, weight_decay=0.0001)

## Define Pytorch Dataset

In [12]:
def read_files(name):
    article_path = "data/%s/article" % name
    highlights_path = "data/%s/highlights" % name
    
    articles = [x.rstrip() for x in open(article_path).readlines()]
    highlights = [x.rstrip() for x in open(highlights_path).readlines()]
    
    assert len(articles) == len(highlights)
    return articles, highlights

In [13]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, articles, highlights):
        self.x = articles
        self.y = highlights
        
    def __getitem__(self, index):
        x = tokenizer.encode_plus(model.config.prefix + self.transfrom(self.x[index]), max_length=512, return_tensors="pt", pad_to_max_length=True)
        y = tokenizer.encode(self.transfrom(self.y[index]), max_length=150, return_tensors="pt", pad_to_max_length=True)
        return x['input_ids'].view(-1), x['attention_mask'].view(-1), y.view(-1)
    
    @staticmethod
    def transfrom(x):
        x = x.lower()
        x = re.sub("'(.*)'", r"\1", x)
        return x
    
    def __len__(self):
        return len(self.x)

In [14]:
def get_dataset(name):
    articles, highlights = read_files(name)
    return MyDataset(articles, highlights)

In [15]:
train_ds = get_dataset("train")
test_ds = get_dataset("test")
val_ds = get_dataset("val")

In [16]:
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE)

## Define Step function

In [17]:
pad_token_id = tokenizer.pad_token_id
def step(inputs_ids, attention_mask, y):
    y_ids = y[:, :-1].contiguous()
    lm_labels = y[:, 1:].clone()
    lm_labels[y[:, 1:] == pad_token_id] = -100
    output = model(inputs_ids, attention_mask=attention_mask, decoder_input_ids=y_ids, lm_labels=lm_labels)
    return output[0] # loss

## Train

In [19]:
EPOCHS = 1
log_interval = 200
train_loss = []
val_loss = []
for epoch in range(EPOCHS):
    model.train() 
    start_time = time.time()
    for i, (inputs_ids, attention_mask, y) in enumerate(train_loader):
        inputs_ids = inputs_ids.to(device)
        attention_mask = attention_mask.to(device)
        y = y.to(device)
        
        
        optimizer.zero_grad()
        loss = step(inputs_ids, attention_mask, y)
        train_loss.append(loss.item())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
            
        if (i + 1) % log_interval == 0:
            with torch.no_grad():
                x, x_mask, y = next(iter(val_loader))
                x = x.to(device)
                x_mask = x_mask.to(device)
                y = y.to(device)
                
                v_loss = step(x, x_mask, y)
                v_loss = v_loss.item()
                
                
                elapsed = time.time() - start_time
                print('| epoch {:3d} | [{:5d}/{:5d}] | '
                  'ms/batch {:5.2f} | '
                  'loss {:5.2f} | val loss {:5.2f}'.format(
                    epoch, i, len(train_loader),
                    elapsed * 1000 / log_interval,
                    loss.item(), v_loss))
                start_time = time.time()
                val_loss.append(v_loss)
                
                

| epoch   0 | [  199/17945] | ms/batch 362.94 | loss  2.24 | val loss  2.08
| epoch   0 | [  399/17945] | ms/batch 360.18 | loss  2.43 | val loss  2.00
| epoch   0 | [  599/17945] | ms/batch 363.08 | loss  2.33 | val loss  2.02
| epoch   0 | [  799/17945] | ms/batch 364.22 | loss  2.06 | val loss  1.93
| epoch   0 | [  999/17945] | ms/batch 365.22 | loss  2.28 | val loss  1.97
| epoch   0 | [ 1199/17945] | ms/batch 364.49 | loss  1.95 | val loss  1.96
| epoch   0 | [ 1399/17945] | ms/batch 361.50 | loss  1.91 | val loss  1.92
| epoch   0 | [ 1599/17945] | ms/batch 364.73 | loss  1.93 | val loss  1.97
| epoch   0 | [ 1799/17945] | ms/batch 362.19 | loss  1.69 | val loss  1.95
| epoch   0 | [ 1999/17945] | ms/batch 365.62 | loss  2.03 | val loss  1.94
| epoch   0 | [ 2199/17945] | ms/batch 364.76 | loss  2.04 | val loss  1.91
| epoch   0 | [ 2399/17945] | ms/batch 365.65 | loss  2.30 | val loss  1.93
| epoch   0 | [ 2599/17945] | ms/batch 365.40 | loss  1.87 | val loss  1.92
| epoch   0 

KeyboardInterrupt: 

## Evaluate

In [21]:
from rouge_score import rouge_scorer
from rouge_score import scoring

class RougeScore:
    '''
    mostly from https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/evaluation/metrics.py 
    '''
    
    def __init__(self, score_keys=None)-> None:
        super().__init__()
        if score_keys is None:  
            self.score_keys = ["rouge1", "rouge2", "rougeLsum"]
        
        self.scorer = rouge_scorer.RougeScorer(self.score_keys)
        self.aggregator = scoring.BootstrapAggregator()
        
        
    @staticmethod
    def prepare_summary(summary):
            # Make sure the summary is not bytes-type
            # Add newlines between sentences so that rougeLsum is computed correctly.
            summary = summary.replace(" . ", " .\n")
            return summary
    
    def __call__(self, target, prediction):
        """Computes rouge score.''
        Args:
        targets: string
        predictions: string
        """

        target = self.prepare_summary(target)
        prediction = self.prepare_summary(prediction)
        
        self.aggregator.add_scores(self.scorer.score(target=target, prediction=prediction))

        return 
    
    def reset_states(self):
        self.rouge_list = []

    def result(self):
        result = self.aggregator.aggregate()
        
        for key in self.score_keys:
            score_text = "%s = %.2f, 95%% confidence [%.2f, %.2f]"%(
                key,
                result[key].mid.fmeasure*100,
                result[key].low.fmeasure*100,
                result[key].high.fmeasure*100
            )
            print(score_text)
        
        return {key: result[key].mid.fmeasure*100 for key in self.score_keys}

In [25]:
rouge_score = RougeScore()
predictions = []
for i, (input_ids, attention_mask, y) in enumerate(test_loader):
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    y = y.to(device)
        
    summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask)
    pred = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
    real = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in y]
    for pred_sent, real_sent in zip(pred, real):
        rouge_score(pred_sent, real_sent)
        predictions.append(str("pred sentence: " + pred_sent + "\n\n real sentence: " + real_sent))
    if i > 40:
        break
    
rouge_score.result()

rouge1 = 38.52, 95% confidence [37.74, 39.39]
rouge2 = 17.51, 95% confidence [16.71, 18.34]
rougeLsum = 35.93, 95% confidence [35.14, 36.73]


{'rouge1': 38.52036439543923,
 'rouge2': 17.507818219729863,
 'rougeLsum': 35.92564132743401}

In [27]:
for pred in predictions[:10]:
    print("------")
    print(pred)
    print("------")    

------
pred sent:faa use planes with 31 inches between each row of seats, a standard which on some airlines has decreased . many airlines have 30 inches of space, while some airlines offer as little as 28 inches . british airways airways has 29 inches, thomson's short haul seat pitch is 28 inches and virgin atlantics is 30 inches below the pitch . some airlines have a seat pitch of 31 inches and air asia offers 29 inches and spirit airlines offers 30-31 . the department of transportation says a u.s standard on aeroplanes with a 31 inches from one point on each row . tests conducted by the government, while united airlines have 31 inches of room, while most airlines have just 28 inches in space . but some airlines stick to a pitch, while gulf air economy seats are 29 inches shorter, air atlantic has 30 inches, and a short haul seats
 real sent:experts question if packed out planes are putting passengers at risk . u.s consumer advisory group says minimum space must be stipulated . safety