## Setup

In [None]:
!pip -q install transformers

In [None]:
from glob import glob
from sklearn.model_selection import GroupKFold, StratifiedKFold
import torch
from torch import nn
import os
import itertools
from datetime import datetime
import time
import random
import torchvision
import pandas as pd
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.cuda.amp import autocast, GradScaler
from torch.nn.modules.loss import _WeightedLoss
import torch.nn.functional as F

import sklearn
import warnings
import joblib
from sklearn.metrics import roc_auc_score, log_loss
from sklearn import metrics

# Hugging Face
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# rich: for a better display on terminal
from rich.table import Column, Table
from rich import box
from rich.console import Console

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Running on device: {device}')

In [None]:
CFG = {
    'show_examples': False,
    'data_dir': '/kaggle/input/drowsy-eye-keypoints',
    'cluster_sent_n': 20,
    'seed': 719,
    'model_arch': "VietAI/vit5-base-vietnews-summarization",
    'source_len': 1024,
    'target_len': 256,
    'epochs': 3,
    'train_bs': 2,
    'valid_bs': 2,
    'T_0': 10,
    'lr': 1e-4,
    'min_lr': 1e-6,
    'weight_decay':1e-6,
    'num_workers': 4,
    'accum_iter': 2, # suppoprt to do batch accumulation for backprop with effectively larger batch size
    'verbose_step': 1,
}

In [None]:
data_dir = '/kaggle/input/vims-dataset/ViMs'
original_dir = os.path.join(data_dir, 'original')
summary_dir = os.path.join(data_dir, 'summary')

## Utils

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [None]:
# define a rich console logger
console = Console(record=True)

# to display dataframe in ASCII format
def display_df(df):
    """display dataframe in ASCII format"""

    console = Console()
    table = Table(
        Column("source_text", justify="center"),
        Column("target_text", justify="center"),
        title="Sample Data",
        pad_edge=False,
        box=box.ASCII,
    )

    for i, row in enumerate(df.values.tolist()):
        table.add_row(row[0], row[1])

    console.print(table)

# training logger to log training progress
training_logger = Table()

def resetTable():
    global training_logger

    training_logger = Table(
    Column("Epoch", justify="center"),
    Column("Steps", justify="center"),
    Column("Loss", justify="center"),
    title="Training Status",
    pad_edge=False,
    box=box.ASCII,
)
resetTable()

In [None]:
def read_txt(path, article_type, sent=False):
    content = []
    write_file = False
    with open(path) as f:
        for line in f:
            if article_type == "original":
                if line.lower().startswith("content"):
                    write_file = True
            else:
                write_file = True
            if write_file: 
                if line.rstrip():
                    content.append(line.rstrip())
    if sent:
        return content[1:]
    return " ".join(content[1:])
if CFG['show_examples']:
    path1 = '/kaggle/input/vims-dataset/ViMs/original/Cluster_001/original/10.txt'
    path2 = '/kaggle/input/vims-dataset/ViMs/summary/Cluster_001/0.gold.txt'
    print(read_txt(path2, article_type="summary", sent=True))

## CSV File

In [None]:
def create_csv(data_dir):
    """
    Input: data_dir
    - dir format: data_dir/original/cluster/original/txt
    Output: csv
    """
    df = {'cluster':[], 'path':[]}
    for cluster in os.listdir(data_dir):
        file_type = data_dir[data_dir.rfind("/")+1:]
        if file_type == "original":
            f_path = os.path.join(data_dir, cluster, file_type)
        else:
            f_path = os.path.join(data_dir, cluster)
        for f in glob(f_path + '/*'):
            df['cluster'].append(cluster)
            df['path'].append(f)

    df = pd.DataFrame(df)
    df = df.groupby('cluster')['path'].apply(list).reset_index()
    return df

original_df = create_csv(original_dir)
original_df.columns = ['cluster', 'original_dir']

summary_df = create_csv(summary_dir)
summary_df.columns = ['cluster', 'summary_dir']

df = original_df.merge(summary_df, how='inner', on='cluster')
if CFG['show_examples']:
    print(len(df))
    print(len(df['cluster'].unique()))
    print(df.head())

