# German Text Summary with the tranlated CNN Daily Mail Dataset
This time We will try out a Seq2Seq Bert Model which is pretrained on German data 

In [1]:
from pathlib import Path
import torch
import re
import time

from transformers import BertTokenizer, EncoderDecoderModel

In [2]:
BATCH_SIZE = 6

SHUFFEL_SIZE = 1024

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

learning_rate = 3e-5

EPOCHS = 1

log_interval = 200

## Model

In [3]:
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-german-cased", "bert-base-german-cased").to(device)

tokenizer = BertTokenizer.from_pretrained("bert-base-german-cased")


# model = EncoderDecoderModel.from_encoder_decoder_pretrained("distilbert-base-german-cased", "distilbert-base-german-cased").to(device)

# tokenizer = BertTokenizer.from_pretrained("distilbert-base-german-cased")

In [4]:
# CLS token will work as BOS token
tokenizer.bos_token = tokenizer.cls_token

# SEP token will work as EOS token
tokenizer.eos_token = tokenizer.sep_token

# set decoding params
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.max_length = 142
model.config.min_length = 56
model.config.no_repeat_ngram_size = 3
model.early_stopping = True
model.length_penalty = 2.0
model.num_beams = 4

In [5]:
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate, weight_decay=0.0001)

## Load Dataset

In [6]:
def get_translated_ds(name):
    article_path = "../data/%s/articles_german" % name
    highlights_path = "../data/%s/highlights_german" % name

    articles = [x.rstrip() for x in open(article_path).readlines()]
    highlights = [x.rstrip() for x in open(highlights_path).readlines()]
    
    assert len(articles) == len(highlights)
    return articles, highlights

In [7]:
train_x, train_y = get_translated_ds("train")
test_x, test_y = get_translated_ds("test")
val_x, val_y = get_translated_ds("val")

In [8]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, articles, highlights):
        self.x = articles
        self.y = highlights
        
    def __getitem__(self, index):
        x = tokenizer.encode_plus(self.transfrom(self.x[index]), max_length=512, return_tensors="pt", pad_to_max_length=True)
        y = tokenizer.encode(self.transfrom(self.y[index]), max_length=150, return_tensors="pt", pad_to_max_length=True)
        return x['input_ids'].view(-1), x['attention_mask'].view(-1), y.view(-1)
    
    @staticmethod
    def transfrom(x):
        x = re.sub("'(.*)'", r"\1", x)
        return x
    
    def __len__(self):
        return len(self.x)

In [9]:
train_ds = MyDataset(train_x, train_y) 
val_ds = MyDataset(val_x, val_y)
test_ds = MyDataset(test_x, test_y)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE)

x, x_mask, y = next(iter(val_loader))
x.shape, x_mask.shape, y.shape

(torch.Size([6, 512]), torch.Size([6, 512]), torch.Size([6, 150]))

## Train

In [10]:
pad_token_id = tokenizer.pad_token_id
def step(inputs_ids, attention_mask, y):
    y_ids = y[:, :-1].contiguous()
    lm_labels = y[:, 1:].clone()
    lm_labels[y[:, 1:] == pad_token_id] = -100
    output = model(inputs_ids, attention_mask=attention_mask, decoder_input_ids=y_ids, lm_labels=lm_labels)
    return output[0] # loss

In [16]:
train_loss = []
val_loss = []
for epoch in range(EPOCHS):
    model.train() 
    start_time = time.time()
    for i, (inputs_ids, attention_mask, y) in enumerate(train_loader):
        inputs_ids = inputs_ids.to(device)
        attention_mask = attention_mask.to(device)
        y = y.to(device)
        
        
        optimizer.zero_grad()
        loss = step(inputs_ids, attention_mask, y)
        train_loss.append(loss.item())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
            
        if (i + 1) % log_interval == 0:
            with torch.no_grad():
                x, x_mask, y = next(iter(val_loader))
                x = x.to(device)
                x_mask = x_mask.to(device)
                y = y.to(device)
                
                v_loss = step(x, x_mask, y)
                v_loss = v_loss.item()
                
                
                elapsed = time.time() - start_time
                print('| epoch {:3d} | [{:5d}/{:5d}] | '
                  'ms/batch {:5.2f} | '
                  'loss {:5.2f} | val loss {:5.2f}'.format(
                    epoch, i, len(train_loader),
                    elapsed * 1000 / log_interval,
                    loss.item(), v_loss))
                start_time = time.time()
                val_loss.append(v_loss)

| epoch   0 | [  199/47853] | ms/batch 219.17 | loss  6.04 | val loss  6.29
| epoch   0 | [  399/47853] | ms/batch 220.98 | loss  6.04 | val loss  6.28
| epoch   0 | [  599/47853] | ms/batch 219.11 | loss  6.57 | val loss  6.28
| epoch   0 | [  799/47853] | ms/batch 220.19 | loss  6.39 | val loss  6.32
| epoch   0 | [  999/47853] | ms/batch 220.33 | loss  6.11 | val loss  6.31
| epoch   0 | [ 1199/47853] | ms/batch 220.84 | loss  6.27 | val loss  6.29
| epoch   0 | [ 1399/47853] | ms/batch 218.24 | loss  6.23 | val loss  6.29
| epoch   0 | [ 1599/47853] | ms/batch 220.21 | loss  5.80 | val loss  6.35
| epoch   0 | [ 1799/47853] | ms/batch 221.47 | loss  6.13 | val loss  6.31
| epoch   0 | [ 1999/47853] | ms/batch 220.87 | loss  6.46 | val loss  6.36
| epoch   0 | [ 2199/47853] | ms/batch 220.49 | loss  6.16 | val loss  6.38
| epoch   0 | [ 2399/47853] | ms/batch 221.15 | loss  6.15 | val loss  6.42
| epoch   0 | [ 2599/47853] | ms/batch 220.86 | loss  6.27 | val loss  6.37
| epoch   0 

