# Libraries

In [1]:
! pip install colorlog

Collecting colorlog
  Downloading colorlog-6.8.2-py3-none-any.whl (11 kB)
Installing collected packages: colorlog
Successfully installed colorlog-6.8.2


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:

from collections import defaultdict
import dataclasses
from dataclasses import dataclass, field
import json
import logging
import numpy as np
import os
import pickle
import re
import sys
import time

from tqdm import tqdm
from typing import Dict, List, Optional

from termcolor import colored
from sklearn.model_selection import train_test_split
from torch.optim import AdamW, Adam
from torch.utils.data import (
    Dataset,
    DataLoader,
    RandomSampler,
    SequentialSampler
)
import torch
import torch.nn as nn
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollator,
    EvalPrediction,
    T5ForConditionalGeneration,
    T5TokenizerFast,
    T5Tokenizer,
    Trainer,
    TrainingArguments
)

# Helper functions

In [4]:
def save_pkl(save_object, save_file):
    with open(save_file, 'wb') as f:
        pickle.dump(save_object, f, protocol=pickle.HIGHEST_PROTOCOL)

def load_pkl(load_file):
    with open(load_file, 'rb') as f:
        output = pickle.load(f)
    return output

def load_json(path):
    with open(path, 'r', encoding='utf-8') as json_file:
        data = json.load(json_file)
    json_file.close()
    return data

def write_to_json(output_path, docs):
    with open(output_path, 'w', encoding="utf-8") as fw:
        json.dump(docs, fw, ensure_ascii=False, indent=4)
    fw.close()

def check_path(path):
    if not os.path.exists(path):
        os.mkdir(path)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [5]:
def load_dataset(data_path, test_size=0.2):
    json_data = load_json(data_path)

    # perform train test split
    train_dataset, test_dataset = train_test_split(json_data, test_size=test_size)
    train_dataset, val_dataset = train_test_split(train_dataset, test_size=0.1/(1-test_size))
    return train_dataset, val_dataset, test_dataset

# Dataset and DataLoader

In [6]:
class LegalQADataset(Dataset):
    def __init__(self, data: list):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data_row = self.data[index]
        return data_row

## FID Collator

In [7]:
def encode_passages(batch_text_passages, tokenizer, max_length):
    passage_ids, passage_masks = [], []
    for k, text_passages in enumerate(batch_text_passages):
        p = tokenizer(
            text_passages,
            max_length=max_length,
            padding='max_length',
            add_special_tokens=True,
            return_tensors='pt',
            truncation=True
        )
        passage_ids.append(p['input_ids'].unsqueeze(dim=0))
        passage_masks.append(p['attention_mask'].unsqueeze(dim=0))

    passage_ids = torch.cat(passage_ids, dim=0)
    passage_masks = torch.cat(passage_masks, dim=0)
    return passage_ids, passage_masks.bool()

class FiDCollator(object):
    def __init__(self, tokenizer, text_maxlength, answer_maxlength=100, n_passages=5):
        self.tokenizer = tokenizer
        self.text_maxlength = text_maxlength
        self.answer_maxlength = answer_maxlength
        self.n_passages = n_passages

    def __call__(self, batch):
        assert(batch[0]['answer'] != None)
        ids = [ex['question_id'] for ex in batch]
        answer = [ex['answer'] for ex in batch]
        answer = self.tokenizer(
            answer,
            max_length=self.answer_maxlength if self.answer_maxlength > 0 else None,
            padding='max_length',
            return_tensors='pt',
            truncation=True if self.answer_maxlength > 0 else False,
        )
        target_ids = answer["input_ids"]
        target_mask = answer["attention_mask"].bool()
        target_ids = target_ids.masked_fill(~target_mask, -100)

        text_passages = []

        for ex in batch:
            question = ex['question'] + ' ' + ex['context']
            passages = [question + ' ' + p for p in ex['passages']]
            text_passages.append(passages)

        passage_ids, passage_masks = encode_passages(text_passages,
                                                     self.tokenizer,
                                                     self.text_maxlength)

        return (ids, target_ids, target_mask, passage_ids, passage_masks)

# Fusion in Decoder model

In [8]:
from transformers import T5EncoderModel

