# TorchScriptable T5 with TorchText
## Motivation

[TorchScript](https://pytorch.org/docs/stable/jit.html) is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency, such as in a standalone C++ program. This makes it possible to train models in PyTorch using familiar tools in Python and then export the model via TorchScript to a production environment where Python programs may be disadvantageous for performance and multi-threading reasons. 

The new PyTorch version introduced the [GenerationUtils](https://github.com/pytorch/text/blob/1b72eba0a07295d74d168c99fd8a5586a0943aa3/torchtext/prototype/generate.py#L13) functionality. It allows wrapping TorchText's [T5Model](https://github.com/pytorch/text/blob/670e52a3df658f6332f2904cfed67308f3f5adce/torchtext/models/t5/model.py#L67), and using it to generate text in a similar way to the [HuggingFace 'generate'](https://huggingface.co/docs/transformers/v4.27.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) function. However, although both T5Model and its tokenizer are initially "Torchscriptle", this property is not preserved after wrapping the model with GenerationUtils. 


## Technical details
We've implemented a "Hacky" solution for wrapping the full 'generate' functionality inside a "forward" function. We will work with the Pytorch team and **hopefully, this code (with some modifications) can later later added to TorchText's T5Model.**.

To do so, we:
1. T5TorchGenerative: inherited from T5Model:
- extracted the decoding code from t5.forward() function to a standalone 'decode' function that returns a specific type.  
- added the GenerationUtils's 'generate' functionality as a class method (similar to HuggingFace).
2. Added TorchScriptableT5, a module that implements the full generative logic in the forward method.
3. Helper classes that build a jit (TorchScript) model from a predefined T5 Bundle



## Issue Example
Currently, this code raises an exception.

In [1]:
%load_ext autoreload
%autoreload 2

import torch
from torchtext.prototype.generate import GenerationUtils
from torchtext.models import T5_SMALL_GENERATION

# The tokenizer object is torchscriptable
tokenizer = T5_SMALL_GENERATION.transform()
tokenizer_jit = torch.jit.script(tokenizer)

# The T5 model is also torchscriptable
model = T5_SMALL_GENERATION.get_model()
model_jit = torch.jit.script(model)


# But after wrapping with GenerationUtils, the model is no longer torchscriptable
generative_model = GenerationUtils(model)
generative_model_jit = torch.jit.script(generative_model)

NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "/Users/rbahumi/miniconda3/envs/conda_pytorch/lib/python3.10/site-packages/torchtext/prototype/generate.py", line 37
    def __init__(self, model: nn.Module, **kwargs) -> None:
                                          ~~~~~~~ <--- HERE
        self.model = model
        self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", True)


This failure is caused by: 
1. The use of keyword argument from a dictionary (**kwargs) 
2. Functions that can accept Optional values 
3. Multiple optiones for returned types.


In the next section, we suggest a (currently) hacky solution to solve this. 

# Generate results using T5TorchGenerative
We'll define a new class called T5TorchGenerative (subclass of T5Model) that will 

We've implemented a "Hacky" solution for wrapping the full 'generate' functionality inside a "forward" function. We will work with the Pytorch team to make it an appropriate pull request.

To do so, we:
1. T5TorchGenerative: inherited from T5Model:
- extracted the decoding code from t5.forward() function to a standalone 'decode' function that returns a specific type.  
- added the GenerationUtils's 'generate' functionality as a class method (similar to HuggingFace).
2. Added TorchScriptableT5, a module that implements the full generative logic in the forward method.
3. Helper classes that build a jit (TorchScript) model from a predefined T5 Bundle

In [2]:
from typing import List, Optional
import torch
import torch.nn.functional as F
from torch import Tensor
from torchtext.models import T5Model, T5Conf
from torchtext.models.t5.modules import PAST_KEY_VALUES_TYPE, T5Decoder, T5Encoder, ENCODER_OUTPUTS_TYPE


DEFAULT_MAX_SEQ_LEN = 256


class T5TorchGenerative(T5Model):
    """
    This is a quick and dirty implementation for the T5Model model which encapsulates the GenerationUtils functionality
    inside the instance.

    Motivation: the ability to make a generate functionality TorchScriptable.

    TODO: implement beam search once it is added to GenerationUtils.

    """
    @torch.jit.export
    def _prepare_decoder_ids_for_generation(
            self, batch_size: int, pad_idx: int = 0, device: Optional[torch.device] = None
    ):
        return torch.ones((batch_size, 1), dtype=torch.long, device=device) * pad_idx

    @torch.jit.export
    def decode(
            self,
            encoder_outputs: ENCODER_OUTPUTS_TYPE,
            decoder_tokens: Optional[Tensor] = None,
            encoder_mask: Optional[Tensor] = None,
            decoder_mask: Optional[Tensor] = None,
            encoder_padding_mask: Optional[Tensor] = None,
            decoder_padding_mask: Optional[Tensor] = None,
            past_key_values: Optional[List[PAST_KEY_VALUES_TYPE]] = None,
            return_past_key_values: bool = False,
    ) -> Tensor:
        """
        This method's code was copied from the T5Model::forward() function. 
        It only does the decoder part, and returns a tensor instead of multiple return values wrapped in a dictionary type.

        In the future, it might be helpful if we can call this function from forward, and remove the duplicate code. 
        """

        assert self.decoder is not None
        assert encoder_outputs is not None

        encoder_output = encoder_outputs.get("encoder_output")
        assert torch.jit.isinstance(encoder_output, Tensor)

        batch_size = encoder_output.size(0)
        encoder_output_device = encoder_output.device

        # decoder_tokens is None means at start of inference, in which case decoder sequence should begin with padding idx.
        if decoder_tokens is None:
            decoder_tokens = (
                    torch.ones((batch_size, 1), device=encoder_output_device, dtype=torch.long) * self.padding_idx
            )

        if decoder_padding_mask is None:
            decoder_padding_mask = decoder_tokens.eq(self.padding_idx)
            # T5 implemention uses padding idx to start sequence. Want to ignore this when masking
            decoder_padding_mask[:, 0] = False

        decoder_embeddings = self.token_embeddings(decoder_tokens)
        decoder_outputs = self.decoder(
            decoder_embeddings,
            memory=encoder_output,
            tgt_mask=decoder_mask,
            memory_mask=encoder_mask,
            tgt_key_padding_mask=decoder_padding_mask,
            memory_key_padding_mask=encoder_padding_mask,
            past_key_values=past_key_values,
            return_past_key_values=return_past_key_values,
        )

        decoder_output = decoder_outputs.get("decoder_output")
        assert torch.jit.isinstance(decoder_output, Tensor)

        if self.linear_head:
            assert self.lm_head is not None
            # Rescale output before projecting on vocab. This happens when the encoder and decoder share the
            # same word embeddings, which is always the case in our t5 implementation.
            # See https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/t5/modeling_t5.py#L1661
            decoder_output = decoder_output * (self.embedding_dim ** -0.5)
            decoder_output = self.lm_head(decoder_output)

        return decoder_output

    @torch.jit.export
    def greedy_search(
            self, input_ids: torch.Tensor, max_length: int, eos_idx: int, encoder_outputs: ENCODER_OUTPUTS_TYPE, pad_idx: Optional[int] = None,
    ) -> torch.Tensor:
        """Greedy search decoding for text generation. Takes the most likely next token every time.

        Inputs:
            input_ids (Tensor): Text prompt(s) for greedy generation.
            max_length (int): Max length to generate responses.
            eos_idx (int): End of sequence index.
            pad_idx (int): Padding index.

        Returns:
            Batch of sequences decoded by greedy search.
        """
        unfinished_sequences = torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long)

        while True:
            decoder_output = self.decode(
                decoder_tokens=input_ids,
                encoder_mask=None,
                decoder_mask=None,
                encoder_padding_mask=None,
                decoder_padding_mask=None,
                encoder_outputs=encoder_outputs,
                past_key_values=None,
                return_past_key_values=True
            )

            # Calculate probabilities and take the most likely next token
            probs = F.log_softmax(decoder_output[:, -1], dim=-1)
            _, next_tokens = torch.topk(probs, 1)

            # For any finished sequences, padding idx should be the last token
            if eos_idx is not None:
                if pad_idx is not None:
                    next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences)

            # Append the next tokens to the previous tokens
            input_ids = torch.cat([input_ids, next_tokens], dim=-1)

            if eos_idx is not None:
                unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_idx).long())

            # Stop iterating once all sequences are finished or exceed the max_length
            if unfinished_sequences.max() == 0 or len(input_ids[0]) >= max_length:
                break

        return input_ids

    @torch.jit.export
    def generate(
            self,
            inputs: torch.Tensor,
            num_beams: Optional[int] = None,
            max_length: int = DEFAULT_MAX_SEQ_LEN,
            pad_idx: int = 0,
            eos_idx: int = 1,
    ) -> torch.Tensor:
        encoder_outputs = self.encoder(inputs)
        inputs = self._prepare_decoder_ids_for_generation(len(inputs), device=inputs.device, pad_idx=pad_idx)

        if num_beams is None or num_beams == 1:
            return self.greedy_search(inputs, max_length, eos_idx, pad_idx=pad_idx, encoder_outputs=encoder_outputs)
        # elif num_beams > 1:
        #     return self.beam_search(inputs, num_beams, max_length)
        else:
            raise ValueError("`num_beams` must be >= 1.")

