In [1]:
import tensorflow as tf
from models.tf_models.CifarResnet_tf import get_Resnet18
from utils.tf_datasets import get_cifar_data
import numpy as np

# step 1 Load dataset and tf models and test accuracy

### Note: torch-onnx-tf-onnx is different from tf-onnx

## load dataset

In [2]:
x_train,y_train,x_val,y_val,x_test,y_test=get_cifar_data()

## load model and train

In [3]:
model = get_Resnet18(num_classes=10)
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 build_res_net (BuildResNet)  (None, 10)               11183562  
                                                                 
Total params: 11,183,562
Trainable params: 11,173,962
Non-trainable params: 9,600
_________________________________________________________________


In [4]:
EPOCHS=20
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=2)
history = model.fit(x_train, y_train, epochs=EPOCHS, validation_data = (x_val, y_val), batch_size=512,callbacks=[callback])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20


In [5]:
print("Evaluate on test data")
results = model.evaluate(x_test, y_test, batch_size=128)
print("test loss, test acc:", results)

Evaluate on test data
test loss, test acc: [1.8803019523620605, 0.6891000270843506]


In [6]:
# save as .hs file
save_path = "saved_models/tf2torch/CifarResnet18.h5"
model.save(save_path)
# save as saved_model


# step 2 convert from tensorflow to onnx

In [7]:
import tf2onnx
import onnxruntime as rt


spec = (tf.TensorSpec((None, 32, 32, 3), tf.float32, name="input"),)
output_path = "saved_models/tf2torch/CifarResnet18_from_keras.onnx"
model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13, output_path=output_path)
output_names = [n.name for n in model_proto.graph.output]

In [8]:
output_names

['build_res_net']

# step 3 test onnx model using onnxruntime 

In [9]:
import onnx
import onnxruntime as ort
load_path_1 = "saved_models/tf2torch/CifarResnet18_from_keras.onnx"
onnx_model_1 = onnx.load(load_path_1)


In [10]:
onnx.checker.check_model(onnx_model_1)


In [11]:
def get_accuracy(ort_sess):
    _correct=0
    _all=y_test.shape[0]
    batch_size=256
    num_batches = _all//batch_size
    for i in range(num_batches):
        if i!=num_batches-1:
            output=ort_sess.run(output_names=output_names,input_feed={"input":x_test[i*batch_size:(i+1)*batch_size]})
            pred = np.argmax(output[0],axis=1).reshape(-1,1)
            _correct+=(pred==y_test[i*batch_size:(i+1)*batch_size]).sum()
            
        else:
            output=ort_sess.run(output_names=output_names,input_feed={"input":x_test[i*batch_size:]})
            pred = np.argmax(output[0],axis=1).reshape(-1,1)
            _correct+=(pred==y_test[i*batch_size:]).sum()
    return _correct/_all

In [12]:
ort_sess_1 = ort.InferenceSession('saved_models/tf2torch/CifarResnet18_from_keras.onnx')
#ort_sess_2 = ort.InferenceSession('saved_models/torch2tf/CifarResnet18.onnx')
# Print accuracy Result
acc_1 = get_accuracy(ort_sess_1)
#acc_2 = get_accuracy(ort_sess_2)
print(f"accuracy of onnx model from tf: {acc_1}")
#print(f"accuracy of onnx model from torch: {acc_2}")

accuracy of onnx model from tf: 0.6891


# step 4 convert onnx to torch

In [13]:
from onnx2pytorch import ConvertModel

In [14]:
torch_model = ConvertModel(onnx_model_1,debug=False,experimental=True)
torch_model.eval()
torch_model.cpu()

  layer.weight.data = torch.from_numpy(numpy_helper.to_array(weight))


ConvertModel(
  (Transpose_model/build_res_net/conv2d/Conv2D__6:0): Transpose()
  (Conv_model/build_res_net/batch_normalization/FusedBatchNormV3:0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (Relu_model/build_res_net/Relu:0): ReLU(inplace=True)
  (Conv_model/build_res_net/sequential/basic_block/batch_normalization_1/FusedBatchNormV3:0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (Relu_model/build_res_net/sequential/basic_block/Relu:0): ReLU(inplace=True)
  (Conv_model/build_res_net/sequential/basic_block/batch_normalization_2/FusedBatchNormV3:0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (Add_model/build_res_net/sequential/basic_block/add/add:0): Add()
  (Relu_model/build_res_net/sequential/basic_block/Relu_1:0): ReLU(inplace=True)
  (Conv_model/build_res_net/sequential/basic_block_1/batch_normalization_3/FusedBatchNormV3:0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (Relu_model/build_

# step 5 test torch model accuracy

In [31]:
import torch
_all=len(y_test)
_correct=0
for img,label in zip(x_test,y_test):
    img = torch.from_numpy(img).unsqueeze(0)
    out=torch.argmax(torch_model(img)).item()
    if out==label:
        _correct+=1
print(f"accuracy: {_correct/_call}")

NameError: name '_call' is not defined