class FiDT5(T5ForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
        self.wrap_encoder()

    def forward_(self, **kwargs):
        if 'input_ids' in kwargs:
            kwargs['input_ids'] = kwargs['input_ids'].view(kwargs['input_ids'].size(0), -1)
        if 'attention_mask' in kwargs:
            kwargs['attention_mask'] = kwargs['attention_mask'].view(kwargs['attention_mask'].size(0), -1)

        return super(FiDT5, self).forward(
            **kwargs
        )

    # We need to resize as B x (N * L) instead of (B * N) x L here
    # because the T5 forward method uses the input tensors to infer
    # dimensions used in the decoder.
    # EncoderWrapper resizes the inputs as (B * N) x L.
    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        if input_ids != None:
            # inputs might have already be resized in the generate method
            if input_ids.dim() == 3:
                self.encoder.n_passages = input_ids.size(1)
            input_ids = input_ids.view(input_ids.size(0), -1)
        if attention_mask != None:
            attention_mask = attention_mask.view(attention_mask.size(0), -1)
        return super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )

    # We need to resize the inputs here, as the generate method expect 2D tensors
    def generate(self, input_ids, **kwargs):
        self.encoder.n_passages = input_ids.size(1)
        return super().generate(input_ids=input_ids.view(input_ids.size(0), -1), **kwargs)

    def wrap_encoder(self, use_checkpoint=False):
        """
        Wrap T5 encoder to obtain a Fusion-in-Decoder model.
        """
        self.encoder = EncoderWrapper(self.encoder, use_checkpoint=use_checkpoint)

    def unwrap_encoder(self):
        """
        Unwrap Fusion-in-Decoder encoder, useful to load T5 weights.
        """
        self.encoder = self.encoder.encoder
        block = []
        for mod in self.encoder.block:
            block.append(mod.module)
        block = nn.ModuleList(block)
        self.encoder.block = block

    def load_t5(self, state_dict):
        self.unwrap_encoder()
        self.load_state_dict(state_dict)
        self.wrap_encoder()

    def set_checkpoint(self, use_checkpoint):
        """
        Enable or disable checkpointing in the encoder.
        See https://pytorch.org/docs/stable/checkpoint.html
        """
        for mod in self.encoder.encoder.block:
            mod.use_checkpoint = use_checkpoint

    def reset_score_storage(self):
        """
        Reset score storage, only used when cross-attention scores are saved
        to train a retriever.
        """
        for mod in self.decoder.block:
            mod.layer[1].EncDecAttention.score_storage = None

    def get_crossattention_scores(self, context_mask):
        """
        Cross-attention scores are aggregated to obtain a single scalar per
        passage. This scalar can be seen as a similarity score between the
        question and the input passage. It is obtained by averaging the
        cross-attention scores obtained on the first decoded token over heads,
        layers, and tokens of the input passage.

        More details in Distilling Knowledge from Reader to Retriever:
        https://arxiv.org/abs/2012.04584.
        """
        scores = []
        n_passages = context_mask.size(1)
        for mod in self.decoder.block:
            scores.append(mod.layer[1].EncDecAttention.score_storage)
        scores = torch.cat(scores, dim=2)
        bsz, n_heads, n_layers, _ = scores.size()
        # batch_size, n_head, n_layers, n_passages, text_maxlength
        scores = scores.view(bsz, n_heads, n_layers, n_passages, -1)
        scores = scores.masked_fill(~context_mask[:, None, None], 0.)
        scores = scores.sum(dim=[1, 2, 4])
        ntokens = context_mask.sum(dim=[2]) * n_layers * n_heads
        scores = scores/ntokens
        return scores

    def overwrite_forward_crossattention(self):
        """
        Replace cross-attention forward function, only used to save
        cross-attention scores.
        """
        for mod in self.decoder.block:
            attn = mod.layer[1].EncDecAttention
            attn.forward = types.MethodType(cross_attention_forward, attn)