In [3]:
from typing import Optional, Union, Dict, Any
from torchtext import _TEXT_BUCKET
from urllib.parse import urljoin
from torchtext._download_hooks import load_state_dict_from_url


def build_model(
    config: T5Conf,
    T5Class=T5Model,
    freeze_model: bool = False,
    checkpoint: Optional[Union[str, Dict[str, torch.Tensor]]] = None,
    strict: bool = False,
    dl_kwargs: Optional[Dict[str, Any]] = None,
) -> T5Model:
    """Class builder method that can overide the default T5Model model class 
    
    (reference: https://github.com/pytorch/text/blob/a1dc61b8e80df70fe7a35b9f5f5cc7e19c7dd8a3/torchtext/models/t5/bundler.py#L113)
    
    Args:
        config (T5Conf): An instance of classT5Conf that defined the model configuration
        freeze_model (bool): Indicates whether to freeze the model weights. (Default: `False`)
        checkpoint (str or Dict[str, torch.Tensor]): Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. (Default: ``None``)
        strict (bool): Passed to :func: `torch.nn.Module.load_state_dict` method. (Default: `False`)
        dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: `None`)
    """
    model = T5Class(config, freeze_model)
    if checkpoint is not None:
        if torch.jit.isinstance(checkpoint, Dict[str, torch.Tensor]):
            state_dict = checkpoint
        elif isinstance(checkpoint, str):
            dl_kwargs = {} if dl_kwargs is None else dl_kwargs
            state_dict = load_state_dict_from_url(checkpoint, **dl_kwargs)
        else:
            raise TypeError(
                "checkpoint must be of type `str` or `Dict[str, torch.Tensor]` but got {}".format(type(checkpoint))
            )

        model.load_state_dict(state_dict, strict=strict)

    return model


