# German Text Summary with the tranlated CNN Daily Mail Dataset
and T5 from Huggingface Pytorch

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from pathlib import Path
import torch
import re
import time

In [2]:
BATCH_SIZE = 16

SHUFFEL_SIZE = 1024

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

learning_rate = 3e-5

EPOCHS = 5

log_interval = 200

In [3]:
print(device)

cuda:0


## Define Model

In [4]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small').to(device)

task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
    model.config.update(task_specific_params.get("summarization", {}))
    

optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate, weight_decay=0.0001)

## Read the tranlated German Dataset

In [5]:
def get_dict_data(list_input):
    ret_dict = {}
    for input_item in list_input:
        data_point = input_item.split(";")
        data_id = data_point[0]
        data = ""
        for item in data_point[1:]:
            data += item
        ret_dict[int(data_id)] = data.strip()
    return ret_dict
    
def get_translated_ds(name):
    article_path = "../data/%s/articles_german" % name
    highlights_path = "../data/%s/highlights_german" % name

    articles = [x.rstrip() for x in open(article_path).readlines()]
    highlights = [x.rstrip() for x in open(highlights_path).readlines()]

    len_articles = len(articles)
    
    articles = get_dict_data(articles)
    highlights = get_dict_data(highlights)
    
    cleaned_articles = []
    cleaned_highlights = []

    for i in range(len_articles):
        if i in articles.keys() and i in highlights.keys():
            cleaned_articles.append(articles[i])
            cleaned_highlights.append(highlights[i])
        else:
            is_in_art = i in articles.keys()
            text = "is not in "
            if not is_in_art:
                text += "articles"
            else:
                text += "highlights"
            
            print(name, i, text)
            
    assert len(cleaned_articles) == len(cleaned_highlights)
    return cleaned_articles, cleaned_highlights

In [6]:
train_x, train_y = get_translated_ds("train")
test_x, test_y = get_translated_ds("test")
val_x, val_y = get_translated_ds("val")

train 287113 is not in articles
train 287114 is not in articles
train 287115 is not in articles
train 287116 is not in articles
train 287117 is not in articles
train 287118 is not in articles
train 287119 is not in articles
train 287120 is not in articles
train 287121 is not in articles
train 287122 is not in articles
train 287123 is not in articles
train 287124 is not in articles
train 287125 is not in articles
train 287126 is not in articles
test 11490 is not in highlights
test 11491 is not in highlights
test 11492 is not in highlights
test 11493 is not in highlights
test 11494 is not in highlights
test 11495 is not in highlights
test 11496 is not in highlights
test 11497 is not in highlights
test 11498 is not in highlights
test 11499 is not in articles
val 13368 is not in articles


## Show some Dataset Example 

In [7]:
import numpy as np
for i in range(5):
    rand_int = np.random.randint(len(train_x))
    print("\n--------")
    print(train_x[rand_int][:1500])
    print()
    print(train_y[rand_int])


--------
Eine deutsche Oper wurde nach einem Sturm der Kritik abgesagt, nachdem darin Darsteller in Nazi-Uniform, eine massakrierte jüdische Familie und eine brutale Vergewaltigungsszene zu sehen waren. Schon nach 30 Minuten verließ das Publikum Burkhard C. Kosminskis Interpretation von Wagners Klassiker "Tannhäuser" und einige mussten wegen "psychologischer Traumata" ärztlich behandelt werden. Die Rheinoper in Düsseldorf sagte, sie habe zwar gewusst, dass das Konzert "kontrovers" sein würde, aber nicht mit einer derart extremen Reaktion gerechnet. Kritik: Die Oper verärgerte einige Zuschauer, weil darin Darsteller in Nazi-Uniform zu sehen waren. Kontrovers: In einer Szene vergewaltigt die Figur Wolfram, links, Elisabeth, rechts brutal. Kosminksi sagte, er habe die antisemitische Haltung von Komponisten wie Wagner, der als einer der Favoriten Adolf Hitlers galt, "thematisieren" wollen. In einer Szene kamen nackte Darsteller in verrauchten Gaskammern auf die Bühne, um die als SS-Offizi

## Define Pytorch Dataset

