<a href="https://colab.research.google.com/github/rayhern/convert-gpt2-xl-to-onnx/blob/master/Convert_GPT2_XL_Transformer_to_ONXX_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Create ONNX model from transformers models. (Updated: 9-17-2020)
- You must run this notebook in Google Colab Pro.
- The instance needs to be of type GPU with High-RAM.
- By: Ray Hernandez [github: @rayhern, twitter:@bizong]

# Check GPU

In [None]:
!nvidia-smi

In [None]:
#@title Install dependencies.
!pip install onnx onnxruntime-gpu onnxruntime-tools transformers
!pip install -U torch

#Convert GPT2-XL using ONNX Gpt2Helper.

In [None]:
%%time
import os
if os.path.exists('gpt2-final') is False:
  print('making directory ./gpt2-final...')
  %mkdir gpt2-final
from onnxruntime_tools.transformers.gpt2_helper import Gpt2Helper, MyGPT2LMHeadModel
# Can be encoded with cpu or gpu.
device = 'cpu'
# Use GPT2 model wrapper from onnxruntime's gpt2 helper supple.
# conversion will not work without this wrapper.
model = MyGPT2LMHeadModel.from_pretrained('gpt2-xl')
onnx_model_path = "gpt2-final/gpt2-xl.onnx"
print('converting model...')
# Use GPT2Helper from onnxruntime tools to export GPT2-XL.
Gpt2Helper.export_onnx(
  model, 
  device, 
  onnx_model_path, 
  use_external_data_format=True, 
  verbose=True
)
print('finished.')

#/usr/local/lib/python3.6/dist-packages/transformers/modeling_gpt2.py:712: 
# FutureWarning: The `past` argument is deprecated and will be removed in a 
# future version, use `past_key_values` instead. FutureWarning,


#Optimize GPT2-XL model and convert to float16.

In [None]:
%%time
from onnxruntime_tools import optimizer
from onnxruntime_tools.transformers.onnx_model_bert import BertOptimizationOptions
from transformers import (
    AutoModelForCausalLM
)
tf_model = AutoModelForCausalLM.from_pretrained('gpt2-xl')
print('num heads: %s. hidden size: %s.' % (tf_model.config.n_head, tf_model.config.n_embd))
options = BertOptimizationOptions('gpt2')
optimized_model = optimizer.optimize_model(
    "gpt2-final/gpt2-xl.onnx",
    model_type='gpt2',
    num_heads=tf_model.config.n_head,
    hidden_size=tf_model.config.n_embd,
    optimization_options=options,
)
optimized_model.convert_model_float32_to_float16()
optimized_model.change_input_to_int32()
optimized_model.save_model_to_file("gpt2-final/gpt2_fp16.onnx", use_external_data_format=True)

#ONNX Model Text Generation Class

In [None]:

from transformers import (
    AutoConfig,
    AutoTokenizer,
)
from psutil import cpu_count
from os import environ
# Constants from the performance optimization available in onnxruntime
# It needs to be done before importing onnxruntime
environ["OMP_NUM_THREADS"] = str(cpu_count(logical=True))
environ["OMP_WAIT_POLICY"] = 'ACTIVE'
from onnxruntime_tools.transformers.gpt2_helper import Gpt2Helper
from onnxruntime.capi._pybind_state import set_seed as ort_set_seed
import onnxruntime as ort
import torch
from torch import Tensor
import torch.nn.functional as F
import numpy
from typing import Iterable, List, Optional, Tuple
import random

@torch.no_grad()
def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> List:
    """Copied from fairseq for no_repeat_ngram in beam_search"""
    if cur_len + 1 < no_repeat_ngram_size:
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].tolist()
        generated_ngram = generated_ngrams[idx]
        for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]

    def _get_generated_ngrams(hypo_idx):
        # Before decoding the next token, prevent decoding of ngrams that have already appeared
        start_idx = cur_len + 1 - no_repeat_ngram_size
        ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
        return generated_ngrams[hypo_idx].get(ngram_idx, [])

    banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
    return banned_tokens