## Dataset

In [None]:
class MyDataset(Dataset):

    def __init__(
        self, dataframe, tokenizer, source_len, target_len, source_dir="original_dir", target_dir="summary_dir"
    ):
        """
        Initializes a Dataset class

        Args:
            dataframe (pandas.DataFrame): Input dataframe
            tokenizer (transformers.tokenizer): Transformers tokenizer
            source_len (int): Max length of source text
            target_len (int): Max length of target text
        """
        self.tokenizer = tokenizer
        self.df = dataframe
        self.source_len = source_len
        self.target_len = target_len
        self.source_dir = self.df[source_dir]
        self.target_dir = self.df[target_dir]

    def __len__(self):
        """returns the length of dataframe"""

        return len(self.df)

    def __getitem__(self, index):
        """return the input ids, attention masks and target ids"""
        
        ## Source text
        source_dir_list = self.source_dir[index]
        source_text = []
        for i in range(len(source_dir_list)):
            txt = read_txt(source_dir_list[i], article_type="original", sent=True)
            source_text += txt
        
        source_ids = random.sample(range(len(source_text)), 
                                   min(CFG['cluster_sent_n'], len(source_text)))
        source_ids = sorted(source_ids)
        source_text = " ".join([source_text[i] for i in source_ids])
        
        ## Target text
        target_id = random.choice([0,1])
        target_dir_list = self.target_dir[index]
        target_text = read_txt(target_dir_list[target_id], article_type="summary")
        
        ## Tokenize
        source = self.tokenizer.batch_encode_plus(
            [source_text],
            max_length=self.source_len,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )

        target = self.tokenizer.batch_encode_plus(
            [target_text],
            max_length=self.target_len,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )

        source_ids = source["input_ids"].squeeze()
        source_mask = source["attention_mask"].squeeze()
        target_ids = target["input_ids"].squeeze()
        target_mask = target["attention_mask"].squeeze()

        return {
            "source_txt": source_text,
            "target_txt": target_text,
            "source_ids": source_ids.to(dtype=torch.long),
            "source_mask": source_mask.to(dtype=torch.long),
            "target_ids": target_ids.to(dtype=torch.long),
            "target_mask": target_mask.to(dtype=torch.long),
        }

if CFG['show_examples']:
    tokenizer = AutoTokenizer.from_pretrained(CFG['model_arch'])
    dataset = MyDataset(df, tokenizer, CFG['source_len'], CFG['target_len'])
    for i, output in enumerate(dataset):
        print("Source:")
        print("Source article: \n", output['source_txt'])
        print("Source input ids length: \n", len(output['source_ids']))
        print("Source input ids: \n", output['source_ids'])
        print("Source attention mask: \n", output['source_mask'])
        print("\n")
        print("Target:\n")
        print("Target text: \n", output['source_txt'])
        print("Target input ids length: \n", len(output['source_ids']))
        print("Target attention mask: \n", output['source_mask'])
        if i >= 0:
            break

## Train

In [None]:
def train(epoch, tokenizer, model, device, loader, optimizer):

    """
    Function to be called for training with the parameters passed from main function

    """

    model.train()
    for _, data in enumerate(loader, 0):
        y = data["target_ids"].to(device, dtype=torch.long)
        y_ids = y[:, :-1].contiguous()
        lm_labels = y[:, 1:].clone().detach()
        lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100
        ids = data["source_ids"].to(device, dtype=torch.long)
        mask = data["source_mask"].to(device, dtype=torch.long)

        outputs = model(
            input_ids=ids,
            attention_mask=mask,
            decoder_input_ids=y_ids,
            labels=lm_labels,
        )
        loss = outputs[0]

        if _ % 10 == 0:
            training_logger.add_row(str(epoch), str(_), str(loss))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    console.print(training_logger)
    resetTable()

