In [5]:
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import onnx
from onnx_tf.backend import prepare
from models.torch_models import LitCifarResnet
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from utils.data_loaders import get_cifar_loader
import tensorflow as tf

In [6]:
# check torch version and device
print(torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("use device:",device)
print("TensorFlow version:", tf.__version__)

1.13.0+cu116
use device: cuda


# step 1 load pytorch model dataloader and train

In [7]:
train_loader,val_loader,test_loader = get_cifar_loader(batch_size=256)

epochs=10
early_stop_callback = EarlyStopping(monitor="val_acc", min_delta=0.00, patience=5, verbose=False, mode="max")

trainer = Trainer(max_epochs=epochs,fast_dev_run=False,accelerator="gpu",callbacks=[early_stop_callback])
model = LitCifarResnet().to(device)

trainer.fit(model,train_loader,test_loader)


Files already downloaded and verified
Files already downloaded and verified


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type     | Params
---------------------------------------
0 | model     | ResNet   | 11.2 M
1 | test_acc  | Accuracy | 0     
2 | valid_acc | Accuracy | 0     
---------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


# step 2 test pytorch model accuracy and save model

In [143]:
trainer.test(model,test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.8144000172615051
        test_loss           0.9494243860244751
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.9494243860244751, 'test_acc': 0.8144000172615051}]

In [144]:
torch_model = model.model
save_path = "saved_models/torch2tf/CifarResnet18.pth"
torch.save(torch_model.state_dict(),save_path)

In [347]:
# check if saved model is correct
load_path = "saved_models/torch2tf/CifarResnet18.pth"
model2 = LitCifarResnet().model
model2.load_state_dict(torch.load(load_path))
model2.eval()
model2=model2.to(device)
_all=0
_correct=0
for imgs,labels in test_loader:
    pred = torch.argmax(model2(imgs.to(device)),axis=1).to("cpu")
    _all+=len(labels)
    _correct+=(pred==labels).sum()
print(f"accuracy:{_correct/_all}")


accuracy:0.8144000172615051


# step 3 convert pytorch model to onnx model

In [412]:
dummy_input = torch.randn(1,3,32,32,device="cuda")
save_path = "saved_models/torch2tf/CifarResnet18.onnx"

torch.onnx.export(model2,
                  dummy_input,
                  save_path,
                  input_names=["input"],
                  output_names=["output"],
                  dynamic_axes={'input':{0:'batch_size'}})


# step 4 convert onnx model to tf model test accuracy and save model

In [413]:
load_path = "saved_models/torch2tf/CifarResnet18.onnx"
onnx_model = onnx.load(load_path)
tf_rep = prepare(onnx_model)

In [414]:
_all=0
_correct=0
for imgs,labels in test_loader:
    pred = np.argmax(tf_rep.run(imgs)[0],axis=1)
    _all+=len(labels)
    _correct+=(pred==labels.numpy()).sum()
print(f"accuracy:{_correct/_all}")


accuracy:0.8144


In [415]:
save_path = "saved_models/torch2tf/CifarResnet18"
tf_rep.export_graph(save_path)



INFO:tensorflow:Assets written to: saved_models/torch2tf/CifarResnet18\assets


INFO:tensorflow:Assets written to: saved_models/torch2tf/CifarResnet18\assets


# step 5 load tf model and test accuracy

In [416]:
load_path = "saved_models/torch2tf/CifarResnet18"
loaded=tf.saved_model.load(load_path)

In [418]:
print(list(loaded.signatures.keys())) 
infer = loaded.signatures["serving_default"]
key=list(infer.structured_outputs.keys())[0]

['serving_default']


In [419]:
_all=0
_correct=0
for imgs,labels in test_loader:
    out = infer(**{'input': imgs})
    pred = np.argmax(out[key],axis=1)
    _all+=len(labels)
    _correct+=(pred==labels.numpy()).sum()
print(f"accuracy:{_correct/_all}")

accuracy:0.8144


In [420]:
# additional test one single sample if they produce same result
test_img=imgs[0].unsqueeze(0)
model2.eval()
out1=model2(test_img.to(device)).cpu()
out2=infer(**{'input': test_img})[key]
print(out1.detach().numpy())
print(out2.numpy())

[[-6.654736  -7.450909  -4.149689  -2.9547741 -2.7491271  1.6382511
  -8.316099  10.531688  -8.481063  -4.960392 ]]
[[-6.654734  -7.450907  -4.1496887 -2.9547796 -2.749123   1.638253
  -8.316097  10.531686  -8.481064  -4.96039  ]]


In [41]:
# laod dataset. It is checked that this dataset is the same as the one download from pytorch
# stored under C:\Users\My_User_Name\.keras\datasets 
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
