# German Text Summary with the tranlated CNN Daily Mail Dataset
and T5 from Huggingface Pytorch

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from pathlib import Path
import torch
import re
import time

In [2]:
BATCH_SIZE = 4

SHUFFEL_SIZE = 1024

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

learning_rate = 3e-5

EPOCHS = 5

log_interval = 200

model_size = "t5-base"

In [3]:
print(device)

cuda:0


## Define Model

In [4]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained(model_size)
model = T5ForConditionalGeneration.from_pretrained(model_size).to(device)

task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
    model.config.update(task_specific_params.get("summarization", {}))
    

optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate, weight_decay=0.0001)

## Read the tranlated German Dataset

In [5]:
def get_translated_ds(name):
    article_path = "../data/%s/articles_german" % name
    highlights_path = "../data/%s/highlights_german" % name

    articles = [x.rstrip() for x in open(article_path).readlines()]
    highlights = [x.rstrip() for x in open(highlights_path).readlines()]
    
    assert len(articles) == len(highlights)
    return articles, highlights

In [6]:
train_x, train_y = get_translated_ds("train")
test_x, test_y = get_translated_ds("test")
val_x, val_y = get_translated_ds("val")

## Show some Dataset Example 

In [7]:
import numpy as np
for i in range(5):
    rand_int = np.random.randint(len(train_x))
    print("\n--------")
    print(train_x[rand_int][:1500])
    print()
    print(train_y[rand_int])


--------
120617; (CNN) - Ein unerwartetes und plötzliches Rampenlicht auf die Special Olympics, eine Organisation, die seit über 40 Jahren Menschen mit intellektuellen Behinderungen dient und ehrt, kommt weniger als zwei Wochen bevor die gemeinnützige Organisation eine neue Kampagne lanciert: Spread the Word to the End the Word. Special Olympics führt eine Kampagne durch, um die Verwendung des "R-Wortes" zu stoppen. Der 31. März wird zu einem "nationalen Tag des Bewusstseins" erklärt, ein Aufruf an die Amerikaner, ihre Verwendung des Wortes "verzögern" oder, wie die Organisation es vorziehen würde, des "R-Wortes" anzuerkennen und zu überdenken. "Die meisten Menschen betrachten dieses Wort nicht als Hassrede, aber genau so fühlt es sich für Millionen von Menschen mit intellektuellen Behinderungen, ihre Familien und Freunde an", heißt es in einer Erklärung über die Kampagne. "Dieses Wort ist genauso grausam und beleidigend wie jede andere Verunglimpfung". Der Vorstoß für mehr Respekt wi

In [8]:
import numpy as np
for i in range(5):
    rand_int = np.random.randint(len(test_x))
    print("\n--------")
    print(test_x[rand_int][:1500])
    print()
    print(test_y[rand_int])


--------
1591; Ein ehemaliger TSA-Agent behauptet, dass Tasten jeden Tag an Sicherheitskontrollpunkten in den Vereinigten Staaten passieren und Mitarbeiter routinemäßig die Regeln beugen, um Passagiere in ihre Hände zu bekommen. Jason Edward Harrington, der sechs Jahre lang mit der Regierungsbehörde zusammengearbeitet hat, fügte hinzu, dass Screenshots oft attraktive Reisende zur Seite ziehen, so dass sie sie unangemessen berühren können und Taschenkontrollen bei Menschen durchführen, die ihnen einfach nicht gefallen. Seine Kommentare folgen auf Enthüllungen, wonach zwei TSA-Mitarbeiter gefeuert wurden, weil sie Körperscanner in Denver manipuliert hatten, um Männer vor dem Einsteigen in ihre Flüge zu tätscheln. In einem Beitrag für die Zeitschrift TIME sagte er, der Fall sei verstörend, gab aber zu, dass er nicht überrascht worden sei. Vorwürfe: Jason Edward Harrington, der sechs Jahre lang mit der TSA zusammengearbeitet hat, behauptet, dass Taschen an Sicherheitskontrollen im ganzen 

## Define Pytorch Dataset

