# Install reqs

In [None]:
!pip install gdown
!pip install pytorch-lightning==1.7.7
!pip install torch==1.11.0
!pip install torchmetrics==0.10.0
!pip install transformers==4.20.1
!pip install pandas
!pip install numpy
!pip install sklearn
!pip install wandb # Optional

# Import packages

In [None]:
import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from pytorch_lightning import LightningDataModule, seed_everything, Trainer
from pytorch_lightning.core.module import LightningModule
import torchmetrics.functional as MF

from collections import OrderedDict

import wandb
from pytorch_lightning.loggers import WandbLogger

import re

from transformers.optimization import Adafactor, AdafactorSchedule
from transformers import T5ForConditionalGeneration, T5TokenizerFast, T5Config


import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

import gdown

# Define datasets

## Dataset with labels

In [None]:
class SCDataset(Dataset):
    """Pytorch index dataset from Pandas DataFrame for conditional text generation, with labels provided.

    Attributes
    ----------
    dataframe : DataFrame
        Pandas DataFrame from which text is pulled
    tokenizer : PreTrainedTokenizer
        HuggingFace tokinizer
    source_text : str
        Column name for source text
    target_text : str
        Column name for target text
    """

    def __init__(
        self, dataframe, tokenizer, source_text, target_text
    ):
        """Load tokenizer and source and target columns from DataFrame into memory.
        Find maximum length of strings in each column.
        """

        self.tokenizer = tokenizer
        self.target_text = dataframe[target_text]
        self.source_text = dataframe[source_text]
        self.source_len = self.source_text.str.len().max()
        self.summ_len = self.target_text.str.len().max()     

    def __len__(self):
        """Get length of dataset.

        Returns
        -------
        int
            Length of dataset
        """

        return len(self.target_text)

    def __getitem__(self, index):
        """Get item from dataset by index.
        1) Index training sample from source and target columns.
        2) Basic cleaning of strings.
        3) Add finetuning prompt (not necessary). 
        4) Encode both strings with tokenizer, with padding to maximum length.
        5) Return attention masks, input ids (for target and source text) and clear target text.

        Parameters
        ----------
        index : int

        Returns
        -------
        source["input_ids"] : LongTensor
        source["attention_mask"]: LongTensor
        target["input_ids"] : LongTensor
        target["attention_mask"]: LongTensor
        target_text: str
            Clean target text, for metric calculation
        """
        
        source_text = self.source_text.iloc[index]
        target_text = self.target_text.iloc[index]

        source_text = " ".join(source_text.split())
        target_text = " ".join(target_text.split())

        source_text = re.compile(r'[\-\˗\֊\‐\‑\‒\–\—\⁻\₋\−\﹣\－]', re.UNICODE).sub('-', source_text)
        target_text = re.compile(r'[\-\˗\֊\‐\‑\‒\–\—\⁻\₋\−\﹣\－]', re.UNICODE).sub('-', target_text)

        source = self.tokenizer.batch_encode_plus(
            ["Исправление ошибок: "+source_text],
            max_length=self.source_len,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        target = self.tokenizer.batch_encode_plus(
            [target_text],
            max_length=self.summ_len,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )

        return source["input_ids"].squeeze(), source["attention_mask"].squeeze(), target["input_ids"].squeeze(), target["attention_mask"].squeeze(), target_text

## Dataset without labels

In [None]:
class SCPredictDataset(Dataset):
    """Pytorch index dataset from Pandas DataFrame for conditional text generation, without labels provided, suitible only for prediction

    Attributes
    ----------
    dataframe : DataFrame
        Pandas DataFrame from which text is pulled
    tokenizer : PreTrainedTokenizer
        HuggingFace tokinizer
    source_text : str
        Column name for source text
    """
    def __init__(
        self, dataframe, tokenizer, source_text
    ):
        self.tokenizer = tokenizer
        self.source_text = dataframe[source_text]
        self.source_len = self.source_text.str.len().max()
        print('in dataset init')

    def __len__(self):
        """Get length of dataset.

        Returns
        -------
        len : int
            Length of dataset
        """
        return len(self.source_text)

    def __getitem__(self, index):
        """Get item from dataset by index.
        1) Index inference sample from source column.
        2) Basic cleaning of string.
        3) Add finetuning prompt (not necessary). 
        4) Encode string with tokenizer, with padding to maximum length.
        5) Return attention mask, input ids (for source text).

        Parameters
        ----------
        index : int

        Returns
        -------
        source["input_ids"] : LongTensor
        source["attention_mask"]: LongTensor
        """

        source_text = self.source_text.iloc[index]
        source_text = " ".join(source_text.split())
        source_text = re.compile(r'[\-\˗\֊\‐\‑\‒\–\—\⁻\₋\−\﹣\－]', re.UNICODE).sub('-', source_text)

        source = self.tokenizer.batch_encode_plus(
            ["Исправление ошибок: "+source_text],
            max_length=self.source_len,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        ) 
        return source["input_ids"].squeeze(), source["attention_mask"].squeeze()

# Training args

In [None]:
args = {}
args['workers'] = 0 # Memory and CPU load stability.
args['epochs'] = 1  # Model has been trained only on one epoch. My guess, what on 2 epochs I could achive around 0.67 ACC, but I didn't have enough time, to do so.
args['warmup'] = 0  # No warmup on 1 ecpoch training.
args['batch_size'] = 64 # ~12GB GPU RAM.
args['lr'] = 0.001 # Taken from T5 paper. Suitable for one epoch.
args['weight_decay'] = 5e-2 # Taken from T5 paper. Suitable for one epoch.
args['min_lr'] = 0.00001 # One epoch, so no scheduler, no need in that.
args['seed'] = 42 # Random seed for stability.

class Args:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        
args = Args(**args)

# Define model 

In [None]:
class SCModel(LightningModule):
    """Custom lightning module for model.

    Attributes
    ----------
    args : Args
        Args object
    model : PreTrainedModel
        HuggingFace text model
    tokenizer : PreTrainedTokenizer
        HuggingFace tokinizer
    """
    def __init__(self, args, model, tokenizer):
        """Create module. Ignoring model and tokenizer hyp-es.
        """
        super().__init__()
        self.save_hyperparameters(ignore=["model", "tokenizer"])
        self.model = model
        self.tokenizer = tokenizer
        self.args = args
        self.training_epoch_end = None
        
    def forward(self, batch):
        """Forward pass of HF text model.

        Parameters
        ----------
        batch : tuple
            Input batch

        Returns
        -------
        torch.FloatTensor
            Loss of model
        """
        source_ids, source_mask, target_ids, target_mask, _ = batch # Unpack batch.

        target_ids[target_ids[:, :] == self.tokenizer.pad_token_id] = -100 # Set padding tokens attention in target to -100, to not influence the loss.
        
        return self.model(
            input_ids=source_ids,
            attention_mask=source_mask,
            labels=target_ids,
            decoder_attention_mask=target_mask
        )[0] #calculate loss of T5

    def val_forward(self, batch):
        """Inference pass of HF text model.

        Parameters
        ----------
        batch : tuple
            Input batch

        Returns
        -------
        list[str]
            List of decoded tokens
        """      
        source_ids = batch[0] # Unpack batch.
        source_mask = batch[1] # Unpack batch.
        generated_ids = self.model.generate(
            input_ids = source_ids,
            attention_mask = source_mask, 
            max_length=16, # Common maximum from train and public test datasets.
            num_beams=10,# Kinda overkill, but still ok time for inference.
            repetition_penalty=5.0, # Noticed many repetitions of punctuaction (!?-.,), so decided to increase that penalty.
            length_penalty=1.0, # No need in high penalty, good length distribution.
            early_stopping=True, # Ensure stability.
            top_p=0.75 # Previously was on 0.99, decided to look at wider range of answers.
        )
        return [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids] # Decode without special tokens and excessive spaces.

    def training_step(self, batch, batch_idx):
        """Training step.
    
        Parameters
        ----------
        batch : tuple
            Input batch
        batch_idx : int

        Returns
        -------
        FloatTensor
            Trainig loss
        """
        loss = self.forward(batch)
        self.log("train_loss", loss) # Log trainig loss
        return loss

    def validation_step(self, batch, batch_idx):
        """Validation step.

        Parameters
        ----------
        batch : tuple
            Input batch
        batch_idx : int
        """
        loss = self.forward(batch)
        self.log("val_loss", loss)# Log validation loss

    def test_step(self, batch, batch_idx):
        """Testing step.
        Calculate losses and metrics.
        Decided to calulate metrics only on test step to speedup training.

        Parameters
        ----------
        batch : tuple
            Input batch
        batch_idx: int

        Returns
        -------
        str
            Predicted string
        """
        y = batch[-1]
        loss = self.forward(batch)
        self.log("test_loss", loss) # Log loss
        preds = self.val_forward(batch) # Inference
        self.log('acc',MF.classification.accuracy(torch.Tensor([1 if i == j and j != '' else 0 for (i,j) in zip(preds, y)]), torch.ones((len(preds))).long()), on_epoch = True) # Calculate CLS accuracy (like in leaderboard)
        self.log('bleu',MF.bleu_score(preds, y), on_epoch = True) # Calculate BLEU score
        self.log('cer',MF.char_error_rate(preds, y), on_epoch = True) # Calculate char error rate
        self.log('wer',MF.word_error_rate(preds, y), on_epoch = True) # Calculate word error rate
        
        return preds
    
    
    def predict_step(self, batch, batch_idx):
        """Inference step.
        Predict corrected string.

        Parameters
        ----------
        batch : tuple
            Input batch
        batch_idx: int

        Returns
        -------
        str
            Predicted string
        """
        return self.val_forward(batch)
    
    
    def configure_optimizers(self):
        """Configuring optimizers for PL module.
        Using Adafactor with parameters as in original T5 paper.

        No need in scheduler because of 1 epoch finetuning.
        """
        optimizer = Adafactor(self.model.parameters(), scale_parameter=False, relative_step=False, lr=self.args.lr)
        lr_scheduler = False #AdafactorSchedule(optimizer)

        if lr_scheduler:
            return {
                "optimizer": optimizer,
                "lr_scheduler": lr_scheduler,
                "monitor": "val_loss",
            }
        else:
            return {"optimizer": optimizer}

# Load and prepare data

In [None]:
df = pd.read_csv('./data/train.csv') # Or your path.

In [None]:
df['correct_text'] = df['correct_text'].str.replace(r'\s+', ' ', regex=True) # Removing excessive spaces
df['corrupted_text'] = df['corrupted_text'].str.replace(r'\s+', ' ', regex=True) # Removing excessive spaces

In [None]:
df_full, df_test = train_test_split(df, stratify=df['category'], test_size=0.01, random_state = args.seed) # train+val / test split. Model never sees 1% of data. 

In [None]:
df_train, df_val = train_test_split(df_full, stratify=df_full['category'], test_size=0.15, random_state = args.seed) # train/val split 85-15

# Prepare HuggingFace ruT5-base model and tokinizer

In [None]:
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
# Using pretrained ruT5 from sber, t2t objective. Decided to use base model (not large) for lower speed and memory consumption.
tokenizer = T5TokenizerFast.from_pretrained("sberbank-ai/ruT5-base") 
t_model = T5ForConditionalGeneration.from_pretrained("sberbank-ai/ruT5-base", return_dict = True).to(dev)

# Define datasets and loaders

In [None]:
train = SCDataset(df_train, tokenizer, 'corrupted_text', 'correct_text')
val = SCDataset(df_val, tokenizer, 'corrupted_text', 'correct_text')
test = SCDataset(df_test, tokenizer, 'corrupted_text', 'correct_text')

In [None]:
train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
val_loader = DataLoader(val, batch_size=args.batch_size//2, shuffle=False, num_workers=args.workers, pin_memory=True) # Divide batch size on 2 to ensure stable memory consumption.
test_loader = DataLoader(test, batch_size=args.batch_size//2, shuffle=False, num_workers=args.workers, pin_memory=True) # Divide batch size on 2 to ensure stable memory consumption.

# Define PL Trainer

In [None]:
seed_everything(args.seed)
model = SCModel(args, model = t_model, tokenizer = tokenizer)
trainer = Trainer(
    max_epochs=args.epochs,
    num_sanity_val_steps=0,
    devices=1, # Trained on one GPU
    accelerator="auto", 
    logger=WandbLogger(project="t5ru-optimizing-nontok-batches"), # I used Wandb to track training of models.
    default_root_dir = './checkpoint',
#     callbacks=[checkpoint_callback, lr_monitor_callback, early_stop_callback], # Optional
    log_every_n_steps=20, # To speed up training
    accumulate_grad_batches = 4 # Accumulating batches to speed up training.
)

# Training loop

In [None]:
trainer.fit(model, train_loader, val_loader)

# Test loop

In [None]:
trainer.test(model, test_loader)

# Inference

In [None]:
#Downloading checkpoint from google drive
gd_id = '1lmT1tCbX-s3MDxMhvO1NvXbMffeXpPLp'
gdown.download(f'https://drive.google.com/uc?id={gd_id}', 'checkpoint.ckpt', quiet=False)

In [None]:
#Prepare checkpoint to load in PL module. 
state_dict = OrderedDict({
    k.replace('model.', ''):v
    for k,v in torch.load('./checkpoint.ckpt', map_location=dev)["state_dict"].items()
})

In [None]:
#Defining clear T5 model (no weights).
config = T5Config.from_pretrained(
    "sberbank-ai/ruT5-base"
)
tokenizer = T5TokenizerFast.from_pretrained("sberbank-ai/ruT5-base")
t_model = T5ForConditionalGeneration(config).to(dev)

In [None]:
pred_batch_size = int(input("Размер батча для предсказания: "))

In [None]:
df_pred = pd.read_csv('./data/private_test.csv') # Or your path.
pred = SCPredictDataset(df_pred, tokenizer, 'corrupted_text')
pred_loader = DataLoader(pred, batch_size=pred_batch_size, shuffle=False, num_workers=0, pin_memory=True)

In [None]:
seed_everything(args.seed)
model = SCModel(args, model = t_model, tokenizer = tokenizer)
model.model.load_state_dict(state_dict) # Loading weights.

trainer = Trainer(
    max_epochs=args.epochs,
    num_sanity_val_steps=0,
    devices=1,
    accelerator="auto"
)

In [None]:
predictions_df = trainer.predict(model, pred_loader, return_predictions=True) # Inference loop.

In [None]:
predictions_df_df = pd.DataFrame([item for sublist in predictions_df for item in sublist]) # Convert list[list] to DataFrame.

In [None]:
predictions_df_df.to_csv('sample.csv', index=False, header = False) # Save to csv.