# MT5 Small Model

In this notebook, we will be fine tuning the MT5 Sequence-to-Sequence Transformer model to take a Natural Language structured card specification to Java code.

### Check for Cuda Compatibility.

In [8]:
import torch
import torch.nn as nn
assert torch.cuda.is_available()

True

Depending on the context that this notebook is hosted on, we would use a different path location for saving and loading the model.

In [9]:
using_google_drive = False
# PATH is the folder of this repository
if using_google_drive:
    from google.colab import drive
    drive.mount('/content/gdrive')
    PATH = '/content/gdrive/MyDrive/Colab Notebooks/Final Project'
else:
    PATH = '.'

In [10]:
%%bash
pip -q install transformers
pip -q install tqdm
pip -q install sentencepiece 

# Tokenizer for the MT5 Model

In [11]:
# Tokenizers
import transformers
pretrained_model_name = 'google/mt5-small'

tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_name)
context_ids = tokenizer.encode("When this creature enters the battle field, target creature loses 2 life.")
print(tokenizer.convert_ids_to_tokens(context_ids))

['▁', 'When', '▁this', '▁c', 'reature', '▁enter', 's', '▁the', '▁battle', '▁field', ',', '▁target', '▁c', 'reature', '▁los', 'es', '▁2', '▁life', '.', '</s>']


# 1. Dataset Collection and Processing

Load the dataset. The framework for making changes to individual points in the dataset is set in the `preprocess_datapoint` method, which at the moment does nothing to our dataset.

## 1.1. Loading the Dataset

Here, we load the training and testing datasets for magic

In [12]:
with open(PATH + '/datasets_uwu/train_magic.in') as f:
    train_x = f.readlines()
with open(PATH + '/datasets_uwu/train_magic.out') as f:
    train_y = f.readlines()
with open(PATH + '/datasets_uwu/test_magic.in') as f:
    test_x = f.readlines()
with open(PATH + '/datasets_uwu/test_magic.out') as f:
    test_y = f.readlines()
    
# Structure the dataset in the form of a list of {input, labels} pairs
training_dataset = [{ 'card': x, 'code': y } for x, y in zip(train_x, train_y)]
testing_dataset  = [{ 'card': x, 'code': y } for x, y in zip(test_x,  test_y )]

dataset = {
    "train": training_dataset,
    "test":  testing_dataset
}

## 1.2. Preprocessing a Single Datapoint

To account for potential changes we might make to the dataset or card representations, we have added a method for preprocessing a single datapoint. This can involve replacing characters, removing parts of the input or output, etc. The current implementation applies no change to the dict. It can return None if we want to remove this datapoint as well. 

We have decided to implement the identity function for now.

In [13]:
def preproc_init(tokenizer_for_model):
    """ 
    Use this to assign global variables within a new thread 
    
    Parameters
    ----------
    tokenizer_for_mode: fn
        The tokenizer for the pretrained transformer model
    """
    global tokenizer
    tokenizer = tokenizer_for_model

def preprocess_datapoint (datapoint):
    """
    Is effectively an identity function, but is here if we do preprocessing later
    
    This method will preprocess a single datapoint loaded above. This can involve
    replacing characters, removing parts of the input or output, etc. The current
    implementation applies no change to the dict. It can return None if we want to 
    remove this datapoint as well.

    Parameters
    ----------
    datapoint: dict
        The dict containing the initial value of each data in the dataset.
        Each datapoint has the following two datapoints
        "card": the string for the card description and meta data
        "code": the string for the card implementation in Java
    
    Returns
    -------
    dict
        A new representation for this individual datapoint.
    """
    
    # We have access to global vars defined in preproc_init
    return datapoint

## 1.3. Preprocessing the Full Dataset

Here, we will using the the multiprocessing package in Python to apply our preprocessing method onto each datapoint in our dataset. 

In [14]:
import json
import random
from multiprocessing import Pool
from tqdm import tqdm, trange
    
def preprocess_dataset(dataset_list, threads, tokenizer):
    """
    Preprocesses the entire dataset in `threads` threads
    
    This method opens `threads` new threads, each with a subset of 
    the total datapoints and applies the preprocessing method onto 
    each datapoint. 

    Parameters
    ----------
    dataset_list: dict[]
        A list of datapoints, where each datapoint is in the shape:
        "card": the string for the card description and meta data
        "code": the string for the card implementation in Java
    threads: int
        The number of threads to run the preprocessing on
    tokenizer: fn
        The tokenizer for the particular pretrained model
    
    Returns
    -------
    dict
        A new representation for every datapoint in the dataset_list
    """
    
    # Open new threads and map tasks between them
    with Pool(threads, initializer=preprocess_datapoint, initargs=(tokenizer,)) as p:
        processed_dataset = list(tqdm(p.imap(preprocess_datapoint, dataset_list), total=len(dataset_list)))
    # Remove None values in the list
    processed_dataset = [x for x in processed_dataset if x]
    
    json.dump(processed_dataset, open(PATH + "/processed_dataset.json", 'w'))
    return processed_dataset

