# German Text Summary with the tranlated CNN Daily Mail Dataset
This time We will try out a Seq2Seq mBart Model which is pretrained on German data 

In [1]:
from pathlib import Path
import torch
import re
import time

from transformers import BertTokenizer, EncoderDecoderModel

In [2]:
BATCH_SIZE = 6

SHUFFEL_SIZE = 1024

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

learning_rate = 3e-5

EPOCHS = 1

log_interval = 200

## Model

In [3]:
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-german-cased", "bert-base-german-cased").to(device)

tokenizer = BertTokenizer.from_pretrained("bert-base-german-cased")


# model = EncoderDecoderModel.from_encoder_decoder_pretrained("distilbert-base-german-cased", "distilbert-base-german-cased").to(device)

# tokenizer = BertTokenizer.from_pretrained("distilbert-base-german-cased")

In [4]:
# CLS token will work as BOS token
tokenizer.bos_token = tokenizer.cls_token

# SEP token will work as EOS token
tokenizer.eos_token = tokenizer.sep_token

# set decoding params
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.max_length = 142
model.config.min_length = 56
model.config.no_repeat_ngram_size = 3
model.early_stopping = True
model.length_penalty = 2.0
model.num_beams = 4

In [5]:
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate, weight_decay=0.0001)

## Load Dataset

In [6]:
def get_translated_ds(name):
    article_path = "../data/%s/articles_german" % name
    highlights_path = "../data/%s/highlights_german" % name

    articles = [x.rstrip() for x in open(article_path).readlines()]
    highlights = [x.rstrip() for x in open(highlights_path).readlines()]
    
    assert len(articles) == len(highlights)
    return articles, highlights

In [7]:
train_x, train_y = get_translated_ds("train")
test_x, test_y = get_translated_ds("test")
val_x, val_y = get_translated_ds("val")

In [8]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, articles, highlights):
        self.x = articles
        self.y = highlights
        
    def __getitem__(self, index):
        x = tokenizer.encode_plus(self.transfrom(self.x[index][3:]), max_length=512, return_tensors="pt", pad_to_max_length=True)
        y = tokenizer.encode(self.transfrom(self.y[index][3:]), max_length=150, return_tensors="pt", pad_to_max_length=True)
        return x['input_ids'].view(-1), x['attention_mask'].view(-1), y.view(-1)
    
    @staticmethod
    def transfrom(x):
        x = re.sub("'(.*)'", r"\1", x)
        return x
    
    def __len__(self):
        return len(self.x)

In [9]:
train_ds = MyDataset(train_x, train_y) 
val_ds = MyDataset(val_x, val_y)
test_ds = MyDataset(test_x, test_y)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE)

x, x_mask, y = next(iter(val_loader))
x.shape, x_mask.shape, y.shape

(torch.Size([6, 512]), torch.Size([6, 512]), torch.Size([6, 150]))

## Train

In [10]:
pad_token_id = tokenizer.pad_token_id
def step(inputs_ids, attention_mask, y):
    y_ids = y[:, :-1].contiguous()
    lm_labels = y[:, 1:].clone()
    lm_labels[y[:, 1:] == pad_token_id] = -100
    output = model(inputs_ids, attention_mask=attention_mask, decoder_input_ids=y_ids, lm_labels=lm_labels)
#     output = model(inputs_ids, decoder_input_ids=y, lm_labels=y)
    return output[0] # loss

In [11]:
train_loss = []
val_loss = []
for epoch in range(EPOCHS):
    model.train() 
    start_time = time.time()
    for i, (inputs_ids, attention_mask, y) in enumerate(train_loader):
        inputs_ids = inputs_ids.to(device)
        attention_mask = attention_mask.to(device)
        y = y.to(device)
        
        
        optimizer.zero_grad()
        loss = step(inputs_ids, attention_mask, y)
        train_loss.append(loss.item())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
            
        if (i + 1) % log_interval == 0:
            with torch.no_grad():
                x, x_mask, y = next(iter(val_loader))
                x = x.to(device)
                x_mask = x_mask.to(device)
                y = y.to(device)
                
                v_loss = step(x, x_mask, y)
                v_loss = v_loss.item()
                
                
                elapsed = time.time() - start_time
                print('| epoch {:3d} | [{:5d}/{:5d}] | '
                  'ms/batch {:5.2f} | '
                  'loss {:5.2f} | val loss {:5.2f}'.format(
                    epoch, i, len(train_loader),
                    elapsed * 1000 / log_interval,
                    loss.item(), v_loss))
                start_time = time.time()
                val_loss.append(v_loss)

