### Settings

Install the missing requirements in the Colab VM.

In [None]:
!pip install transformers onnx onnxruntime

Download the CodeGen pre-trained model and tokenizer. 

In [None]:
import torch
device = "cpu"
if torch.cuda.is_available():
  torch.set_default_tensor_type(torch.cuda.FloatTensor)
  device = "cuda"
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer_id = "Salesforce/codegen-350M-mono"
model_id = "Salesforce/codegen-350M-mono"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)

In [None]:
tokenizer.save_pretrained("local-pt-checkpoint")
model.save_pretrained("local-pt-checkpoint")

### Conversion to ONNX Format

Convert the pre-trained model to the ONNX format using the tool available in the Transformers library. The command below performs also validation at the end of the conversion process.

In [None]:
!python -m transformers.onnx --feature "causal-lm" --framework pt --export_with_transformers --model=local-pt-checkpoint onnx/

### Quantization

Do 8-bit quantization of the ONNX converted model.

In [None]:
onnx_model_path = "onnx/model.onnx"
quantized_model_path = "model.quant.onnx"

In [None]:
import os
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

def quantize_onnx_model(onnx_model_path, quantized_model_path):    
    onnx_opt_model = onnx.load(onnx_model_path)
    quantize_dynamic(onnx_model_path,
                     quantized_model_path,
                     weight_type=QuantType.QInt8)

quantize_onnx_model(onnx_model_path, quantized_model_path)

print('ONNX full precision model size (MB):', os.path.getsize(onnx_model_path)/(1024*1024))
print('ONNX quantized model size (MB):', os.path.getsize(quantized_model_path)/(1024*1024))

### Benchmarks

Define some utility functions to perform benchmarks of different versions of the model with diverse providers in the ONNX runtime.

In [None]:
from contextlib import contextmanager
from dataclasses import dataclass
from time import time
from tqdm import trange
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_all_providers

def create_model_for_provider(model_path: str, provider: str, thread_pooling=False) -> InferenceSession: 
  
  assert provider in get_all_providers(), f"provider {provider} not found, {get_all_providers()}"

  options = SessionOptions()
  if thread_pooling:
    options.intra_op_num_threads = 1
  options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
 
  session = InferenceSession(model_path, options, providers=[provider])
  session.disable_fallback()
    
  return session

@contextmanager
def track_infer_time(buffer: [int]):
    start = time()
    yield
    end = time()

    buffer.append(end - start)

@dataclass
class OnnxInferenceResult:
  model_inference_time: [int]  
  optimized_model_path: str

Prepare the input to use for benchmarking the original model (PyTorch Tensor) and the ONNX versions (numpy array).

In [None]:
from transformers import CodeGenTokenizerFast

tokenizer = CodeGenTokenizerFast.from_pretrained(model_id)

model_inputs = tokenizer("def hello_world():", return_tensors="pt")
inputs_onnx = {k: v.cpu().detach().numpy() for k, v in model_inputs.items()}

Benchmark inference of the original PyTorch model on CPU.

In [None]:
from transformers import CodeGenModel

PROVIDERS = {
    ("cpu", "PyTorch CPU"),
}

results = {}

for device, label in PROVIDERS:
    
    model_inputs_on_device = {
        arg_name: tensor.to(device)
        for arg_name, tensor in model_inputs.items()
    }

    model_pt = CodeGenModel.from_pretrained(model_id).to(device)
    for _ in trange(10, desc="Warming up"):
      model_pt(**model_inputs_on_device)

    # Compute 
    time_buffer = []
    for _ in trange(100, desc=f"Tracking inference time on PyTorch"):
      with track_infer_time(time_buffer):
        model_pt(**model_inputs_on_device)

    # Store the result
    results[label] = OnnxInferenceResult(
        time_buffer, 
        None
    ) 

Benchmark the converted model to ONNX format in the ONNX runtime (CPU).

In [None]:
PROVIDERS = {
    ("CPUExecutionProvider", "ONNX CPU"),
}

for provider, label in PROVIDERS:
    model = create_model_for_provider(onnx_model_path, provider)

    time_buffer = []

    model.run(None, inputs_onnx)
 
    for _ in trange(100, desc=f"Tracking inference time on {provider}"):
      with track_infer_time(time_buffer):
          model.run(None, inputs_onnx)

    results[label] = OnnxInferenceResult(
      time_buffer,
      model.get_session_options().optimized_model_filepath
    )

Compare benchmark results visually.

In [None]:
import numpy as np
import plotly.express as px

# Compute average inference time
time_results = {k: np.mean(v.model_inference_time) * 1e3 for k, v in results.items()}

fig = px.bar(x=time_results.keys(), y=time_results.values(), 
             title="Average inference time (ms) for each provider", 
             labels={'x':'Provider', 'y':'Avg Inference time (ms)'},
             text_auto='.2s')
fig.show()

Quantize and benchmark the original PyTorch model on CPU.

In [None]:
import torch 

model_pt_quantized = torch.quantization.quantize_dynamic(
    model_pt.to("cpu"), {torch.nn.Linear}, dtype=torch.qint8
)

model_pt_quantized(**model_inputs)

time_buffer = []
for _ in trange(100):
    with track_infer_time(time_buffer):
        model_pt_quantized(**model_inputs)
    
results["PyTorch CPU Quantized"] = OnnxInferenceResult(
    time_buffer,
    None
)

Benchmark the ONNX quantized model in the ONNX runtime (CPU).

In [None]:
quantized_model = create_model_for_provider(quantized_model_path, "CPUExecutionProvider")

outputs = quantized_model.run(None, inputs_onnx)

time_buffer = []
for _ in trange(100, desc=f"Tracking inference time on CPUExecutionProvider with quantized model"):
    with track_infer_time(time_buffer):
        outputs = quantized_model.run(None, inputs_onnx)

results["ONNX CPU Quantized"] = OnnxInferenceResult(
    time_buffer, 
    quantized_model_path
) 

Compare all the benchmark results visually.

In [None]:
# Compute average inference time and standard deviation
time_results = {k: np.mean(v.model_inference_time) * 1e3 for k, v in results.items()}
time_results_std = {k: np.std(v.model_inference_time) * 1000 for k, v in results.items()}

fig = px.bar(x=time_results.keys(), y=time_results.values(), 
             title="Average inference time (ms) for each provider", 
             labels={'x':'Provider', 'y':'Avg Inference time (ms)'},
             color=time_results.values(),
             color_continuous_scale=px.colors.sequential.Tealgrn,
             text_auto='.2s')
fig.show()

In [None]:
import plotly.graph_objects as go

fig = go.Figure()
fig.add_trace(go.Bar(
    name='Control',
    x=list(time_results.keys()), y=list(time_results.values()),  
    error_y=dict(type='data', array=list(time_results_std.values())),
    marker=dict(
        colorscale='Tealgrn',
        showscale=True
    )
))
fig.update_xaxes(title_text="Provider")
fig.update_yaxes(title_text="Avg Inference time (ms)")
fig.show()