<a href="https://colab.research.google.com/github/nyadla-sys/pytorch_2_tflite/blob/main/tinynn_pytorch_to_tflite_int8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install git+https://github.com/alibaba/TinyNeuralNetwork.git

Collecting git+https://github.com/alibaba/TinyNeuralNetwork.git
  Cloning https://github.com/alibaba/TinyNeuralNetwork.git to /tmp/pip-req-build-sap89wfs
  Running command git clone -q https://github.com/alibaba/TinyNeuralNetwork.git /tmp/pip-req-build-sap89wfs
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting ruamel.yaml>=0.16.12
  Downloading ruamel.yaml-0.17.21-py3-none-any.whl (109 kB)
[K     |████████████████████████████████| 109 kB 5.4 MB/s 
Collecting python-igraph>=0.9.6
  Downloading python_igraph-0.9.9-py3-none-any.whl (9.1 kB)
Collecting tflite==2.3.0
  Downloading tflite-2.3.0-py2.py3-none-any.whl (79 kB)
[K     |████████████████████████████████| 79 kB 7.3 MB/s 
[?25hCollecting PyYAML>=5.3.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |█████████████████████

In [None]:
!wget --no-check-certificate \
    https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip \
    -O /content/cats_and_dogs_filtered.zip

--2022-03-11 15:24:52--  https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.152.128, 173.194.198.128, 173.194.74.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.152.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 68606236 (65M) [application/zip]
Saving to: ‘/content/cats_and_dogs_filtered.zip’


2022-03-11 15:24:53 (222 MB/s) - ‘/content/cats_and_dogs_filtered.zip’ saved [68606236/68606236]



In [None]:
import os
import zipfile

local_zip = '/content/cats_and_dogs_filtered.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('/content')
zip_ref.close()

In [None]:
import random

from glob import glob

from PIL import Image
import torch
from torchvision import transforms
import torchvision.models as models


from tinynn.converter import TFLiteConverter
from tinynn.graph.quantization.quantizer import PostQuantizer
from tinynn.graph.tracer import model_tracer
from tinynn.util.cifar10 import get_dataloader, train_one_epoch, validate
from tinynn.util.train_util import DLContext, get_device, train


random.seed(0)


with model_tracer():
  model = models.mobilenet_v2(pretrained=True)
  model.eval()

  # Provide a viable input for the model
  dummy_input = torch.rand((1, 3, 224, 224))

  quantizer = PostQuantizer(model, dummy_input, work_dir='out', config={'asymmetric': True, 'per_tensor': False})
  qat_model = quantizer.quantize()

print(qat_model)

# Use DataParallel to speed up training when possible
if torch.cuda.device_count() > 1:
  qat_model = nn.DataParallel(qat_model)

# Move model to the appropriate device
device = get_device()
qat_model.to(device=device)


dataset_list = glob('/content/cats_and_dogs_filtered/train/**/*', recursive=True)
random.shuffle(dataset_list)
for i in range(100):
  filename = dataset_list[i]      
  print(filename)
  input_image = Image.open(filename)
  preprocess = transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  ])
  input_tensor = preprocess(input_image)
  print(input_tensor.shape)
  input_tensor = torch.unsqueeze(input_tensor, 0)
  print("torch input_tensor size")
  print(input_tensor.shape)    
  qat_model(input_tensor.to(device=device))
  

with torch.no_grad():
  qat_model.eval()
  qat_model.cpu()

  # The step below converts the model to an actual quantized model, which uses the quantized kernels.
  qat_model = torch.quantization.convert(qat_model)

  # When converting quantized models, please ensure the quantization backend is set.
  torch.backends.quantized.engine = quantizer.backend

  # The code section below is used to convert the model to the TFLite format
  # If you need a quantized model with a specific data type (e.g. int8)
  # you may specify `quantize_target_type='int8'` in the following line.
  # If you need a quantized model with strict symmetric quantization check (with pre-defined zero points),
  # you may specify `strict_symmetric_check=True` in the following line.
  converter = TFLiteConverter(qat_model, dummy_input, tflite_path='out/qat_model.tflite', quantize_target_type='int8', input_transpose=False, fuse_quant_dequant=True)
  converter.convert()



Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth


  0%|          | 0.00/13.6M [00:00<?, ?B/s]

MobileNetV2_qat(
  (fake_quant_0): QuantStub(
    (activation_post_process): HistogramObserver()
  )
  (features_0_0): Conv2d(
    3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
    (activation_post_process): HistogramObserver()
  )
  (features_0_1): Identity()
  (features_0_2): ReLU6(
    inplace=True
    (activation_post_process): HistogramObserver()
  )
  (features_1_conv_0_0): Conv2d(
    32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32
    (activation_post_process): HistogramObserver()
  )
  (features_1_conv_0_1): Identity()
  (features_1_conv_0_2): ReLU6(
    inplace=True
    (activation_post_process): HistogramObserver()
  )
  (features_1_conv_1): Conv2d(
    32, 16, kernel_size=(1, 1), stride=(1, 1)
    (activation_post_process): HistogramObserver()
  )
  (features_1_conv_2): Identity()
  (features_2_conv_0_0): Conv2d(
    16, 96, kernel_size=(1, 1), stride=(1, 1)
    (activation_post_process): HistogramObserver()
  )
  (features_2_conv_0_1): Ident

  src_bin_begin // dst_bin_width, 0, self.dst_nbins - 1
  src_bin_end // dst_bin_width, 0, self.dst_nbins - 1
INFO (tinynn.converter.base) Generated model saved to out/qat_model.tflite


In [None]:
# Download an example image from the pytorch website
import urllib
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)

import tensorflow as tf
import numpy as np
tflite_model_path = '/content/out/qat_model.tflite'
interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()

# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
test_details = interpreter.get_input_details()[0]

scale, zero_point = test_details['quantization']
print(scale)
print(zero_point)

# Test the model on image  data
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
input_image = Image.open(filename)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
print(input_tensor.shape)
input_tensor = torch.unsqueeze(input_tensor, 0)
input_tensor = torch.quantize_per_tensor(input_tensor, torch.tensor(scale), torch.tensor(zero_point), torch.qint8)
input_tensor = torch.int_repr(input_tensor).numpy()

print("torch input_tensor size:")
print(input_tensor.shape)
print(input_tensor)
interpreter.set_tensor(input_details[0]['index'], input_tensor)

interpreter.invoke()

# get_tensor() returns a copy of the tensor data
# use tensor() in order to get a pointer to the tensor
output_data = interpreter.get_tensor(output_details[0]['index'])

print("Predicted value . Label index: {}, confidence: {:2.0f}%"
      .format(np.argmax(output_data), 
              100 * output_data[0][np.argmax(output_data)]))

0.01871182955801487
-15
torch.Size([3, 224, 224])
torch input_tensor size:
(1, 3, 224, 224)
[[[[-118 -118 -117 ... -124 -119 -118]
   [-122 -116 -117 ... -120 -118 -110]
   [-122 -119 -117 ... -125 -120 -116]
   ...
   [ -94 -101 -102 ...  -61  -74  -72]
   [ -97 -101 -102 ...  -71 -102  -92]
   [ -98  -94  -82 ...  -64  -83  -82]]

  [[-113 -113 -112 ... -121 -117 -117]
   [-113 -114 -114 ... -120 -118 -115]
   [-113 -114 -114 ... -120 -119 -117]
   ...
   [ -68  -67  -68 ...  -40  -49  -51]
   [ -68  -68  -69 ...  -45  -75  -63]
   [ -66  -68  -58 ...  -36  -58  -55]]

  [[-101 -100  -98 ... -107 -106 -107]
   [-104 -102 -102 ... -107 -107 -103]
   [-103 -104 -103 ... -108 -106 -105]
   ...
   [ -85  -96 -100 ...  -63  -77  -73]
   [ -84  -97 -101 ...  -66  -95  -89]
   [ -83  -91  -81 ...  -55  -75  -83]]]]
Predicted value . Label index: 258, confidence: 7400%
