# GSoC pretest

## Imports

In [None]:
import torch
import timm
import torch.nn as nn
from PIL import Image
from torchvision import transforms
import torch.quantization as quantization
import matplotlib.pyplot as plt
from openvino.runtime import Core
from openvino.runtime import serialize
from openvino.tools import mo
import onnx
import onnxruntime as ort
from onnxruntime.quantization import QuantType, quantize_dynamic

## Load Model

In [None]:
model_names = timm.list_models('swin_small*')
for model_name in model_names:
    print(model_name)
model = timm.create_model('swin_small_patch4_window7_224', pretrained=True)

## Run model at ont-quantization

In [None]:
# Load image and preprocess
image = Image.open('../notebooks/data/image/coco.jpg')
# plt.imshow(image)
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.299, 0.224, 0.255])
])
image = transform(image).unsqueeze(0)

# Predict the class of image
with torch.no_grad():
    output = model(image)
    pred = output.argmax(dim=1).item()
    print(f'Predicted class: {pred}')
# After searching, the 208th category in ImageNet is dog.

## Define convert model to onnx and IR function

In [None]:
def convert_models(model, model_input, path):
    script_model = torch.jit.trace(model, model_input)
    torch.onnx.export(script_model,  model_input, path)   
    convert_model = mo.convert_model(path)
    # Change the .onnx suffix to .xml
    IR_path = path[:-4]+'xml'
    serialize(convert_model, IR_path)

## Convert model and test them

In [None]:
# Convert model to onnx and IR
# If you already have a models directory, place comment the next line of code.
!mkdir models
model.eval()
onnx_path = 'models/swin_small_patch4_window7_224.onnx'
model_input = torch.randn(1,3,224,224).cpu() 
convert_models(model, model_input,onnx_path)

## Use Pytorch built-in quantization to quantize the model and convert them

In [None]:
# ie = Core()
# model_ir = ie.read_model(model=(onnx_path[:-4]+'xml'))
# model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_int8 = quantization.quantize_dynamic(model, dtype=torch.qint8)
# quantization.prepare(model, inplace=True)
# model_int8 = quantization.convert(model, inplace=True)
with torch.no_grad():
    output = model_int8(image)
    pred = output.argmax(dim=1).item()
    print(f'Predicted class: {pred}')
torch.save(model_int8, 'models/quantize_swin.pth')

In [None]:
quantize_path = 'models/quantize_swin.onnx'
quantize_dynamic(model_input=onnx_path,
    model_output= quantize_path,
    weight_type=QuantType.QInt8,
    optimize_model=True
)


In [None]:
convert_model = mo.convert_model(quantize_path)
# Change the .onnx suffix to .xml
IR_path = quantize_path[:-4]+'xml'
serialize(convert_model, IR_path)
# convert_models(torch_model,model_input, quantize_path)


In [None]:
model_int8.eval()
quantize_path = 'models/quantize_swin.onnx'
# convert_models(model_int8,model_input, quantize_path)
# script_model = torch.jit.trace(model_int8, model_input)
# torch.onnx.export(script_model,  model_input, quantize_path)
# convert_model = mo.convert_model(quantize_path)
# # # Change the .onnx suffix to .xml
# IR_path = quantize_path[:-4]+'xml'
# serialize(convert_model, IR_path)