processed_dataset = preprocess_dataset(dataset['train'], 16, tokenizer)


100%|██████████| 11969/11969 [00:00<00:00, 20032.87it/s]


# 2. Building the Model

Below, we define two classes. The first class, `ModelOutputs` stores the outputs of our model. The second class, `CardTranslationModel`, extends the pytorch class of models and wraps a pretrained transformer model.

## 2.1. Pytoch Model Wrapper

In this example, we are wrapping the MT5 Model provided by the huggingface transformers package.

In [15]:
class ModelOutputs:
    def __init__(self, output_logits=None, loss=None):
        """
        An object containing the output of the CardTranslationModel
        
        Parameters
        ----------
        output_logits : torch.tensor
            shape (batch_size, ans_len)
        loss : torch.tensor
            shape (1) The loss of the output

        """
        self.output_logits = output_logits
        self.loss = loss
        
class CardTranslationModel(nn.Module):

    def __init__(self, lm=None):
        """
        Initializes the CardTranslationModel with the provided learning mdoel

        Parameters
        ----------
        lm : pretrained transformer
            The pretrained transformer which will be fine tuned for the Card Translation Task

        """
        super(CardTranslationModel, self).__init__()
        self.lm = lm
    
    def forward(self, input_ids=None, attention_mask=None, label_ids=None):
        """
        The forward pass function for the wrapped transformer model.

        The implementation of the Pytoch forward method. Feeds the input_ids, 
        attention_mask, and label_ids to the transformer's forward method, 
        and returns a portion of the transformer's output that is relevant 
        to our task. 

        Parameters
        ----------
        input_ids : torch.tensor
            shape (batch_size, seq_len) ids of the concatenated input tokens
        attention_mask : torch.tensor
            shape (batch_size, seq_len) concatenated attention masks
        label_ids: torch.tensor
            shape (batch_size, ans_len) the expected code output

        Returns
        -------
        ModelOutputs
            A dictionary with the two outputs of our transformer model:
            output_logits: torch.tensor
                shape (batch_size, ans_len, vocab_size) The predicted distribution on the vocabulary
            loss: 
                NLL loss for the batch, computed by the wrapped transformer model
            

        """
        # Feed our input ids into the pretrained transformer
        lm_output = self.lm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=label_ids,
            use_cache=False
        )
            
        return ModelOutputs(
            output_logits=lm_output['logits'],
            loss=lm_output['loss'])

## 2.2. Instantiating Our Model

Once we have defined our wrapper, we will load to pretrained transformer model and instantiate a CardTranslationModel. We also send our model to the GPU for performance when training. 

In [16]:
from transformers import MT5ForConditionalGeneration, T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained(pretrained_model_name)
# Create the CardTranslationModel using the MT5 Conditional Generation model
lm_pretrained = MT5ForConditionalGeneration.from_pretrained(pretrained_model_name)
model = CardTranslationModel(lm_pretrained).cuda()

# 3. Training

Now that we have successfully loaded the pretrained model and wrapped it in a Pytoch class, we will begin training our 

In [17]:
import torch

# Hyper-parameters: you could try playing with different settings
num_epochs = 1
learning_rate = 3e-5
weight_decay = 1e-5
eps = 1e-6
batch_size = 2
card_max_length = 448
code_max_length = 448