class EncoderWrapper(nn.Module):
    """
    Encoder Wrapper for T5 Wrapper to obtain a Fusion-in-Decoder model.
    """
    def __init__(self, encoder, use_checkpoint=False):
        super().__init__()
        self.encoder = encoder
        self.main_input_name = encoder.main_input_name
        apply_checkpoint_wrapper(self.encoder, use_checkpoint)

    def forward(self, input_ids=None, attention_mask=None, **kwargs,):
        # total_length = n_passages * passage_length
        bsz, total_length = input_ids.shape
        passage_length = total_length // self.n_passages
        input_ids = input_ids.view(bsz*self.n_passages, passage_length)
        attention_mask = attention_mask.view(bsz*self.n_passages, passage_length)
        outputs = self.encoder(input_ids, attention_mask, **kwargs)
        # outputs = (outputs[0].view(bsz, self.n_passages*passage_length, -1), ) + outputs[1:]
        outputs.last_hidden_state = outputs.last_hidden_state.view(bsz, self.n_passages*passage_length, -1)
        return outputs

class CheckpointWrapper(torch.nn.Module):
    """
    Wrapper replacing None outputs by empty tensors, which allows the use of
    checkpointing.
    """
    def __init__(self, module, use_checkpoint=False):
        super().__init__()
        self.module = module
        self.use_checkpoint = use_checkpoint

    def forward(self, hidden_states, attention_mask, position_bias, **kwargs):
        if self.use_checkpoint and self.training:
            kwargs = {k: v for k, v in kwargs.items() if v is not None}
            def custom_forward(*inputs):
                output = self.module(*inputs, **kwargs)
                empty = torch.tensor(
                    [],
                    dtype=torch.float,
                    device=output[0].device,
                    requires_grad=True)
                output = tuple(x if x is not None else empty for x in output)
                return output

            output = torch.utils.checkpoint.checkpoint(
                custom_forward,
                hidden_states,
                attention_mask,
                position_bias
            )
            output = tuple(x if x.size() != 0 else None for x in output)
        else:
            output = self.module(hidden_states, attention_mask, position_bias, **kwargs)
        return output

def apply_checkpoint_wrapper(t5stack, use_checkpoint):
    """
    Wrap each block of the encoder to enable checkpointing.
    """
    block = []
    for mod in t5stack.block:
        wrapped_mod = CheckpointWrapper(mod, use_checkpoint)
        block.append(wrapped_mod)
    block = nn.ModuleList(block)
    t5stack.block = block

def cross_attention_forward(
        self,
        input,
        mask=None,
        kv=None,
        position_bias=None,
        past_key_value_state=None,
        head_mask=None,
        query_length=None,
        use_cache=False,
        output_attentions=False,
    ):
    """
    This only works for computing cross attention over the input
    """
    assert(kv != None)
    assert(head_mask == None)
    assert(position_bias != None or self.has_relative_attention_bias)

    bsz, qlen, dim = input.size()
    n_heads, d_heads = self.n_heads, self.d_kv
    klen = kv.size(1)

    q = self.q(input).view(bsz, -1, n_heads, d_heads).transpose(1, 2)
    if past_key_value_state == None:
        k = self.k(kv).view(bsz, -1, n_heads, d_heads).transpose(1, 2)
        v = self.v(kv).view(bsz, -1, n_heads, d_heads).transpose(1, 2)
    else:
        k, v = past_key_value_state

    scores = torch.einsum("bnqd,bnkd->bnqk", q, k)

    if mask is not None:
        scores += mask

    if position_bias is None:
        position_bias = self.compute_bias(qlen, klen)
    scores += position_bias

    if self.score_storage is None:
        self.score_storage = scores

    attn = F.softmax(scores.float(), dim=-1).type_as(scores)
    attn = F.dropout(attn, p=self.dropout, training=self.training)

    output = torch.matmul(attn, v)
    output = output.transpose(1, 2).contiguous().view(bsz, -1, self.inner_dim)
    output = self.o(output)

    if use_cache:
        output = (output,) + ((k, v),)
    else:
        output = (output,) + (None,)

    if output_attentions:
        output = output + (attn,)

    if self.has_relative_attention_bias:
        output = output + (position_bias,)

    return output

# Logger

In [9]:
import logging
from colorlog import ColoredFormatter
import os

ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)