In [9]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, articles, highlights):
        self.x = articles
        self.y = highlights
        
    def __getitem__(self, index):
        x = tokenizer.encode_plus(model.config.prefix + self.transfrom(self.x[index]), max_length=512, return_tensors="pt", pad_to_max_length=True)
        y = tokenizer.encode(self.transfrom(self.y[index]), max_length=150, return_tensors="pt", pad_to_max_length=True)
        return x['input_ids'].view(-1), x['attention_mask'].view(-1), y.view(-1)
    
    @staticmethod
    def transfrom(x):
        x = x.lower()
        x = re.sub("'(.*)'", r"\1", x)
        return x
    
    def __len__(self):
        return len(self.x)

In [10]:
train_ds = MyDataset(train_x, train_y) 
val_ds = MyDataset(val_x, val_y)
test_ds = MyDataset(test_x, test_y)

In [11]:
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE)

In [12]:
x, x_mask, y = next(iter(val_loader))
x.shape, x_mask.shape, y.shape

(torch.Size([4, 512]), torch.Size([4, 512]), torch.Size([4, 150]))

## Define Step function

In [13]:
pad_token_id = tokenizer.pad_token_id
def step(inputs_ids, attention_mask, y):
    y_ids = y[:, :-1].contiguous()
    lm_labels = y[:, 1:].clone()
    lm_labels[y[:, 1:] == pad_token_id] = -100
    output = model(inputs_ids, attention_mask=attention_mask, decoder_input_ids=y_ids, lm_labels=lm_labels)
    return output[0] # loss

## Train

In [14]:
train_loss = []
val_loss = []
for epoch in range(EPOCHS):
    model.train() 
    start_time = time.time()
    for i, (inputs_ids, attention_mask, y) in enumerate(train_loader):
        inputs_ids = inputs_ids.to(device)
        attention_mask = attention_mask.to(device)
        y = y.to(device)
        
        
        optimizer.zero_grad()
        loss = step(inputs_ids, attention_mask, y)
        train_loss.append(loss.item())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
            
        if (i + 1) % log_interval == 0:
            with torch.no_grad():
                x, x_mask, y = next(iter(val_loader))
                x = x.to(device)
                x_mask = x_mask.to(device)
                y = y.to(device)
                
                v_loss = step(x, x_mask, y)
                v_loss = v_loss.item()
                
                
                elapsed = time.time() - start_time
                print('| epoch {:3d} | [{:5d}/{:5d}] | '
                  'ms/batch {:5.2f} | '
                  'loss {:5.2f} | val loss {:5.2f}'.format(
                    epoch, i, len(train_loader),
                    elapsed * 1000 / log_interval,
                    loss.item(), v_loss))
                start_time = time.time()
                val_loss.append(v_loss)
                
                

| epoch   0 | [  199/71779] | ms/batch 330.42 | loss  2.47 | val loss  2.36
| epoch   0 | [  399/71779] | ms/batch 335.57 | loss  2.05 | val loss  2.20
| epoch   0 | [  599/71779] | ms/batch 326.07 | loss  2.22 | val loss  2.13
| epoch   0 | [  799/71779] | ms/batch 324.17 | loss  2.61 | val loss  2.05
| epoch   0 | [  999/71779] | ms/batch 325.72 | loss  2.32 | val loss  1.94
| epoch   0 | [ 1199/71779] | ms/batch 326.15 | loss  2.35 | val loss  1.88
| epoch   0 | [ 1399/71779] | ms/batch 325.92 | loss  2.38 | val loss  1.92
| epoch   0 | [ 1599/71779] | ms/batch 321.90 | loss  2.26 | val loss  1.87
| epoch   0 | [ 1799/71779] | ms/batch 323.25 | loss  2.26 | val loss  1.88
| epoch   0 | [ 1999/71779] | ms/batch 323.06 | loss  2.46 | val loss  1.89
| epoch   0 | [ 2199/71779] | ms/batch 324.19 | loss  1.64 | val loss  1.80
| epoch   0 | [ 2399/71779] | ms/batch 325.02 | loss  2.24 | val loss  1.86
| epoch   0 | [ 2599/71779] | ms/batch 324.96 | loss  1.74 | val loss  1.77
| epoch   0 