In [None]:
def validate(epoch, tokenizer, model, device, loader):

  """
  Function to evaluate model for predictions

  """
  model.eval()
  predictions = []
  actuals = []
  with torch.no_grad():
      for _, data in enumerate(loader, 0):
          y = data['target_ids'].to(device, dtype = torch.long)
          ids = data['source_ids'].to(device, dtype = torch.long)
          mask = data['source_mask'].to(device, dtype = torch.long)

          generated_ids = model.generate(
              input_ids = ids,
              attention_mask = mask, 
              max_length=256, 
              num_beams=2,
              repetition_penalty=2.5, 
              length_penalty=1.0, 
              early_stopping=True
              )
          preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
          target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)for t in y]
          if _%10==0:
              console.print(f'Completed {_}')

          predictions.extend(preds)
          actuals.extend(target)
  return predictions, actuals

In [None]:
def T5Trainer(dataframe, output_dir="/kaggle/working/"):

    """
    T5 trainer
    """

    # Set random seeds and deterministic pytorch for reproducibility
    seed_everything(CFG['seed'])

    # logging
    console.log(f"""[Model]: Loading {CFG["model_arch"]}...\n""")

    # tokenzier for encoding the text
    tokenizer = AutoTokenizer.from_pretrained(CFG['model_arch'])

    # Defining the model
    model = AutoModelForSeq2SeqLM.from_pretrained(CFG['model_arch'])
    model = model.to(device)

    # logging
    console.log(f"[Data]: Reading data...\n")

    # Creation of Dataset and Dataloader
    train_size = 0.8
    train_dataset = dataframe.sample(frac=train_size, random_state=CFG["seed"])
    val_dataset = dataframe.drop(train_dataset.index).reset_index(drop=True)
    train_dataset = train_dataset.reset_index(drop=True)

    console.print(f"FULL Dataset: {dataframe.shape}")
    console.print(f"TRAIN Dataset: {train_dataset.shape}")
    console.print(f"TEST Dataset: {val_dataset.shape}\n")

    # Creating the Training and Validation dataset for further creation of Dataloader
    training_set = MyDataset(
        train_dataset,
        tokenizer,
        CFG['source_len'], 
        CFG['target_len']
    )
    val_set = MyDataset(
        val_dataset,
        tokenizer,
        CFG['source_len'], 
        CFG['target_len']
    )

    # Defining the parameters for creation of dataloaders
    train_params = {
        "batch_size": CFG["train_bs"],
        "shuffle": True,
        "num_workers": 0,
    }

    val_params = {
        "batch_size": CFG["valid_bs"],
        "shuffle": False,
        "num_workers": 0,
    }

    # Creation of Dataloaders for testing and validation. This will be used down for training and validation stage for the model.
    training_loader = DataLoader(training_set, **train_params)
    val_loader = DataLoader(val_set, **val_params)

    # Defining the optimizer that will be used to tune the weights of the network in the training session.
    optimizer = torch.optim.Adam(
        params=model.parameters(), lr=CFG["lr"]
    )

    # Training loop
    console.log(f"[Initiating Fine Tuning]...\n")

    for epoch in range(CFG["epochs"]):
        train(epoch, tokenizer, model, device, training_loader, optimizer)

    console.log(f"[Saving Model]...\n")
    # Saving the model after training
    path = os.path.join(output_dir, "model_files")
    model.save_pretrained(path)
    tokenizer.save_pretrained(path)

    # evaluating test dataset
    console.log(f"[Initiating Validation]...\n")
    predictions, actuals = validate(0, tokenizer, model, device, val_loader)
    final_df = pd.DataFrame({"Generated Text": predictions, "Actual Text": actuals})
    final_df.to_csv(os.path.join(output_dir, "predictions.csv"))

    console.save_text(os.path.join(output_dir, "logs.txt"))

    console.log(f"[Validation Completed.]\n")
    console.print(
        f"""[Model] Model saved @ {os.path.join(output_dir, "model_files")}\n"""
    )
    console.print(
        f"""[Validation] Generation on Validation data saved @ {os.path.join(output_dir,'predictions.csv')}\n"""
    )
    console.print(f"""[Logs] Logs saved @ {os.path.join(output_dir,'logs.txt')}\n""")

## Main

In [None]:
if __name__ == '__main__':
    T5Trainer(df)