In [8]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, articles, highlights):
        self.x = articles
        self.y = highlights
        
    def __getitem__(self, index):
        x = tokenizer.encode_plus(model.config.prefix + self.transfrom(self.x[index]), max_length=512, return_tensors="pt", pad_to_max_length=True)
        y = tokenizer.encode(self.transfrom(self.y[index]), max_length=150, return_tensors="pt", pad_to_max_length=True)
        return x['input_ids'].view(-1), x['attention_mask'].view(-1), y.view(-1)
    
    @staticmethod
    def transfrom(x):
        x = x.lower()
        x = re.sub("'(.*)'", r"\1", x)
        return x
    
    def __len__(self):
        return len(self.x)

In [9]:
train_ds = MyDataset(train_x, train_y) 
val_ds = MyDataset(val_x, val_y)
test_ds = MyDataset(test_x, test_y)

In [10]:
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE)

In [11]:
x, x_mask, y = next(iter(val_loader))
x.shape, x_mask.shape, y.shape

(torch.Size([16, 512]), torch.Size([16, 512]), torch.Size([16, 150]))

## Define Step function

In [12]:
pad_token_id = tokenizer.pad_token_id
def step(inputs_ids, attention_mask, y):
    y_ids = y[:, :-1].contiguous()
    lm_labels = y[:, 1:].clone()
    lm_labels[y[:, 1:] == pad_token_id] = -100
    output = model(inputs_ids, attention_mask=attention_mask, decoder_input_ids=y_ids, lm_labels=lm_labels)
    return output[0] # loss

## Train

In [13]:
train_loss = []
val_loss = []
for epoch in range(EPOCHS):
    model.train() 
    start_time = time.time()
    for i, (inputs_ids, attention_mask, y) in enumerate(train_loader):
        inputs_ids = inputs_ids.to(device)
        attention_mask = attention_mask.to(device)
        y = y.to(device)
        
        
        optimizer.zero_grad()
        loss = step(inputs_ids, attention_mask, y)
        train_loss.append(loss.item())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
            
        if (i + 1) % log_interval == 0:
            with torch.no_grad():
                x, x_mask, y = next(iter(val_loader))
                x = x.to(device)
                x_mask = x_mask.to(device)
                y = y.to(device)
                
                v_loss = step(x, x_mask, y)
                v_loss = v_loss.item()
                
                
                elapsed = time.time() - start_time
                print('| epoch {:3d} | [{:5d}/{:5d}] | '
                  'ms/batch {:5.2f} | '
                  'loss {:5.2f} | val loss {:5.2f}'.format(
                    epoch, i, len(train_loader),
                    elapsed * 1000 / log_interval,
                    loss.item(), v_loss))
                start_time = time.time()
                val_loss.append(v_loss)
                
                

| epoch   0 | [  199/17945] | ms/batch 408.75 | loss  3.19 | val loss  2.82
| epoch   0 | [  399/17945] | ms/batch 404.60 | loss  3.05 | val loss  2.73
| epoch   0 | [  599/17945] | ms/batch 403.42 | loss  3.15 | val loss  2.66
| epoch   0 | [  799/17945] | ms/batch 405.32 | loss  2.73 | val loss  2.67
| epoch   0 | [  999/17945] | ms/batch 403.11 | loss  2.93 | val loss  2.62
| epoch   0 | [ 1199/17945] | ms/batch 405.05 | loss  2.61 | val loss  2.58
| epoch   0 | [ 1399/17945] | ms/batch 404.82 | loss  2.62 | val loss  2.57
| epoch   0 | [ 1599/17945] | ms/batch 407.41 | loss  2.73 | val loss  2.52
| epoch   0 | [ 1799/17945] | ms/batch 401.32 | loss  2.59 | val loss  2.51
| epoch   0 | [ 1999/17945] | ms/batch 403.44 | loss  2.63 | val loss  2.47
| epoch   0 | [ 2199/17945] | ms/batch 407.13 | loss  2.44 | val loss  2.48
| epoch   0 | [ 2399/17945] | ms/batch 409.26 | loss  2.90 | val loss  2.45
| epoch   0 | [ 2599/17945] | ms/batch 407.93 | loss  2.51 | val loss  2.46
| epoch   0 