| epoch   0 | [21799/47853] | ms/batch 222.39 | loss  6.32 | val loss  6.45
| epoch   0 | [21999/47853] | ms/batch 220.21 | loss  6.36 | val loss  6.47
| epoch   0 | [22199/47853] | ms/batch 220.65 | loss  6.22 | val loss  6.50
| epoch   0 | [22399/47853] | ms/batch 221.04 | loss  6.45 | val loss  6.43
| epoch   0 | [22599/47853] | ms/batch 221.80 | loss  6.33 | val loss  6.48
| epoch   0 | [22799/47853] | ms/batch 221.40 | loss  6.43 | val loss  6.45
| epoch   0 | [22999/47853] | ms/batch 220.76 | loss  6.29 | val loss  6.50
| epoch   0 | [23199/47853] | ms/batch 220.18 | loss  6.52 | val loss  6.48
| epoch   0 | [23399/47853] | ms/batch 221.55 | loss  6.09 | val loss  6.43
| epoch   0 | [23599/47853] | ms/batch 220.06 | loss  6.51 | val loss  6.46
| epoch   0 | [23799/47853] | ms/batch 221.38 | loss  6.20 | val loss  6.46
| epoch   0 | [23999/47853] | ms/batch 220.30 | loss  6.52 | val loss  6.42
| epoch   0 | [24199/47853] | ms/batch 221.55 | loss  6.29 | val loss  6.49
| epoch   0 

KeyboardInterrupt: 

## Evaluate

In [17]:
from rouge_score import rouge_scorer
from rouge_score import scoring

class RougeScore:
    '''
    mostly from https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/evaluation/metrics.py 
    '''
    
    def __init__(self, score_keys=None)-> None:
        super().__init__()
        if score_keys is None:  
            self.score_keys = ["rouge1", "rouge2", "rougeLsum"]
        
        self.scorer = rouge_scorer.RougeScorer(self.score_keys)
        self.aggregator = scoring.BootstrapAggregator()
        
        
    @staticmethod
    def prepare_summary(summary):
            # Make sure the summary is not bytes-type
            # Add newlines between sentences so that rougeLsum is computed correctly.
            summary = summary.replace(" . ", " .\n")
            return summary
    
    def __call__(self, target, prediction):
        """Computes rouge score.''
        Args:
        targets: string
        predictions: string
        """

        target = self.prepare_summary(target)
        prediction = self.prepare_summary(prediction)
        
        self.aggregator.add_scores(self.scorer.score(target=target, prediction=prediction))

        return 
    
    def reset_states(self):
        self.rouge_list = []

    def result(self):
        result = self.aggregator.aggregate()
        
        for key in self.score_keys:
            score_text = "%s = %.2f, 95%% confidence [%.2f, %.2f]"%(
                key,
                result[key].mid.fmeasure*100,
                result[key].low.fmeasure*100,
                result[key].high.fmeasure*100
            )
            print(score_text)
        
        return {key: result[key].mid.fmeasure*100 for key in self.score_keys}

In [21]:
rouge_score = RougeScore()
predictions = []
for i, (input_ids, attention_mask, y) in enumerate(test_loader):
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    y = y.to(device)
        
    summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask)
    pred = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
    real = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in y]
    for pred_sent, real_sent in zip(pred, real):
        rouge_score(pred_sent, real_sent)
        predictions.append(str("pred sentence: " + pred_sent + "\n\n real sentence: " + real_sent))
    if i > 40:
        break
    
rouge_score.result()

Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate sequence
Setting `pad_token_id` to 4 (first `eos_token_id`) to generate s

rouge1 = 13.29, 95% confidence [12.57, 13.93]
rouge2 = 0.21, 95% confidence [0.14, 0.30]
rougeLsum = 11.68, 95% confidence [11.08, 12.28]


{'rouge1': 13.289427538754687,
 'rouge2': 0.21045961905902774,
 'rougeLsum': 11.683025084899315}

In [22]:
for pred in predictions[:10]:
    print("------")
    print(pred)
    print("------") 

------
pred sentence: ##93 ;EU Die dess -s - - -s " " " , die der " " - " "en " " sagt dass " " der " der dess "s " sagt " " . sagt dass die der der der des - - der des "s der deses " "s ist der des . sagt er " " sagte dass " der der "e " " ist . . sagt die der , sei " " nicht " "

 real sentence: 0 ; Experten bezweifeln , ob überfüllte Flugzeuge Passagiere gefährden . Die US - Verbraucherberatungsgruppe sagt , dass Mindestabstände vorgeschrieben werden müssen . Sicherheitstests in Flugzeugen mit mehr Beinfreiheit als von Fluggesellschaften angeboten , werden durchgeführt .
------
------
pred sentence: ##92 ;EU Die vonnn der vonN - - -n wurden , der der von Polizei , , , und Polizei , . Polizei , und , , . , , in , , die der , , der , und Frau , , von , , " , , zu , , sie , , sich , , ein , , er , , im von Polizei zu , er zu , sie . Polizei . Polizei zu und , und wurden , , wegen , ,n , , mit Polizei , in zu , die , , es , , dem zu , der zu , es . Polizei wurde , , Polizei , zu . Poliz

In [23]:
result_path = "../data/bert_result_german.txt"
for pred in predictions:
    with open(result_path, "a") as file:
        file.write(pred + "\n")
