In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import onnx
from onnx_tf.backend import prepare
from models.torch_models.torch_models import resnet18
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from utils.data_loaders import get_cifar_loader,get_imagenette_loader
import tensorflow as tf
import numpy as np

import onnxruntime as ort
from onnx2pytorch import ConvertModel

2022-11-23 10:53:03.017990: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-23 10:53:04.026639: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-11-23 10:53:07.189152: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-11-23 10:53:07.189650: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or 

In [2]:
# check torch version and device
print(torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("use device:",device)
print("TensorFlow version:", tf.__version__)

1.13.0+cu117
use device: cuda
TensorFlow version: 2.10.0


# step 1 load pytorch model dataloader and train

In [6]:
torch.cuda.empty_cache()
#train_loader,val_loader,test_loader = get_cifar_loader(batch_size=2)
train_loader,val_loader,test_loader = get_imagenette_loader(batch_size=8)




In [None]:
epochs=3
early_stop_callback = EarlyStopping(monitor="val_acc", min_delta=0.00, patience=5, verbose=False, mode="max")

trainer = Trainer(max_epochs=epochs,fast_dev_run=False,accelerator="gpu",callbacks=[early_stop_callback])
model = resnet18(num_classes=10).to(device)

trainer.fit(model,train_loader,val_loader)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type     | Params
---------------------------------------
0 | model     | ResNet   | 11.2 M
1 | test_acc  | Accuracy | 0     
2 | valid_acc | Accuracy | 0     
---------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

running
running


Training: 0it [00:00, ?it/s]

# step 2 test pytorch model accuracy and save model

In [9]:
trainer.test(model,test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.7864000201225281
        test_loss           0.9610121250152588
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.9610121250152588, 'test_acc': 0.7864000201225281}]

In [144]:
torch_model = model.model
save_path = "saved_models/torch2tf/CifarResnet18.pth"
torch.save(torch_model.state_dict(),save_path)

In [10]:
# check if saved model is correct
load_path = "saved_models/torch2tf/resnet18.pth"
model2 = resnet18(num_classes=10).model
model2.load_state_dict(torch.load(load_path))
model2.eval()
model2=model2.to(device)
_all=0
_correct=0
for imgs,labels in test_loader:
    pred = torch.argmax(model2(imgs.to(device)),axis=1).to("cpu")
    _all+=len(labels)
    _correct+=(pred==labels).sum()
print(f"accuracy:{_correct/_all}")


accuracy:0.8144000172615051


In [8]:
# also save the whole model
full_model_save_path = "saved_models/torch2tf/CifarResnet18_model.pth"
torch.save(model2,full_model_save_path)

# step 3 convert pytorch model to onnx model and test onnx model accuracy

In [412]:
dummy_input = torch.randn(1,3,32,32,device="cuda")
save_path = "saved_models/torch2tf/CifarResnet18.onnx"

torch.onnx.export(model2,
                  dummy_input,
                  save_path,
                  input_names=["input"],
                  output_names=["output"],
                  dynamic_axes={'input':{0:'batch_size'}})


In [4]:
def get_accuracy(ort_sess):
    _correct=0
    _all=0
    for imgs,labels in test_loader:
        output = ort_sess.run(output_names=['output'],input_feed={'input': imgs.numpy()})
        pred = np.argmax(output[0],axis=1)
        _all+=len(labels)
        _correct+=(pred==labels.numpy()).sum()
    return _correct/_all
ort_sess = ort.InferenceSession('saved_models/torch2tf/CifarResnet18.onnx')
acc = get_accuracy(ort_sess)
print(f"accuracy of onnx model from torch: {acc}")

accuracy of onnx model from torch: 0.8144


# step 4 convert onnx model to tf model test accuracy and save model

In [413]:
load_path = "saved_models/torch2tf/CifarResnet18.onnx"
onnx_model = onnx.load(load_path)
tf_rep = prepare(onnx_model)

In [414]:
_all=0
_correct=0
for imgs,labels in test_loader:
    pred = np.argmax(tf_rep.run(imgs)[0],axis=1)
    _all+=len(labels)
    _correct+=(pred==labels.numpy()).sum()
print(f"accuracy:{_correct/_all}")


accuracy:0.8144


In [415]:
save_path = "saved_models/torch2tf/CifarResnet18"
tf_rep.export_graph(save_path)



INFO:tensorflow:Assets written to: saved_models/torch2tf/CifarResnet18\assets


INFO:tensorflow:Assets written to: saved_models/torch2tf/CifarResnet18\assets


# step 5 load tf model and test accuracy

In [11]:
load_path = "saved_models/torch2tf/CifarResnet18"
loaded=tf.saved_model.load(load_path)

In [418]:
print(list(loaded.signatures.keys())) 
infer = loaded.signatures["serving_default"]
key=list(infer.structured_outputs.keys())[0]

['serving_default']


In [419]:
_all=0
_correct=0
for imgs,labels in test_loader:
    out = infer(**{'input': imgs})
    pred = np.argmax(out[key],axis=1)
    _all+=len(labels)
    _correct+=(pred==labels.numpy()).sum()
print(f"accuracy:{_correct/_all}")

accuracy:0.8144


In [420]:
# additional test one single sample if they produce same result
test_img=imgs[0].unsqueeze(0)
model2.eval()
out1=model2(test_img.to(device)).cpu()
out2=infer(**{'input': test_img})[key]
print(out1.detach().numpy())
print(out2.numpy())

[[-6.654736  -7.450909  -4.149689  -2.9547741 -2.7491271  1.6382511
  -8.316099  10.531688  -8.481063  -4.960392 ]]
[[-6.654734  -7.450907  -4.1496887 -2.9547796 -2.749123   1.638253
  -8.316097  10.531686  -8.481064  -4.96039  ]]


# step 6 convert onnx model back to torch model and test accuracy

In [5]:
load_path = "saved_models/torch2tf/CifarResnet18.onnx"
onnx_model = onnx.load(load_path)
torch_model = ConvertModel(onnx_model,debug=False)

  layer.weight.data = torch.from_numpy(numpy_helper.to_array(weight))


In [6]:
def get_acc_from_converted_pytorch_model(model):
    _all=0
    _correct=0
    for imgs,labels in test_loader:
        _all+=len(labels)
        for img,label in zip(imgs,labels):
            output = model(img.unsqueeze(0))
            _correct+=(torch.argmax(output)==label).item()
    return _correct/_all


In [8]:
acc=get_acc_from_converted_pytorch_model(torch_model)
print(f"acc:{acc}")

acc:0.8144


In [11]:
# use onnx2torch
from onnx2torch import convert
load_path = "saved_models/torch2tf/resnet18.onnx"
onnx_model = onnx.load(load_path)
torch_model_2 = convert(load_path).to(device)
_correct=0
_all=0
for (img,labels) in test_loader:
    _all+=len(labels)
    out=torch_model_2(img.to(device)).to("cpu")
    pred = torch.argmax(out,axis=1)
    _correct+=(pred==labels).sum()
print(f"accuracy: {_correct/_all}")

accuracy: 0.6853502988815308