| epoch   1 | [ 3999/17945] | ms/batch 400.61 | loss  2.81 | val loss  2.56
| epoch   1 | [ 4199/17945] | ms/batch 402.63 | loss  2.76 | val loss  2.58
| epoch   1 | [ 4399/17945] | ms/batch 403.30 | loss  2.86 | val loss  2.56
| epoch   1 | [ 4599/17945] | ms/batch 403.00 | loss  2.79 | val loss  2.57
| epoch   1 | [ 4799/17945] | ms/batch 401.96 | loss  2.74 | val loss  2.61
| epoch   1 | [ 4999/17945] | ms/batch 403.38 | loss  2.83 | val loss  2.59
| epoch   1 | [ 5199/17945] | ms/batch 402.37 | loss  3.05 | val loss  2.59
| epoch   1 | [ 5399/17945] | ms/batch 403.24 | loss  2.77 | val loss  2.59
| epoch   1 | [ 5599/17945] | ms/batch 405.12 | loss  2.57 | val loss  2.60
| epoch   1 | [ 5799/17945] | ms/batch 417.72 | loss  2.86 | val loss  2.62
| epoch   1 | [ 5999/17945] | ms/batch 407.31 | loss  3.12 | val loss  2.59
| epoch   1 | [ 6199/17945] | ms/batch 406.59 | loss  2.63 | val loss  2.62
| epoch   1 | [ 6399/17945] | ms/batch 403.91 | loss  2.82 | val loss  2.67
| epoch   1 

| epoch   2 | [ 7799/17945] | ms/batch 404.14 | loss  3.34 | val loss  2.88
| epoch   2 | [ 7999/17945] | ms/batch 403.84 | loss  3.12 | val loss  2.91
| epoch   2 | [ 8199/17945] | ms/batch 403.05 | loss  3.13 | val loss  2.95
| epoch   2 | [ 8399/17945] | ms/batch 399.09 | loss  3.35 | val loss  2.88
| epoch   2 | [ 8599/17945] | ms/batch 404.01 | loss  3.12 | val loss  2.92
| epoch   2 | [ 8799/17945] | ms/batch 403.11 | loss  2.73 | val loss  2.92
| epoch   2 | [ 8999/17945] | ms/batch 401.30 | loss  3.07 | val loss  2.89
| epoch   2 | [ 9199/17945] | ms/batch 404.62 | loss  3.21 | val loss  2.88
| epoch   2 | [ 9399/17945] | ms/batch 402.01 | loss  2.91 | val loss  2.87
| epoch   2 | [ 9599/17945] | ms/batch 402.04 | loss  2.93 | val loss  2.91
| epoch   2 | [ 9799/17945] | ms/batch 403.43 | loss  3.01 | val loss  2.90
| epoch   2 | [ 9999/17945] | ms/batch 403.10 | loss  2.89 | val loss  2.94
| epoch   2 | [10199/17945] | ms/batch 403.67 | loss  3.20 | val loss  2.89
| epoch   2 

| epoch   3 | [11599/17945] | ms/batch 404.77 | loss  3.28 | val loss  3.02
| epoch   3 | [11799/17945] | ms/batch 406.24 | loss  2.97 | val loss  3.04
| epoch   3 | [11999/17945] | ms/batch 402.52 | loss  2.87 | val loss  2.98
| epoch   3 | [12199/17945] | ms/batch 403.72 | loss  3.19 | val loss  3.05
| epoch   3 | [12399/17945] | ms/batch 403.16 | loss  3.30 | val loss  3.07
| epoch   3 | [12599/17945] | ms/batch 405.17 | loss  3.03 | val loss  2.98
| epoch   3 | [12799/17945] | ms/batch 404.35 | loss  3.14 | val loss  3.02
| epoch   3 | [12999/17945] | ms/batch 410.91 | loss  3.17 | val loss  3.02
| epoch   3 | [13199/17945] | ms/batch 402.60 | loss  3.18 | val loss  2.98
| epoch   3 | [13399/17945] | ms/batch 402.95 | loss  3.33 | val loss  3.00
| epoch   3 | [13599/17945] | ms/batch 402.59 | loss  3.48 | val loss  3.01
| epoch   3 | [13799/17945] | ms/batch 403.45 | loss  3.34 | val loss  3.00
| epoch   3 | [13999/17945] | ms/batch 403.32 | loss  2.73 | val loss  3.01
| epoch   3 