| epoch   0 | [21799/71779] | ms/batch 323.85 | loss  3.27 | val loss  2.47
| epoch   0 | [21999/71779] | ms/batch 326.98 | loss  3.40 | val loss  2.53
| epoch   0 | [22199/71779] | ms/batch 326.39 | loss  2.72 | val loss  2.45
| epoch   0 | [22399/71779] | ms/batch 324.89 | loss  2.48 | val loss  2.46
| epoch   0 | [22599/71779] | ms/batch 325.60 | loss  2.84 | val loss  2.54
| epoch   0 | [22799/71779] | ms/batch 325.64 | loss  3.22 | val loss  2.49
| epoch   0 | [22999/71779] | ms/batch 325.53 | loss  3.16 | val loss  2.59
| epoch   0 | [23199/71779] | ms/batch 325.90 | loss  2.95 | val loss  2.56
| epoch   0 | [23399/71779] | ms/batch 324.86 | loss  2.95 | val loss  2.57
| epoch   0 | [23599/71779] | ms/batch 326.53 | loss  3.35 | val loss  2.52
| epoch   0 | [23799/71779] | ms/batch 329.70 | loss  2.95 | val loss  2.56
| epoch   0 | [23999/71779] | ms/batch 325.92 | loss  2.60 | val loss  2.56
| epoch   0 | [24199/71779] | ms/batch 326.80 | loss  2.72 | val loss  2.60
| epoch   0 

| epoch   0 | [43399/71779] | ms/batch 324.87 | loss  3.72 | val loss  2.92
| epoch   0 | [43599/71779] | ms/batch 326.27 | loss  3.24 | val loss  2.90
| epoch   0 | [43799/71779] | ms/batch 329.10 | loss  2.74 | val loss  2.86
| epoch   0 | [43999/71779] | ms/batch 326.55 | loss  3.32 | val loss  2.90
| epoch   0 | [44199/71779] | ms/batch 325.72 | loss  3.53 | val loss  2.92
| epoch   0 | [44399/71779] | ms/batch 324.11 | loss  3.12 | val loss  2.84
| epoch   0 | [44599/71779] | ms/batch 324.49 | loss  3.15 | val loss  3.03
| epoch   0 | [44799/71779] | ms/batch 324.83 | loss  3.20 | val loss  2.85
| epoch   0 | [44999/71779] | ms/batch 330.46 | loss  3.26 | val loss  2.85
| epoch   0 | [45199/71779] | ms/batch 327.01 | loss  3.28 | val loss  2.86
| epoch   0 | [45399/71779] | ms/batch 332.98 | loss  3.40 | val loss  2.89
| epoch   0 | [45599/71779] | ms/batch 327.17 | loss  3.09 | val loss  2.98
| epoch   0 | [45799/71779] | ms/batch 324.36 | loss  4.11 | val loss  3.00
| epoch   0 

| epoch   0 | [64999/71779] | ms/batch 326.83 | loss  3.19 | val loss  2.97
| epoch   0 | [65199/71779] | ms/batch 325.66 | loss  3.59 | val loss  3.07
| epoch   0 | [65399/71779] | ms/batch 323.43 | loss  3.06 | val loss  3.06
| epoch   0 | [65599/71779] | ms/batch 321.69 | loss  2.92 | val loss  2.93
| epoch   0 | [65799/71779] | ms/batch 327.13 | loss  3.35 | val loss  3.08
| epoch   0 | [65999/71779] | ms/batch 323.91 | loss  3.93 | val loss  3.06
| epoch   0 | [66199/71779] | ms/batch 325.47 | loss  3.55 | val loss  3.08
| epoch   0 | [66399/71779] | ms/batch 325.11 | loss  3.50 | val loss  3.06
| epoch   0 | [66599/71779] | ms/batch 325.06 | loss  3.31 | val loss  3.07
| epoch   0 | [66799/71779] | ms/batch 325.17 | loss  3.40 | val loss  3.07
| epoch   0 | [66999/71779] | ms/batch 325.58 | loss  3.03 | val loss  3.05
| epoch   0 | [67199/71779] | ms/batch 325.58 | loss  3.58 | val loss  3.12
| epoch   0 | [67399/71779] | ms/batch 324.63 | loss  3.61 | val loss  2.98
| epoch   0 