formatter = ColoredFormatter(
    "%(log_color)s[%(asctime)s] %(message)s",
#    datefmt='%H:%M:%S.%f',
    datefmt=None,
    reset=True,
    log_colors={
        'DEBUG':    'cyan',
        'INFO':     'white,bold',
        'INFOV':    'cyan,bold',
        'WARNING':  'yellow',
        'ERROR':    'red,bold',
        'CRITICAL': 'red,bg_white',
    },
    secondary_log_colors={},
    style='%'
)
ch.setFormatter(formatter)

log = logging.getLogger('rn')
log.setLevel(logging.DEBUG)
log.handlers = []       # No duplicated handlers
log.propagate = False   # workaround for duplicated logs in ipython
log.addHandler(ch)

In [10]:
def batch_to_device(batch, device=torch.device('cuda')):
        new_batch = []
        for idx in range(len(batch)):
            try:
                new_batch.append(batch[idx].to(device))
            except:
                new_batch.append(batch[idx])
        return new_batch

def train_one(args, model, device, optimizer, epoch, train_loader):
    # Set to training mode
    model.train()
    # Iterate over batches
    loss_sum = 0
    if not args.notqdm:
        train_loader = tqdm(train_loader, leave=True, desc=colored(f'Training on train - Epoch {epoch}', 'blue'))

    for batch_idx, batch in enumerate(train_loader):
        # Load data to device
        batch = batch_to_device(batch, device=device)

        # Zero out gradients for optimizer
        optimizer.zero_grad()
        # Run model
        loss = model(input_ids=batch[3], attention_mask=batch[4], labels=batch[1]).loss
        # Update model
        loss.backward()
        optimizer.step()

        loss_sum += loss

    log.info('[Epoch: ' + str(epoch) + '] ' + \
             '[Loss = ' + '{:.4f}'.format((loss_sum / (batch_idx + 1)).item()) + '] ')

    overall_loss = loss_sum / len(train_loader)
    return overall_loss

def test_one(args, model, device, test_loader, mode='test'):
    log.info('Evaluating on test set...')
    # Set to eval mode
    model.eval()

    if not args.notqdm:
        test_loader = tqdm(test_loader, leave=True, desc=colored(f'Testing on {mode}', 'yellow'))

    # Iterate over batches
    all_loss = []
    for batch_idx, batch in enumerate(test_loader):
        # Load data to device
        batch = batch_to_device(batch, device=device)

        # Run model
        loss= model(input_ids=batch[3], attention_mask=batch[4], labels=batch[1]).loss

        all_loss.append(loss.item())
        
    # Report overall test performance
    avg_loss = np.mean(all_loss)
    log.info('[Test Summary] ' + \
             '[Loss = ' + '{:.4f}'.format(avg_loss) + '] ')
    return avg_loss

In [11]:
def train(train_loader, val_loader, model, args):
    # creating a tmp directory to save the models
    out_dir = os.path.abspath(os.path.join(
                                  os.path.curdir,
                                  "tmp-runs",
                                  str(int(time.time() * 1e7))))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    best_path = None

    log.info('Setting up optimizer...')
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    if args.lr_scheduler == 'ExponentialLR':
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.ExponentialLR_gamma)

    log.info('Training begins...')

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    min_val_loss = np.inf
    sub_cycle = 0

    epoch_loss = {'train': [], 'val': []}
    for epoch in range(1, args.num_epochs + 1):
        # Training loop
        overall_train_loss = train_one(args, model, device, optimizer, epoch, train_loader)
        epoch_loss['train'].append(overall_train_loss)

        overall_val_loss = test_one(args, model, device, val_loader, 'val')
        epoch_loss['val'].append(overall_val_loss)

        # Update the current best model if val loss is better
        if args.save_model == 1 and overall_val_loss < min_val_loss:
            log.info(f"Val loss improves from {min_val_loss} to {overall_val_loss}.")
            # save the current model
            best_path = os.path.join(out_dir, args.model_name + '-' + str(epoch))
            log.info("Save cur best model to {}".format(best_path))

            torch.save(model.state_dict(), best_path+'.pth')
            min_val_loss = overall_val_loss
            sub_cycle = 0
        else:
            log.info(f"Val loss does NOT improve from previous.")
            sub_cycle += 1

        # Break if the val loss hasn't improved in the past patience epochs
        if sub_cycle == args.patience:
            break

        if args.lr_scheduler == 'ExponentialLR':
            scheduler.step()

    log.info("End of training. Restore the best weights")

    # restore the best saved model
    model.load_state_dict(torch.load(best_path+'.pth'))

    if args.save:
        # save the current model
        out_dir = os.path.abspath(os.path.join(
                                      os.path.curdir,
                                      "saved-runs",
                                      str(int(time.time() * 1e7))))
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

        best_path = os.path.join(out_dir, f'best-{args.model_name}')
        torch.save(model.state_dict(), best_path + '.pth')

        with open(best_path + '_args.txt', 'w') as f:
            for attr, value in sorted(args.__dict__.items()):
                f.write("{}={}\n".format(attr, value))

        with open(best_path + '_summary.txt', 'w') as f:
            f.write("{} = {}\n".format('Avg. Train loss', epoch_loss['train']/len(epoch_loss['train'])))
            f.write("{} = {}\n".format('Avg. Val loss', epoch_loss['val']/len(epoch_loss['val'])))