| epoch   4 | [15399/17945] | ms/batch 401.65 | loss  3.16 | val loss  3.03
| epoch   4 | [15599/17945] | ms/batch 404.04 | loss  3.33 | val loss  3.01
| epoch   4 | [15799/17945] | ms/batch 402.67 | loss  2.86 | val loss  3.00
| epoch   4 | [15999/17945] | ms/batch 403.41 | loss  3.10 | val loss  3.04
| epoch   4 | [16199/17945] | ms/batch 399.90 | loss  3.21 | val loss  3.02
| epoch   4 | [16399/17945] | ms/batch 399.19 | loss  2.99 | val loss  3.02
| epoch   4 | [16599/17945] | ms/batch 402.67 | loss  3.21 | val loss  3.04
| epoch   4 | [16799/17945] | ms/batch 400.43 | loss  3.42 | val loss  3.05
| epoch   4 | [16999/17945] | ms/batch 401.25 | loss  3.48 | val loss  3.03
| epoch   4 | [17199/17945] | ms/batch 398.50 | loss  3.00 | val loss  3.02
| epoch   4 | [17399/17945] | ms/batch 402.09 | loss  3.45 | val loss  2.98
| epoch   4 | [17599/17945] | ms/batch 406.05 | loss  3.10 | val loss  3.01
| epoch   4 | [17799/17945] | ms/batch 406.62 | loss  3.57 | val loss  3.05


## Evaluate

In [14]:
from rouge_score import rouge_scorer
from rouge_score import scoring

class RougeScore:
    '''
    mostly from https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/evaluation/metrics.py 
    '''
    
    def __init__(self, score_keys=None)-> None:
        super().__init__()
        if score_keys is None:  
            self.score_keys = ["rouge1", "rouge2", "rougeLsum"]
        
        self.scorer = rouge_scorer.RougeScorer(self.score_keys)
        self.aggregator = scoring.BootstrapAggregator()
        
        
    @staticmethod
    def prepare_summary(summary):
            # Make sure the summary is not bytes-type
            # Add newlines between sentences so that rougeLsum is computed correctly.
            summary = summary.replace(" . ", " .\n")
            return summary
    
    def __call__(self, target, prediction):
        """Computes rouge score.''
        Args:
        targets: string
        predictions: string
        """

        target = self.prepare_summary(target)
        prediction = self.prepare_summary(prediction)
        
        self.aggregator.add_scores(self.scorer.score(target=target, prediction=prediction))

        return 
    
    def reset_states(self):
        self.rouge_list = []

    def result(self):
        result = self.aggregator.aggregate()
        
        for key in self.score_keys:
            score_text = "%s = %.2f, 95%% confidence [%.2f, %.2f]"%(
                key,
                result[key].mid.fmeasure*100,
                result[key].low.fmeasure*100,
                result[key].high.fmeasure*100
            )
            print(score_text)
        
        return {key: result[key].mid.fmeasure*100 for key in self.score_keys}

In [15]:
rouge_score = RougeScore()
predictions = []
for i, (input_ids, attention_mask, y) in enumerate(test_loader):
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    y = y.to(device)
        
    summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask)
    pred = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
    real = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in y]
    for pred_sent, real_sent in zip(pred, real):
        rouge_score(pred_sent, real_sent)
        predictions.append(str("pred sentence: " + pred_sent + "\n\n real sentence: " + real_sent))
    if i > 40:
        break
    
rouge_score.result()

rouge1 = 25.95, 95% confidence [25.37, 26.58]
rouge2 = 7.55, 95% confidence [7.15, 7.92]
rougeLsum = 14.76, 95% confidence [14.40, 15.09]


{'rouge1': 25.953370416274623,
 'rouge2': 7.5470369582583405,
 'rougeLsum': 14.763113105472748}

In [16]:
for pred in predictions[:10]:
    print("------")
    print(pred)
    print("------")    

------
pred sentence: flugzeugzeuge, ob ein der schrumpfenpfende platz in sicherheit gefährdet hat. cynthia corbertt, forscherin, krachenden zoll, sagt, dass es sei, wie schnell passagiere ein fluggesellschaften hat, sagt er. faa für eine humane faktoren, der bei einigen giere in rückenlehne, während eine behandlung der gesundheit und sicherheit, aber keine mind mindestestmaßmaße auf raum und nahrungsgruppe. ernsteren. viele flugflugzeugen. charlie leochacorcorbertt die kosten von 31 zolllen, fußußußernnsterstersteren problemen, aber die detroit news,, hat die 

 real sentence: experten bezweifeln, ob überfüllte flugzeuge passagiere gefährden. die us-verbraucherberatungsgruppe sagt, dass mindestabstände vorgeschrieben werden müssen. sicherheitstests in flugzeugen mit mehr beinfreiheit als von fluggesellschaften angeboten, werden durchgeführt.
------
------
pred sentence: rahul kumar, 17, kletterte in das löwengehehehege in ahmedabad und rannte auf die tiere. der berauschte, er hätte ei

In [18]:
result_path = "../data/result_german.txt"
for pred in predictions:
    with open(result_path, "a") as file:
        file.write(pred + "\n")