| epoch   0 | [  199/47853] | ms/batch 231.27 | loss  6.75 | val loss  6.78
| epoch   0 | [  399/47853] | ms/batch 232.28 | loss  6.30 | val loss  6.62
| epoch   0 | [  599/47853] | ms/batch 229.37 | loss  6.59 | val loss  6.54
| epoch   0 | [  799/47853] | ms/batch 230.47 | loss  6.31 | val loss  6.46
| epoch   0 | [  999/47853] | ms/batch 231.63 | loss  5.96 | val loss  6.40


KeyboardInterrupt: 

## Evaluate

In [12]:
from rouge_score import rouge_scorer
from rouge_score import scoring

class RougeScore:
    '''
    mostly from https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/evaluation/metrics.py 
    '''
    
    def __init__(self, score_keys=None)-> None:
        super().__init__()
        if score_keys is None:  
            self.score_keys = ["rouge1", "rouge2", "rougeLsum"]
        
        self.scorer = rouge_scorer.RougeScorer(self.score_keys)
        self.aggregator = scoring.BootstrapAggregator()
        
        
    @staticmethod
    def prepare_summary(summary):
            # Make sure the summary is not bytes-type
            # Add newlines between sentences so that rougeLsum is computed correctly.
            summary = summary.replace(" . ", " .\n")
            return summary
    
    def __call__(self, target, prediction):
        """Computes rouge score.''
        Args:
        targets: string
        predictions: string
        """

        target = self.prepare_summary(target)
        prediction = self.prepare_summary(prediction)
        
        self.aggregator.add_scores(self.scorer.score(target=target, prediction=prediction))

        return 
    
    def reset_states(self):
        self.rouge_list = []

    def result(self):
        result = self.aggregator.aggregate()
        
        for key in self.score_keys:
            score_text = "%s = %.2f, 95%% confidence [%.2f, %.2f]"%(
                key,
                result[key].mid.fmeasure*100,
                result[key].low.fmeasure*100,
                result[key].high.fmeasure*100
            )
            print(score_text)
        
        return {key: result[key].mid.fmeasure*100 for key in self.score_keys}

In [13]:
rouge_score = RougeScore()
predictions = []
for i, (input_ids, attention_mask, y) in enumerate(test_loader):
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    y = y.to(device)
        
    summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask)
    pred = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
    real = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in y]
    for pred_sent, real_sent in zip(pred, real):
        rouge_score(pred_sent, real_sent)
        predictions.append(str("pred sentence: " + pred_sent + "\n\n real sentence: " + real_sent))
    if i > 10:
        break
    
rouge_score.result()

Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence


rouge1 = 10.95, 95% confidence [10.06, 11.90]
rouge2 = 0.09, 95% confidence [0.00, 0.21]
rougeLsum = 9.97, 95% confidence [9.10, 10.82]


{'rouge1': 10.950125198510523,
 'rouge2': 0.0855510805759562,
 'rougeLsum': 9.970408419899043}

In [14]:
for pred in predictions[:10]:
    print("------")
    print(pred)
    print("------") 

------
pred sentence: ;sss ,ss - , , , der der - - - , der - von - - " " " , " "e " " . " " - " , der " "y " " in " " ist " " mit " "s " " wird " " der " - - . " . ist "ey " , ein " "er " "

 real sentence: Experten bezweifeln , ob überfüllte Flugzeuge Passagiere gefährden . Die US - Verbraucherberatungsgruppe sagt , dass Mindestabstände vorgeschrieben werden müssen . Sicherheitstests in Flugzeugen mit mehr Beinfreiheit als von Fluggesellschaften angeboten , werden durchgeführt .
------
------
pred sentence: ;sss - - - , , , der der - -ige - - von , , und , , " " " , " , , ein " "r " " . " "e " " der " " - " " in " "y " " wird " " für " "al " "

 real sentence: Betrunkener Teenager kletterte in Löwengehege eines Zoos in Westindien : Rahul Kumar , 17 , lief auf Tiere zu und schrie : " Heute töte ich einen Löwen ! " Glücklicherweise stürzte er in einen Wassergraben , bevor er die Löwen erreichte und wurde gerettet .
------
------
pred sentence: ; - - - , , , der - -ige - - von - - " " " 

In [None]:
result_path = "../data/bert_result_german.txt"
for pred in predictions:
    with open(result_path, "a") as file:
        file.write(pred + "\n")
