In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import onnx
from onnx_tf.backend import prepare
from models.torch_models.torch_models import resnet18
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from utils.data_loaders import get_cifar_loader,get_imagenette_loader
import tensorflow as tf
import numpy as np
import onnx_tf
import onnxruntime as ort
from onnx2pytorch import ConvertModel

In [None]:
# check torch version and device
print(torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("use device:",device)
print("TensorFlow version:", tf.__version__)

# step 1 load pytorch model dataloader and train

In [None]:
torch.cuda.empty_cache()
train_loader,val_loader,test_loader = get_cifar_loader(batch_size=64)
#train_loader,val_loader,test_loader = get_imagenette_loader(batch_size=8)


In [None]:

epochs=3
early_stop_callback = EarlyStopping(monitor="val_acc", min_delta=0.00, patience=5, verbose=False, mode="max")

trainer = Trainer(max_epochs=epochs,fast_dev_run=False,accelerator="gpu",callbacks=[early_stop_callback])
model = resnet18(num_classes=10).to(device)

trainer.fit(model,train_loader,val_loader)


# step 2 test pytorch model accuracy and save model

In [None]:
trainer.test(model,test_loader)

In [None]:
torch_model = model.model
save_path = "saved_models/torch2tf/CifarResnet18.pth"
torch.save(torch_model.state_dict(),save_path)

In [None]:
# check if saved model is correct
load_path = "saved_models/torch2tf/resnet18.pth"
model2 = resnet18(num_classes=10).model
model2.load_state_dict(torch.load(load_path))
model2.eval()
model2=model2.to(device)
_all=0
_correct=0
for imgs,labels in test_loader:
    pred = torch.argmax(model2(imgs.to(device)),axis=1).to("cpu")
    _all+=len(labels)
    _correct+=(pred==labels).sum()
print(f"accuracy:{_correct/_all}")


In [None]:
# also save the whole model
full_model_save_path = "saved_models/torch2tf/CifarResnet18_model.pth"
torch.save(model2,full_model_save_path)

# step 3 convert pytorch model to onnx model and test onnx model accuracy

In [None]:
dummy_input = torch.randn(1,3,32,32,device="cuda")
save_path = "saved_models/torch2tf/CifarResnet18.onnx"

torch.onnx.export(model2,
                  dummy_input,
                  save_path,
                  input_names=["input"],
                  output_names=["output"],
                  dynamic_axes={'input':{0:'batch_size'}})


In [None]:
def get_accuracy(ort_sess):
    _correct=0
    _all=0
    for imgs,labels in test_loader:
        output = ort_sess.run(output_names=['output'],input_feed={'input': imgs.numpy()})
        pred = np.argmax(output[0],axis=1)
        _all+=len(labels)
        _correct+=(pred==labels.numpy()).sum()
    return _correct/_all
ort_sess = ort.InferenceSession('saved_models/torch2tf/CifarResnet18.onnx')
acc = get_accuracy(ort_sess)
print(f"accuracy of onnx model from torch: {acc}")

# step 4 convert onnx model to tf model test accuracy and save model

In [None]:
load_path = "saved_models/torch2tf/CifarResnet18.onnx"
onnx_model = onnx.load(load_path)
tf_rep = prepare(onnx_model)

In [None]:
_all=0
_correct=0
for imgs,labels in test_loader:
    pred = np.argmax(tf_rep.run(imgs)[0],axis=1)
    _all+=len(labels)
    _correct+=(pred==labels.numpy()).sum()
print(f"accuracy:{_correct/_all}")


In [None]:
save_path = "saved_models/torch2tf/CifarResnet18"
tf_rep.export_graph(save_path)

In [None]:
# test training converted tf model

In [None]:
tf_rep = onnx_tf.backend.prepare(onnx_model,training_mode=True)

In [None]:
tf_compat = tf.compat.v1
epochs=1
training_flag_placeholder = tf_rep.tensor_dict[
    onnx_tf.backend.training_flag_name]
input_name = onnx_model.graph.input[0].name
output_name = onnx_model.graph.output[0].name

with tf_rep.graph.as_default():
    with tf_compat.Session() as sess:
        y_truth = tf_compat.placeholder(tf.int64, [None], name='y-input')
        tf_rep.tensor_dict["y_truth"] = y_truth
        loss_op = tf.reduce_mean(
            tf_compat.losses.sparse_softmax_cross_entropy(
                labels=tf_rep.tensor_dict['y_truth'],
                logits=tf_rep.tensor_dict[output_name]))
        opt_op = tf_compat.train.AdamOptimizer().minimize(loss_op)
        eval_op = tf.reduce_mean(input_tensor=tf.cast(
            tf.equal(tf.argmax(input=tf_rep.tensor_dict[output_name], axis=1),
            tf_rep.tensor_dict['y_truth']), tf.float32))
        x_train,y_train,x_val,y_val,x_test,y_test = get_cifar_data()
        train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(batch_size, drop_remainder=True)
        sess.run(tf_compat.global_variables_initializer())
        for epoch in range(1, epochs + 1):
            step = 1
            next_batch = tf_compat.data.make_one_shot_iterator(train_ds).get_next()
            while True:
                try:
                    next_batch_value = sess.run(next_batch)
                    feed_dict = {
                        tf_rep.tensor_dict[input_name]: next_batch_value[0].transpose((0, 3, 1, 2)),#for pytorch model
                        #tf_rep.tensor_dict[input_name]:next_batch_value[0],
                        tf_rep.tensor_dict['y_truth']:next_batch_value[1].flatten()
                                }
                    feed_dict[training_flag_placeholder] = True
                    loss, accuracy, _ = sess.run([loss_op, eval_op, opt_op],feed_dict=feed_dict)
                    print('Epoch {}, train step {}, loss:{}, accuracy:{}'.format(epoch, step, loss, accuracy))
                    step += 1
                except tf.errors.OutOfRangeError:
                    step = 1
                    break
                


# step 5 load tf model and test accuracy

In [None]:
load_path = "saved_models/torch2tf/CifarResnet18"
loaded=tf.saved_model.load(load_path)

In [None]:
print(list(loaded.signatures.keys())) 
infer = loaded.signatures["serving_default"]
key=list(infer.structured_outputs.keys())[0]

In [None]:
_all=0
_correct=0
for imgs,labels in test_loader:
    out = infer(**{'input': imgs})
    pred = np.argmax(out[key],axis=1)
    _all+=len(labels)
    _correct+=(pred==labels.numpy()).sum()
print(f"accuracy:{_correct/_all}")

In [None]:
# additional test one single sample if they produce same result
test_img=imgs[0].unsqueeze(0)
model2.eval()
out1=model2(test_img.to(device)).cpu()
out2=infer(**{'input': test_img})[key]
print(out1.detach().numpy())
print(out2.numpy())

# step 6 convert onnx model back to torch model and test accuracy

In [None]:
load_path = "saved_models/torch2tf/CifarResnet18.onnx"
onnx_model = onnx.load(load_path)
torch_model = ConvertModel(onnx_model,debug=False)

In [None]:
def get_acc_from_converted_pytorch_model(model):
    _all=0
    _correct=0
    for imgs,labels in test_loader:
        _all+=len(labels)
        for img,label in zip(imgs,labels):
            output = model(img.unsqueeze(0))
            _correct+=(torch.argmax(output)==label).item()
    return _correct/_all


In [None]:
acc=get_acc_from_converted_pytorch_model(torch_model)
print(f"acc:{acc}")

In [None]:
# use onnx2torch
from onnx2torch import convert
load_path = "saved_models/torch2tf/CifarResnet18.onnx"
onnx_model = onnx.load(load_path)
torch_model_2 = convert(onnx_model).to(device)

In [None]:
_correct=0
_all=0
for (img,labels) in test_loader:
    _all+=len(labels)
    out=torch_model_2(img.to(device)).to("cpu")
    pred = torch.argmax(out,axis=1)
    _correct+=(pred==labels).sum()
print(f"accuracy: {_correct/_all}")

In [None]:
# train torch model
import torchmetrics
import pytorch_lightning as pl
import torch.nn.functional as F
learning_rate = 1e-2

class LitModel(pl.LightningModule):
    def __init__(self,model):
        super(LitModel,self).__init__()
        self.model=model
        self.test_acc = torchmetrics.Accuracy()
        self.valid_acc = torchmetrics.Accuracy()
    def forward(self,x):
        return self.model(x)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(),lr=learning_rate)
        return optimizer
    
    def training_step(self,batch,batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = F.cross_entropy(outputs,labels)
        self.log("train_loss", loss)
        return {"loss":loss}
    
    def validation_step(self,batch,batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = F.cross_entropy(outputs,labels)
        self.valid_acc(outputs, labels)
        self.log("val_loss", loss)
        self.log("val_acc",self.valid_acc)
        return {"val_loss":loss}

    def test_step(self,batch,batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = F.cross_entropy(outputs,labels)
        self.test_acc(outputs, labels)
        self.log("test_loss", loss)
        self.log('test_acc', self.test_acc)
        return {"test_loss":loss}
    
    def validation_epoch_end(self,outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {"val_loss":avg_loss}

        return {"val_lss":avg_loss,"log":tensorboard_logs}


In [None]:
train_model = LitModel(torch_model_2)
epochs=1
early_stop_callback = EarlyStopping(monitor="val_acc", min_delta=0.00, patience=5, verbose=False, mode="max")

trainer = Trainer(max_epochs=epochs,fast_dev_run=False,accelerator="gpu",callbacks=[early_stop_callback])

trainer.fit(train_model,train_loader,val_loader)
