In [None]:
%%capture
!pip install -r https://raw.githubusercontent.com/shitkov/mt5_text_generation/main/requirements.txt

In [None]:
import re
import os
import os.path
import json
import gzip

from bs4 import BeautifulSoup
from tqdm import tqdm
import pandas as pd
import numpy as np
from razdel import tokenize
from rouge import Rouge
from sklearn.model_selection import train_test_split

import torch
from torch import cuda
from transformers import MT5Model, T5Tokenizer, MT5ForConditionalGeneration
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler

In [None]:
class ExtData:
    
    def __init__(self, path, lines):
        '''
        path: path to json archive
        lines: lines for work
        '''
        self.path = path
        self.lines = lines

    def _clean_text(self, text):
        '''
        Clean input text
        '''
        soup = BeautifulSoup(text, features="html.parser")
        
        for script in soup(["script", "style"]):
            script.extract()
        
        text = soup.get_text()
        
        text = re.sub(r'[^А-Яа-я0-9ЁёA-Za-z :%.,!?-]', ' ', text)
        text = re.sub(r" +", " ", text).strip()
        text = text.replace(' .', '.')
        text = text.replace(' ,', ',')
        
        return text
    
    def load_data(self):
        '''
        1. Download archived file
        2. Unpacking
        3. Cleaning data
        '''
        # download compressed data
        os.system('wget ' + self.path)

        # unpacking
        with gzip.open(self.path.split('/')[-1], 'rb') as f:
            file_content = f.read()
            file_content = file_content.decode('utf-8')
            file_content = file_content.splitlines()

        # cleaning
        texts = []
        titles = []
        for line in tqdm(file_content[:self.lines]):
            try:
                s = json.loads(line)
                text = self._clean_text(s['text'])
                title = self._clean_text(s['title'])
                
                if text != '' and title != '':
                    texts.append(text)
                    titles.append(title)
            except:
                pass
        
        return texts, titles

    def get_summ_len(self, data):
        '''
        Get maximum summary length
        '''
        summ_len_list = [len(list(tokenize(i))) for i in data]
        summ_max_len = np.max(summ_len_list)
        summ_len = min((int(summ_max_len/10) + 1) * 10, MAX_LEN)
        return summ_len