def test(test_loader, model, args):
    log.info('Testing begins...')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if not next(model.parameters()).is_cuda:
        model.to(device)
    # Test model
    overall_test_loss = test_one(args, model, device, test_loader, 'test')

# Training

In [12]:
class TrainingArguments:
    def __init__(self) -> None:
        self.num_epochs = 10
        self.model_name = 'vit5-base-multitask-model'
        self.save_model = 1
        self.lr = 5e-6
        self.weight_decay = 1e-4
        self.max_grad_norm = 1.0
        self.patience = 5
        self.notqdm = False
        self.lr_scheduler = 'ExponentialLR'
        self.ExponentialLR_gamma = 0.86
        self.save = True

        self.passage_length = 256
        self.target_length = 256
        self.num_passage = 5

args = TrainingArguments()

## Baseline model and tokenizer

In [None]:
checkpoint = "VietAI/vit5-base"
tokenizer = T5Tokenizer.from_pretrained(checkpoint)
vit5 = T5ForConditionalGeneration.from_pretrained(checkpoint)

## Data preparation

In [None]:
data_path = '/content/drive/MyDrive/LegalT5/data/json/Vi-LegalQA-FiD.json'
# perform train test split
train_data, test_data, val_data = load_dataset(data_path, test_size=0.1) # train-val-test = 0.9-0.1-0.1

print(f"Num. train examples: {len(train_data)}")
print(f"Num. val examples: {len(val_data)}")
print(f"Num. test examples: {len(test_data)}")

In [15]:
train_dataset = LegalQADataset(train_data)
val_dataset = LegalQADataset(val_data)
test_dataset = LegalQADataset(test_data)

collator = FiDCollator(tokenizer, args.passage_length, args.target_length, args.num_passage)

train_sampler = RandomSampler(train_dataset)
val_sampler = SequentialSampler(val_dataset)
test_sampler = SequentialSampler(test_dataset)

train_loader = DataLoader(train_dataset, batch_size=1, collate_fn=collator, sampler=train_sampler, num_workers=os.cpu_count())
val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=collator, sampler=val_sampler, num_workers=os.cpu_count())
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collator, sampler=test_sampler, num_workers=os.cpu_count())

## Model

In [16]:
# creating model
model = FiDT5(vit5.config)

In [17]:
# train(train_loader, val_loader, model, args)

In [19]:
# test(test_loader, model, args)

# Inference

In [None]:
checkpoint = "thaingo/vit5_law_base_fid"

tokenizer = T5Tokenizer.from_pretrained(checkpoint)
model = T5ForConditionalGeneration.from_pretrained(checkpoint)

In [None]:
question = "Câu hỏi: Tôi có được phép đi xe không đội mũ bảo hiểm không?"

input = tokenizer(question, return_tensors='pt')

In [None]:
generator_args = {
    "max_length": 256,
    "num_beams": 4,
    "length_penalty": 1.5,
    "no_repeat_ngram_size": 3,
    "early_stopping": True,
}

result = tokenizer.batch_decode(
    model.generate(**input, **generator_args),
    skip_special_tokens=True,
    clean_up_token_spaces=True,
)

print(result)