# Calculating the number of warmup steps
num_training_cases = len(processed_dataset)
t_total = (num_training_cases // batch_size + 1) * num_epochs

# Initializing an AdamW optimizer
ext_optim = torch.optim.AdamW(model.parameters(), lr=learning_rate,
                              eps=eps, weight_decay=weight_decay)

print("***** Training Info *****")
print("  Num examples = %d" % t_total)
print("  Num Epochs = %d" % num_epochs)
print("  Batch size = %d" % batch_size)
print("  Total optimization steps = %d" % t_total)

***** Training Info *****
  Num examples = 5985
  Num Epochs = 1
  Batch size = 2
  Total optimization steps = 5985


## 3.1. Vectorize Inputs

The vectorize batch method takes a batch and converts it into three different `torch.tensor` instances. It receives a batch, or a subset of the preprocessed dataset, and separates into the set of expected input strings and output strings. We tokenize both sets of strings to generate a tensor for the `input_ids`, `input_attn_mask`, and `label_ids`.  

In [18]:
def vectorize_batch(batch, tokenizer):
    """
    Converts the batch of processed datapoints into separate tensors of token ids
    
    Converts the batch of processed datapoints into separate tensors of token ids
    hosted on the GPU. 
    
    Parameters
    ----------
    batch: dict[]
        shape (batch_size) A list of dictionaries in the form 
    tokenizer: fn
        Converts the batch to a tensor of input and output ids
    
    Returns
    -------
    input_ids: torch.tensor
        shape (batch_size, max_input_len)
    input_attn_mask: torch.tensor
        shape (batch_size, max_input_len)
    label_ids: torch.tensor
        shape (batch_size, max_output_len)
    """
    
    # Separate the batch into input and output
    card_batch = [card_data['card'] for card_data in batch]
    code_batch = [code_data['code'] for code_data in batch]
    
    # Encode the card's natural language representation
    card_encode = tokenizer.batch_encode_plus(
        card_batch,
        max_length = card_max_length,
        truncation = True,
        padding = 'longest',
        return_attention_mask = True,
        return_tensors = 'pt'
    )

    # Encode the card's java code representation
    code_encode = tokenizer.batch_encode_plus(
        code_batch,
        max_length = code_max_length,
        truncation = True,
        padding = 'longest',
        return_attention_mask = True,
        return_tensors = 'pt'
    )
    
    # Move the training batch to GPU
    card_ids        = card_encode['input_ids'].cuda()
    card_attn_mask  = card_encode['attention_mask'].cuda()
    code_ids        = code_encode['input_ids'].cuda()
    
    return card_ids, card_attn_mask, code_ids

## 3.2. Training Loop

Now, we define and begin the actual training loop. It is worth noting that our dataset is very large, and as a consequence, the entire `processed_dataset` often leads to errors regarding running out of memory on the CUDA enabled device, even on the Google Machine. Thus, we had to truncate our `processed_dataset` from the total 12000 datapoints to the first couple thousand.

As future work, we may wish to find a more condensed representation for the code or use models that are few-shot training to have a reduced impact on space consumption when fine-tuning the model.

Because some of the training was conducted on Google Collab with limitted GPU compute time or risk of idling too long, we also exported the model periodically. To load the model, we uncomment the next block.

In [19]:
# model.load_state_dict(torch.load(PATH + '/checkpoint'))

<All keys matched successfully>

Once we have loaded our model from our checkpoint, we can resume the training below. Notice that there is an additional cell that exports the model checkpoint, overriding any previous checkpoint file that was previously at the same location with the same file name. 

In [22]:
import gc

model.train()
max_grad_norm = 1

training_dataset = processed_dataset[:100]
num_training_cases = len(training_dataset)

step_id = 0
for _ in range(num_epochs):

    random.shuffle(training_dataset)

    for i in range(0, num_training_cases, batch_size):
        gc.collect()
        torch.cuda.empty_cache()

        batch = training_dataset[i: i + batch_size]
        input_ids, input_attn_mask, label_ids = vectorize_batch(batch, tokenizer)

        gc.collect()
        torch.cuda.empty_cache()

        model.zero_grad() # Does the same as ext_optim.zero_grad()

        # Get the model outputs, including (start, end) logits and losses
        # stored as a ModelOutput object
        outputs = model(            
            input_ids=input_ids,
            attention_mask=input_attn_mask,
            label_ids=label_ids
        )
        
        gc.collect()
        torch.cuda.empty_cache()

        # Back-propagate the loss signal and clip the gradients
        loss = outputs.loss.mean()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        # Update neural network parameters and the learning rate
        ext_optim.step()

        if step_id % 100 == 0:
            print(f'At step {step_id}, the extraction loss = {loss}')
        
        step_id += 1

        input_ids.detach()
        input_attn_mask.detach()
        label_ids.detach()
        outputs.loss.detach()
        
        del input_ids
        del input_attn_mask
        del label_ids
        del outputs

        torch.cuda.empty_cache()

print('Finished Training')

At step 0, the extraction loss = 1.1861380338668823
At step 100, the extraction loss = 0.8956267237663269
At step 200, the extraction loss = 0.5786179304122925
At step 300, the extraction loss = 0.41659775376319885
At step 400, the extraction loss = 0.7746285796165466
At step 500, the extraction loss = 0.5667914748191833
At step 600, the extraction loss = 0.7931115031242371
At step 700, the extraction loss = 0.7759041786193848
At step 800, the extraction loss = 0.8454432487487793
At step 900, the extraction loss = 0.5475283265113831
At step 1000, the extraction loss = 0.6904891729354858
At step 1100, the extraction loss = 0.4824003279209137
At step 1200, the extraction loss = 0.36978965997695923
At step 1300, the extraction loss = 0.48812296986579895
At step 1400, the extraction loss = 0.34553778171539307
At step 1500, the extraction loss = 0.5009756684303284
At step 1600, the extraction loss = 0.48247480392456055
At step 1700, the extraction loss = 0.5818668603897095
At step 1800, the

In [19]:
torch.save(model.state_dict(), PATH + '/checkpoint')