KeyboardInterrupt: 

## Evaluate

In [15]:
from rouge_score import rouge_scorer
from rouge_score import scoring

class RougeScore:
    '''
    mostly from https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/evaluation/metrics.py 
    '''
    
    def __init__(self, score_keys=None)-> None:
        super().__init__()
        if score_keys is None:  
            self.score_keys = ["rouge1", "rouge2", "rougeLsum"]
        
        self.scorer = rouge_scorer.RougeScorer(self.score_keys)
        self.aggregator = scoring.BootstrapAggregator()
        
        
    @staticmethod
    def prepare_summary(summary):
            # Make sure the summary is not bytes-type
            # Add newlines between sentences so that rougeLsum is computed correctly.
            summary = summary.replace(" . ", " .\n")
            return summary
    
    def __call__(self, target, prediction):
        """Computes rouge score.''
        Args:
        targets: string
        predictions: string
        """

        target = self.prepare_summary(target)
        prediction = self.prepare_summary(prediction)
        
        self.aggregator.add_scores(self.scorer.score(target=target, prediction=prediction))

        return 
    
    def reset_states(self):
        self.rouge_list = []

    def result(self):
        result = self.aggregator.aggregate()
        
        for key in self.score_keys:
            score_text = "%s = %.2f, 95%% confidence [%.2f, %.2f]"%(
                key,
                result[key].mid.fmeasure*100,
                result[key].low.fmeasure*100,
                result[key].high.fmeasure*100
            )
            print(score_text)
        
        return {key: result[key].mid.fmeasure*100 for key in self.score_keys}

In [16]:
rouge_score = RougeScore()
predictions = []
for i, (input_ids, attention_mask, y) in enumerate(test_loader):
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    y = y.to(device)
        
    summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask)
    pred = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
    real = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in y]
    for pred_sent, real_sent in zip(pred, real):
        rouge_score(pred_sent, real_sent)
        predictions.append(str("pred sentence: " + pred_sent + "\n\n real sentence: " + real_sent))
    if i > 40:
        break
    
rouge_score.result()

rouge1 = 21.60, 95% confidence [20.65, 22.56]
rouge2 = 4.72, 95% confidence [4.19, 5.25]
rougeLsum = 12.84, 95% confidence [12.27, 13.34]


{'rouge1': 21.601131331699854,
 'rouge2': 4.72085075525385,
 'rougeLsum': 12.835751381295408}

In [17]:
for pred in predictions[:10]:
    print("------")
    print(pred)
    print("------")    

------
pred sentence: cynthia und faa, krakral platz. flugzeuge von flugflugzeugzeugen, die die armlehne, ob ein derart steste platz in der ffährdete. die armarmarmarmlelehne ernsteren problemen, dass die gesundheit und sicherheit gefähndet wurde, während es nicht nicht nicht nur nur unangeangegriffen. er schrumrumte, dass er von der stötet wurde. sie sagt, er sei sei ein fluggesellschaften, um die detrote flugbebebeteckt. der flug der ddächtigte den platz, um flugfächte zu hausen. "icht", sagt, dass sie sie in den flug

 real sentence: 0; experten bezweifeln, ob überfüllte flugzeuge passagiere gefährden. die us-verbraucherberatungsgruppe sagt, dass mindestabstände vorgeschrieben werden müssen. sicherheitstests in flugzeugen mit mehr beinfreiheit als von fluggesellschaften angeboten, werden durchgeführt.
------
------
pred sentence: rahul kumar, 17, kletterte über den zaun des kamla nehrtes zoos in ahmedabad, rannte, schriete auf die löwengegehege töte. kukukumar kumte in der nähe des 

In [18]:
result_path = "../data/result_german_t5_base.txt"
for pred in predictions:
    with open(result_path, "a") as file:
        file.write(pred + "\n")