def load_model(bundle, T5Class=T5TorchGenerative):
    """
    
    Example usage:
    >> model = load_model(bundle=T5_SMALL_GENERATION, T5Class=T5TorchGenerative)
    """
    return build_model(config=bundle.config, T5Class=T5Class, checkpoint=bundle._path)


def get_model_from_bundle(bundle):
    model = load_model(bundle=bundle, T5Class=T5TorchGenerative)
    tokenizer = bundle.transform()
    full_model = TorchScriptableT5(model=model, transform=tokenizer)
    return full_model

def get_jit_from_bundle(bundle):
    full_model = get_model_from_bundle(bundle)
    full_model_jit = torch.jit.script(full_model)
    return full_model_jit

In [4]:
from typing import List, Union

DEFAULT_MAX_LENGHT: int = 100


class TorchScriptableT5(torch.nn.Module):
    def __init__(self, model, transform, cuda: bool = False):
        super(TorchScriptableT5, self).__init__()
        self.cuda = cuda
        self.transform = transform
        
        if cuda:
            model = model.cuda()
        
        self.model = model
        self.model.eval()

    def forward(self, texts: List[str], max_length:int=DEFAULT_MAX_LENGHT) -> Union[List[str], str]:
        input_ids = self.transform(texts)
        if self.cuda:
            input_ids = input_ids.cuda()
        raw_outputs = self.model.generate(input_ids, max_length=max_length)
        
        if raw_outputs.dim() == 1:
            raw_outputs_list: List[List[int]] = raw_outputs[None, :].tolist()
        else:
            raw_outputs_list: List[List[int]] = raw_outputs.tolist() # : List[List[int]] = raw_outputs.tolist()


        res = self.transform.decode(raw_outputs_list)
        return res

In [5]:
from typing import Optional, Union, Dict, Any
from torchtext import _TEXT_BUCKET
from urllib.parse import urljoin
from torchtext._download_hooks import load_state_dict_from_url