@torch.no_grad()
def enforce_repetition_penalty_(lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
    """
    Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__).
    """
    for i in range(batch_size * num_beams):
        for previous_token in set(prev_output_tokens[i].tolist()):
            # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
            if lprobs[i, previous_token] < 0:
                lprobs[i, previous_token] *= repetition_penalty
            else:
                lprobs[i, previous_token] /= repetition_penalty


@torch.no_grad()
def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
    banned_tokens = []
    def _tokens_match(prev_tokens, tokens):
        if len(tokens) == 0:
            # if bad word tokens is just one token always ban it
            return True
        if len(tokens) > len(prev_tokens):
            # if bad word tokens are longer than prev tokens they can't be equal
            return False

        if prev_tokens[-len(tokens):] == tokens:
            # if tokens match
            return True
        else:
            return False
    for prev_input_ids_slice in prev_input_ids:
        banned_tokens_slice = []
        for banned_token_seq in bad_words_ids:
            assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
                bad_words_ids
            )
            if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
                # if tokens do not match continue
                continue
            banned_tokens_slice.append(banned_token_seq[-1])
        banned_tokens.append(banned_tokens_slice)
    return banned_tokens


@torch.no_grad()
def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
    """Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
    a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
        Args:
            scores: logits distribution of shape (batch size, vocabulary size)
            banned_tokens: list of list of tokens to ban of length (batch_size)
    """
    banned_mask_list = []
    for idx, batch_banned_tokens in enumerate(banned_tokens):
        for token in batch_banned_tokens:
            banned_mask_list.append([idx, token])
    if not banned_mask_list:
        return
    banned_mask = torch.LongTensor(banned_mask_list)
    indices = torch.ones(len(banned_mask))
    # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]].
    banned_mask = torch.sparse.Tensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
    scores.masked_fill_(banned_mask, -float("inf"))