In [None]:
class CustomDataset(Dataset):

    def __init__(self, texts, titles, tokenizer, source_len, summ_len):
        self.texts = texts
        self.summaries = titles
        self.tokenizer = tokenizer
        self.source_len = source_len
        self.summ_len = summ_len
        
    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        summary = self.summaries[idx]

        text = self.tokenizer.batch_encode_plus(
                [text],
                max_length=self.source_len,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
        summary = self.tokenizer.batch_encode_plus(
                [summary],
                max_length=self.summ_len,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

        source_ids = text['input_ids'].squeeze()
        source_mask = text['attention_mask'].squeeze()
        target_ids = summary['input_ids'].squeeze()
        target_mask = summary['attention_mask'].squeeze()

        return {
            'input_ids': source_ids.to(dtype=torch.long), 
            'attention_mask': source_mask.to(dtype=torch.long), 
            'target_ids': target_ids.to(dtype=torch.long),
            'target_attention_mask': target_mask.to(dtype=torch.long)
            }

In [None]:
class Summarizer:
    
    def __init__(self, path, lr, max_sent_len, summ_len, epochs, data_train, data_valid, labels_train, labels_valid, news_batch_size):
        self.model_name = 'google/mt5-small'
        self.path = path
        self.tokenizer = None
        self.model = None
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.optimizer = None
        self.lr = lr
        self.max_sent_len = max_sent_len
        self.summ_len = summ_len
        self.epochs = epochs
        self.data_train = None
        self.data_valid = data_valid
        self.labels_train = None
        self.labels_valid = labels_valid

        self._get_model()
        self._get_optimizer()
        self._get_slices(data_train, labels_train, news_batch_size)

    def fit(self, loader):
        self.model.train()
        for step, data in enumerate(loader, 0):
            ids = data['input_ids'].to(self.device, dtype = torch.long)
            labels = data['target_ids'].to(self.device, dtype = torch.long)
            labels[labels == self.tokenizer.pad_token_id] = -100

            outputs = self.model(
                input_ids=ids,
                labels=labels
                )

            loss = outputs[0]

            if step % min(100, int(len(loader)/10)) == 0:
                s = 'Step: ' + str(step) + '/' + str(len(loader))
                print(str(s))
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

    def predict(self, loader):
        self.model.eval()
        predictions = []
        actuals = []

        with torch.no_grad():
            for step, data in enumerate(loader, 0):
                y = data['target_ids'].to(self.device, dtype = torch.long)
                ids = data['input_ids'].to(self.device, dtype = torch.long)
                mask = data['attention_mask'].to(self.device, dtype = torch.long)

                generated_ids = self.model.generate(
                    input_ids = ids,
                    attention_mask = mask, 
                    max_length=self.summ_len
                    )
                
                preds = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
                target = [self.tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)for t in y]

                predictions.extend(preds)
                actuals.extend(target)

        return predictions, actuals

    def train(self):
        start_slice = self._get_start_slice()

        valid_set = CustomDataset(
                self.data_valid,
                self.labels_valid,
                self.tokenizer,
                self.max_sent_len,
                self.summ_len
            )
        
        valid_loader = DataLoader(
                valid_set,
                batch_size=1,
                shuffle=False,
                num_workers=0
            )
          
        current_score = 0
        for epoch in range(self.epochs):
            for slc in range(start_slice, len(self.data_train)):
                train_set = CustomDataset(
                    self.data_train[slc],
                    self.labels_train[slc],
                    self.tokenizer,
                    self.max_sent_len,
                    self.summ_len)
                
                training_loader = DataLoader(
                        train_set,
                        batch_size=1,
                        shuffle=True,
                        num_workers=0
                    )

                self.fit(training_loader)
                hypothesis, reference = self.predict(valid_loader)
                
                score = get_rouge_score(hypothesis, reference)

                print('Epoch: {0}, Slice: {1}, rouge: {2}'.format(epoch, slc, score))
                print('=========')

                model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
                model_to_save.save_pretrained(self.path)                

                with open(self.path +' log.txt', 'w') as f:
                    f.write(str(slc) +'\n')
            
            if start_slice == (len(self.data_train) - 1):
                try:
                    os.remove(self.path + 'log.txt')
                    start_slice = 0
                except:
                    pass

    def _get_optimizer(self):
        self.optimizer = torch.optim.Adam(
                params=self.model.parameters(),
                lr=self.lr
            )

    def _get_model(self):
        try:
            self.tokenizer = T5Tokenizer.from_pretrained(self.path)
            self.model = MT5ForConditionalGeneration.from_pretrained(self.path).to(self.device)
        except:
            self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
            self.model = MT5ForConditionalGeneration.from_pretrained(self.model_name).to(self.device)

    def _get_start_slice(self):
        logfile_path = self.path + 'log.txt'

        if os.path.exists(logfile_path):
            with open(logfile_path, 'r') as f:
                start_slice = int(f.readlines()[-1])
        else:
            with open(logfile_path, 'w+') as f:
                f.write('0\n')
                start_slice = 0

        return start_slice

    def _get_slices(self, x_train, y_train, nbs):
        self.data_train   = [x_train[i * nbs: (i + 1) * nbs] for i in range(int(len(x_train)/nbs))]
        self.labels_train = [y_train[i * nbs: (i + 1) * nbs] for i in range(int(len(y_train)/nbs))]

In [None]:
def get_rouge_score(hypothesis, reference):
    rouge = Rouge()

    hypothesis = [i if len(i) > 0 else 'empty generate error' for i in hypothesis]
    
    scores = rouge.get_scores(hypothesis, reference)
    r1_fm = np.mean([s['rouge-1']['f'] for s in scores])
    r2_fm = np.mean([s['rouge-2']['f'] for s in scores])
    rlcs_fm = np.mean([s['rouge-l']['f'] for s in scores])
    return round(np.mean([r1_fm, r2_fm, rlcs_fm]), 3)

In [None]:
RANDOM_STATE = 42
EPOCHS = 2
LEARNING_RATE = 1e-4
SEED = 42
MAX_LEN = 512
LINES = 1000
NEWS_BATCH_SIZE = 100
TEST_SIZE = 100
SAVE_PATH = '/content/'
PATH = 'https://raw.githubusercontent.com/shitkov/mt5_text_generation/main/ria_10k.json.gz'

In [None]:
ed = ExtData(PATH, LINES)
texts, titles = ed.load_data()

100%|██████████| 1000/1000 [00:01<00:00, 824.54it/s]


In [None]:
X_train, X_test, y_train, y_test = train_test_split(texts, titles, test_size=TEST_SIZE*2, random_state=RANDOM_STATE)
X_valid, X_test, y_valid, y_test = train_test_split(X_test, y_test, test_size=TEST_SIZE, random_state=RANDOM_STATE)

In [None]:
summ_len = ed.get_summ_len(y_train)

In [None]:
torch.manual_seed(SEED)
np.random.seed(SEED)
torch.backends.cudnn.deterministic = True

In [None]:
summarizer = Summarizer(
    path=SAVE_PATH,
    lr=LEARNING_RATE,
    max_sent_len=MAX_LEN,
    summ_len=summ_len,
    epochs=EPOCHS,
    data_train=X_train,
    data_valid=X_valid,
    labels_train=y_train,
    labels_valid=y_valid,
    news_batch_size=NEWS_BATCH_SIZE
)

Downloading:   0%|          | 0.00/4.31M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/82.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/553 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.20G [00:00<?, ?B/s]

In [None]:
summarizer.train()

Step: 0/100
Step: 10/100
Step: 20/100
Step: 30/100
Step: 40/100
Step: 50/100
Step: 60/100
Step: 70/100
Step: 80/100
Step: 90/100
Epoch: 0, Slice: 0, rouge: 0.022
Step: 0/100
Step: 10/100
Step: 20/100
Step: 30/100
Step: 40/100
Step: 50/100
Step: 60/100
Step: 70/100
Step: 80/100
Step: 90/100
Epoch: 0, Slice: 1, rouge: 0.02
Step: 0/100
Step: 10/100
Step: 20/100
Step: 30/100
Step: 40/100
Step: 50/100
Step: 60/100
Step: 70/100
Step: 80/100
Step: 90/100
Epoch: 0, Slice: 2, rouge: 0.034
Step: 0/100
Step: 10/100
Step: 20/100
Step: 30/100
Step: 40/100
Step: 50/100
Step: 60/100
Step: 70/100
Step: 80/100
Step: 90/100
Epoch: 0, Slice: 3, rouge: 0.024
Step: 0/100
Step: 10/100
Step: 20/100
Step: 30/100
Step: 40/100
Step: 50/100
Step: 60/100
Step: 70/100
Step: 80/100
Step: 90/100
Epoch: 0, Slice: 4, rouge: 0.027
Step: 0/100
Step: 10/100
Step: 20/100
Step: 30/100
Step: 40/100
Step: 50/100
Step: 60/100
Step: 70/100
Step: 80/100
Step: 90/100
Epoch: 0, Slice: 5, rouge: 0.022
Step: 0/100
Step: 10/100
Step

In [None]:
test_set = CustomDataset(
        X_test,
        y_test,
        summarizer.tokenizer,
        summarizer.max_sent_len,
        summarizer.summ_len
    )

test_loader = DataLoader(
        test_set,
        batch_size=1,
        shuffle=False,
        num_workers=0
    )

In [None]:
hypothesis, reference = summarizer.predict(test_loader)                

In [None]:
score = get_rouge_score(hypothesis, reference)

In [None]:
score

0.061