def build_model(
    config: T5Conf,
    T5Class=T5Model,
    freeze_model: bool = False,
    checkpoint: Optional[Union[str, Dict[str, torch.Tensor]]] = None,
    strict: bool = False,
    dl_kwargs: Optional[Dict[str, Any]] = None,
) -> T5Model:
    """Class builder method that can overide the default T5Model model class 
    
    (reference: https://github.com/pytorch/text/blob/a1dc61b8e80df70fe7a35b9f5f5cc7e19c7dd8a3/torchtext/models/t5/bundler.py#L113)
    
    Args:
        config (T5Conf): An instance of classT5Conf that defined the model configuration
        freeze_model (bool): Indicates whether to freeze the model weights. (Default: `False`)
        checkpoint (str or Dict[str, torch.Tensor]): Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. (Default: ``None``)
        strict (bool): Passed to :func: `torch.nn.Module.load_state_dict` method. (Default: `False`)
        dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: `None`)
    """
    model = T5Class(config, freeze_model)
    if checkpoint is not None:
        if torch.jit.isinstance(checkpoint, Dict[str, torch.Tensor]):
            state_dict = checkpoint
        elif isinstance(checkpoint, str):
            dl_kwargs = {} if dl_kwargs is None else dl_kwargs
            state_dict = load_state_dict_from_url(checkpoint, **dl_kwargs)
        else:
            raise TypeError(
                "checkpoint must be of type `str` or `Dict[str, torch.Tensor]` but got {}".format(type(checkpoint))
            )

        model.load_state_dict(state_dict, strict=strict)

    return model


def load_model(bundle, T5Class=T5TorchGenerative):
    """
    
    Example usage:
    >> model = load_model(bundle=T5_SMALL_GENERATION, T5Class=T5TorchGenerative)
    """
    return build_model(config=bundle.config, T5Class=T5Class, checkpoint=bundle._path)


def get_model_from_bundle(bundle, cuda=False):
    model = load_model(bundle=bundle, T5Class=T5TorchGenerative)
    tokenizer = bundle.transform()
    full_model = TorchScriptableT5(model=model, transform=tokenizer, cuda=cuda)
    return full_model

def get_jit_from_bundle(bundle, cuda=False):
    full_model = get_model_from_bundle(bundle, cuda=cuda)
    full_model_jit = torch.jit.script(full_model)
    return full_model_jit

# The new model is an E2E generation model

In [6]:
SUMMERIZE_PROMP = "summarize"
TRANSLATE_TO_GERMAN = "translate English to German"
QUESTION_PROMPS = "question"
CONTEXT_PROMPT = "context"


def summarize_text(text):
    return f"{SUMMERIZE_PROMP}: {text}"


def en_to_german_text(text):
    return f"{TRANSLATE_TO_GERMAN}: {text}"


def qa_text(context, question):
    return f"{QUESTION_PROMPS}: {question}? {CONTEXT_PROMPT}: {context}"


In [7]:
from torchtext.models import T5_SMALL_GENERATION, T5_LARGE_GENERATION, T5_3B_GENERATION, T5_11B_GENERATION

In [10]:
EXAMPLE_INPUT =  [
    'question: What does Nir likes to eat? context: Nir is a PM on the Care AI team. Nir only eats vegeterian food and he loves Pizza',
    'question: Who likes to eat pizza? context: Nir is a PM on the Care AI team. Nir only eats vegeterian food and he loves Pizza',
    "summarize: studies say that owning a dog is good for you",
]

t5_large = get_jit_from_bundle(T5_LARGE_GENERATION)

In [None]:
%time t5_large(EXAMPLE_INPUT, max_length=100)

In [29]:
# Try to load to GPU and compare the time difference 
t5_large_gpu = get_jit_from_bundle(T5_LARGE_GENERATION, cuda=True)
%time t5_large_gpu(EXAMPLE_INPUT, max_length=100)

CPU times: user 2.93 s, sys: 47.5 ms, total: 2.98 s
Wall time: 2.95 s


['Pizza',
 'Nir',
 'studies say owning a dog is good for you . a dog is a good companion, a companion for life .']

### Save the model localy as jit

In [34]:
model_filename = 'flan_t5_large_generation.pt'
torch.jit.save(t5_large, model_filename)

In [36]:
!ls -lath | head -3

total 2.9G
-rw-r--r-- 1 rbahumi rbahumi 2.9G Mar 20 03:40 flan_t5_large_generation.pt
drwxr-xr-x 1 rbahumi rbahumi 1.1K Mar 20 03:39 .