class GPT2ONNXModel:
    def __init__(self, onnx_model_path, gpt2_model_path, device='cuda', verbose=False):
        self.config = AutoConfig.from_pretrained(gpt2_model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(gpt2_model_path)
        self.tokenizer.padding_side = "left"
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.num_attention_heads = self.config.n_head
        self.hidden_size = self.config.n_embd
        self.num_layer = self.config.n_layer
        self.verbose = verbose
        # Set the seed for onnxruntime.
        torch.random.manual_seed(random.randint(1, 9999999))
        torch.cuda.manual_seed(random.randint(1, 9999999))
        torch.manual_seed(random.randint(1, 9999999))
        numpy.random.seed(random.randint(1, 9999999))
        ort_set_seed(random.randint(1, 9999999))
        # Set our session options.
        options = ort.SessionOptions()
        options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
        # Only CPU supports parallel.
        if device == 'cpu':
            options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
        # Start the ONNX session.
        self.session = ort.InferenceSession(
            onnx_model_path,
            sess_options=options,
            providers=['CUDAExecutionProvider' if device == 'cuda' else 'CPUExecutionProvider']
        )
        self.device = device
        if self.verbose is True:
            print('ONNX Session Created!')

    @torch.no_grad()
    def generate(self, input_text, max_length=30, min_length=10, temperature: Optional[float] = None,
        do_sample: Optional[bool] = True, top_k: Optional[int] = None, top_p: Optional[float] = None,
        bad_words_ids: Optional[Iterable[int]] = None, repetition_penalty: Optional[float] = None,
        num_return_sequences: Optional[int] = None, no_repeat_ngram_size: Optional[int] = None,
        decoder_start_token_id: Optional[int] = None, length_penalty: Optional[float] = None,
        ):
        # Get the config variables from model config if not user supplied.
        max_length = max_length if max_length is not None else self.config.max_length
        min_length = min_length if min_length is not None else self.config.min_length
        temperature = temperature if temperature is not None else self.config.temperature
        top_k = top_k if top_k is not None else self.config.top_k
        top_p = top_p if top_p is not None else self.config.top_p
        do_sample = do_sample if do_sample is not None else self.config.do_sample
        no_repeat_ngram_size = no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
        num_return_sequences = num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
        repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
        bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
        decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
        if self.verbose is True:
            print('top_k: %s. top_p: %s. temperature: %s.' % (top_k, top_p, temperature))
        # Make sure user inputs are correct.
        assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
        assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
        assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
        assert temperature > 0, "`temperature` should be strictly positive."
        assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
        assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
        assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
        assert (
            isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
        ), "`no_repeat_ngram_size` should be a positive integer."
        # Get end of string token from the tokenizer.
        eos_token_id = self.tokenizer.eos_token_id
        # Get all sequences requested.
        sequences = []
        for seq in range(num_return_sequences):
            # Get our original input ids, attention mask, and position_ids from input_text.
            input_ids, attention_mask, position_ids, past = self.get_inputs(input_text)
            if self.verbose is True:
                print("sequence:", seq + 1)
                print("input_ids:", input_ids)
                print("attention_mask:", attention_mask)
                print("position_ids:", position_ids)
            batch_size = input_ids.size(0)
            has_eos = torch.zeros(batch_size, dtype=torch.bool).to(self.device)
            all_token_ids = input_ids.clone()
            # length of generated sentences / unfinished sentences
            unfinished_sents = input_ids.new(batch_size).fill_(1)
            sent_lengths = input_ids.new(batch_size).fill_(max_length)
            # Get X amount of tokens/words for each sequence.
            for step in range(max_length):
                outputs = self.inference_with_io_binding(
                    input_ids,
                    position_ids,
                    attention_mask,
                    past
                )
                next_token_logits = outputs[0][:, -1, :]
                # post process scores.
                scores = self.postprocess_next_token_scores(
                    next_token_logits,
                    input_ids,
                    no_repeat_ngram_size,
                    bad_words_ids,
                    step,
                    min_length,
                    max_length,
                    eos_token_id,
                    batch_size,
                    1,
                    repetition_penalty
                )
                if do_sample:
                    # Temperature (higher temperature => more likely to sample low probability tokens)
                    if temperature != 1.0:
                        scores = scores / temperature
                    # Top-p/top-k filtering
                    next_token_logscores = self.top_k_top_p_filtering(scores, top_k, top_p)
                    # Sample
                    probs = F.softmax(next_token_logscores, dim=-1)
                    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
                else:
                    # Greedy decoding
                    next_tokens = torch.argmax(next_token_logits, dim=-1)
                has_eos = has_eos | (next_tokens == eos_token_id)
                tokens_to_add = next_tokens.masked_fill(has_eos, eos_token_id)
                all_token_ids = torch.cat([all_token_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
                # Update input_ids, attention_mask, position_ids and past
                input_ids = tokens_to_add.clone().detach().reshape([batch_size, 1])
                position_ids = (position_ids[:, -1] + 1).reshape(batch_size, 1)
                attention_mask = torch.cat([attention_mask, torch.ones([batch_size, 1]).type_as(attention_mask)], 1)
                past = []
                for i in range(self.num_layer):
                    past_i = torch.from_numpy(outputs[i + 1]) if isinstance(outputs[i + 1], numpy.ndarray) else outputs[i + 1].clone().detach()
                    past.append(past_i)

                if torch.all(has_eos):
                    break

                if eos_token_id is not None:
                    eos_in_sents = tokens_to_add == eos_token_id
                    # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
                    is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
                    sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, step)
                    # unfinished_sents is set to zero if eos in sentence
                    unfinished_sents.mul_((~eos_in_sents).long())

                # stop when there is a </s> in each sentence, or if we exceed the maximul length
                if unfinished_sents.max() == 0:
                    break

            sequences += self.tokenizer.batch_decode(all_token_ids.tolist(), skip_special_tokens=True)
        return sequences

    @torch.no_grad()
    def get_inputs(self, prompt_text):
        encodings_dict = self.tokenizer.batch_encode_plus([prompt_text], padding=True)
        input_ids = torch.tensor(encodings_dict['input_ids'], dtype=torch.int64).to(self.device)
        attention_mask = torch.tensor(encodings_dict['attention_mask'], dtype=torch.float32).to(self.device)
        position_ids = (attention_mask.long().cumsum(-1) - 1).to(self.device)
        position_ids.masked_fill_(position_ids < 0, 0)
        # Empty Past State for generating first word
        batch_size = input_ids.size(0)
        past_shape = [2, batch_size, self.num_attention_heads, 0, self.hidden_size // self.num_attention_heads]
        empty_past = []
        for i in range(self.num_layer):
            empty_past.append(torch.empty(past_shape, dtype=torch.float32).to(self.device))
        return input_ids, attention_mask, position_ids, empty_past

    @torch.no_grad()
    def inference_with_io_binding(self, input_ids, position_ids, attention_mask, past):
        output_shapes = Gpt2Helper.get_output_shapes(
            batch_size=input_ids.size(0),
            past_sequence_length=past[0].size(3),
            sequence_length=input_ids.size(1),
            config=self.config
        )
        output_buffers = Gpt2Helper.get_output_buffers(output_shapes, self.device)
        io_binding = Gpt2Helper.prepare_io_binding(
            self.session,
            input_ids,
            position_ids,
            attention_mask,
            past,
            output_buffers,
            output_shapes
        )
        self.session.run_with_iobinding(io_binding)
        outputs = Gpt2Helper.get_outputs_from_io_binding_buffer(
            self.session,
            output_buffers,
            output_shapes,
            return_numpy=False
        )
        return outputs

    @torch.no_grad()
    def top_k_top_p_filtering(self, logits: Tensor, top_k: int = 0, top_p: float = 1.0,
        filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1,
        ) -> torch.Tensor:
        """
        Args:
            logits: logits distribution shape (batch size, vocabulary size)
            top_k: if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p: if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
            filter_value: value to use for filtering.
            min_tokens_to_keep: The minimum amount of tokens to keep.

        Returns:
            torch.Tensor()

        """
        if top_k > 0:
            top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
            # Remove all tokens with a probability less than the last token of the top-k
            indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
            logits[indices_to_remove] = filter_value
        if top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
            sorted_indices_to_remove = cumulative_probs > top_p
            if min_tokens_to_keep > 1:
                # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
                sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
            # Shift the indices to the right to keep also the first token above the threshold
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            # scatter sorted tensors to original indexing
            indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
            logits[indices_to_remove] = filter_value
        return logits

    @torch.no_grad()
    def postprocess_next_token_scores(self, scores, input_ids, no_repeat_ngram_size, bad_words_ids, cur_len,
        min_length, max_length, eos_token_id, batch_size, num_beams, repetition_penalty
        ):
        # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
        if repetition_penalty != 1.0:
            enforce_repetition_penalty_(
              scores,
              batch_size,
              num_beams,
              input_ids,
              repetition_penalty,
            )

        # set eos token prob to zero if min_length is not reached
        if eos_token_id is not None and cur_len < min_length:
            scores[:, eos_token_id] = -float("inf")

        if no_repeat_ngram_size > 0:
            # calculate a list of banned tokens to prevent repetitively generating the same ngrams
            num_batch_hypotheses = batch_size * num_beams
            banned_batch_tokens = calc_banned_ngram_tokens(
                input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
            )
            for i, banned_tokens in enumerate(banned_batch_tokens):
                scores[i, banned_tokens] = -float("inf")

        if bad_words_ids is not None:
            # Exclude EOS token (already processed)
            bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
            # calculate a list of banned tokens according to bad words
            banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids)
            # Modify the scores in place by setting the banned tokens logits to `-inf`
            set_scores_to_inf_for_banned_tokens(scores, banned_tokens)

        return scores



In [None]:
%%time
onnx_model = GPT2ONNXModel(
  'gpt2-final/gpt2_fp16.onnx', 
  'gpt2-xl', 
  device='cuda', 
  verbose=False
)

#Generate some text with your new GPT2-XL optimized model!

In [None]:
%%time
generated = onnx_model.generate(
  'George Washington was', 
  temperature=1.0, 
  top_k=50, 
  top_p=0.92,
  max_length=100,
  min_length=5,
  num_return_sequences=3,
  repetition_penalty=1.0,
  do_sample=True
)
print(generated)