In [None]:
import torch
import tensorflow as tf
from models.tf_models.CifarResnet_tf import get_CifarResnet18
from models.tf_models.tf_models import tf_resnet50,tf_densenet121,tf_inception_resnet_v2, tf_resnet18
from utils.tf_datasets import get_cifar_data,get_imagenette_data
import numpy as np

In [None]:
print(tf.config.list_physical_devices('GPU'))

# step 1 Load dataset and tf models and test accuracy

### Note: torch-onnx-tf-onnx is different from tf-onnx

## load dataset

In [None]:
# cifar
#x_train,y_train,x_val,y_val,x_test,y_test=get_cifar_data()
# imagenette
train_ds,val_ds,test_ds = get_imagenette_data(batch_size=32)

## load model and train

In [None]:
#model = tf_resnet50(num_classes=10,retrain=False)
model = tf_resnet18(num_classes=10)
#model = tf_densenet121(num_classes=10,retrain=False)
#model = tf_inception_resnet_v2(num_classes=10,retrain=False)

In [None]:
#model = get_Resnet18(num_classes=10)
model.summary()

In [None]:

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=2)
#history = model.fit(x_train, y_train, epochs=EPOCHS, validation_data = (x_val, y_val), batch_size=128,callbacks=[callback])
model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=3,
    callbacks=[callback]
)

In [None]:
print("Evaluate on test data")
#results = model.evaluate(x_test, y_test, batch_size=128)
results = model.evaluate(test_ds)
print("test loss, test acc:", results)

In [None]:
# save as .hs file
save_path = "saved_models/tf2torch/resnet18.h5"
#save_path = "saved_models/tf2torch/resnet50.h5"
#save_path = "saved_models/tf2torch/densenet121.h5"
#save_path = "saved_models/tf2torch/inception_resnet_v2.h5"
model.save(save_path)
# save as saved_model


# step 2 convert from tensorflow to onnx

In [None]:
import tf2onnx
import onnxruntime as rt

spec = (tf.TensorSpec((None, 224, 224, 3), tf.float32, name="input"),)
#output_path = "saved_models/tf2torch/resnet50.onnx"
#output_path = "saved_models/tf2torch/densenet121.onnx"
#output_path = "saved_models/tf2torch/inception_resnet_v2.onnx"
output_path = "saved_models/tf2torch/resnet18.onnx"
model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13, output_path=output_path)
output_names = [n.name for n in model_proto.graph.output]

In [None]:
output_names

# step 3 test onnx model using onnxruntime 

In [None]:
import onnx
import onnxruntime as ort
load_path_onnx = "saved_models/tf2torch/resnet50.onnx"
#load_path_onnx = "saved_models/tf2torch/densenet121.onnx"
#load_path_onnx = "saved_models/tf2torch/inception_resnet_v2.onnx"
#load_path_onnx = "saved_models/tf2torch/resnet18.onnx"
onnx_model = onnx.load(load_path_onnx)


In [None]:
onnx.checker.check_model(onnx_model)


In [None]:
"""
for cifar10
def get_accuracy(ort_sess):
    _correct=0
    _all=y_test.shape[0]
    batch_size=256
    num_batches = _all//batch_size
    for i in range(num_batches):
        if i!=num_batches-1:
            output=ort_sess.run(output_names=output_names,input_feed={"input":x_test[i*batch_size:(i+1)*batch_size]})
            pred = np.argmax(output[0],axis=1).reshape(-1,1)
            _correct+=(pred==y_test[i*batch_size:(i+1)*batch_size]).sum()
            
        else:
            output=ort_sess.run(output_names=output_names,input_feed={"input":x_test[i*batch_size:]})
            pred = np.argmax(output[0],axis=1).reshape(-1,1)
            _correct+=(pred==y_test[i*batch_size:]).sum()
    return _correct/_all
"""
# for imagenette
def get_accuracy(ort_sess):
    _correct = 0
    _all = 0
    for imgs, labels in test_ds:
        _all+=len(labels)
        labels = np.argmax(labels,axis=1).reshape(-1,1)
        output=ort_sess.run(output_names=output_names,input_feed={"input":imgs.numpy()})
        pred = np.argmax(output[0],axis=1).reshape(-1,1)
        _correct+= (labels==pred).sum()
    #print(f"accuracy of onnx model: {_correct/_all}")
    return _correct/_all
        

In [None]:
#ort_sess = ort.InferenceSession('saved_models/tf2torch/resnet50.onnx')
#ort_sess = ort.InferenceSession('saved_models/tf2torch/densenet121.onnx')
#ort_sess = ort.InferenceSession('saved_models/tf2torch/inception_resnet_v2.onnx')
ort_sess = ort.InferenceSession('saved_models/tf2torch/resnet18.onnx')
# Print accuracy Result
onnx_acc = get_accuracy(ort_sess)

print(f"accuracy of onnx model from tf: {onnx_acc}")


# step 4 convert onnx to torch and test accuracy

In [None]:
from onnx2torch import convert
import onnx

In [None]:
# reduce batch_size
train_ds,val_ds,test_ds = get_imagenette_data(batch_size=8)
load_path_onnx = "saved_models/tf2torch/resnet18.onnx"
onnx_model = onnx.load(load_path_onnx)
pytorch_model = convert(onnx_model)
pytorch_model.eval()

In [None]:
_correct=0
_all=0
for imgs,labels in test_ds:
    _all+=len(labels)
    labels=np.argmax(labels,axis=1)
    out=pytorch_model(torch.from_numpy(imgs.numpy()))
    pred = torch.argmax(out,axis=1).numpy()
    _correct+=(pred==labels).sum()
print(f"pytorch model acc: {_correct/_all}")

In [None]:
# test training converted torch model

In [None]:
# train torch model
import torchmetrics
import pytorch_lightning as pl
import torch.nn.functional as F
learning_rate = 1e-2

class LitModel(pl.LightningModule):
    def __init__(self,model):
        super(LitModel,self).__init__()
        self.model=model
        self.test_acc = torchmetrics.Accuracy()
        self.valid_acc = torchmetrics.Accuracy()
    def forward(self,x):
        return self.model(x)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(),lr=learning_rate)
        return optimizer
    
    def training_step(self,batch,batch_idx):
        images, labels = batch
        images = images.permute(0,2,3,1)
        outputs = self(images)
        loss = F.cross_entropy(outputs,labels)
        self.log("train_loss", loss)
        return {"loss":loss}
    
    def validation_step(self,batch,batch_idx):
        images, labels = batch
        images = images.permute(0,2,3,1)
        outputs = self(images)
        loss = F.cross_entropy(outputs,labels)
        self.valid_acc(outputs, labels)
        self.log("val_loss", loss)
        self.log("val_acc",self.valid_acc)
        return {"val_loss":loss}

    def test_step(self,batch,batch_idx):
        images, labels = batch
        images = images.permute(0,2,3,1)
        outputs = self(images)
        loss = F.cross_entropy(outputs,labels)
        self.test_acc(outputs, labels)
        self.log("test_loss", loss)
        self.log('test_acc', self.test_acc)
        return {"test_loss":loss}
    
    def validation_epoch_end(self,outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {"val_loss":avg_loss}

        return {"val_lss":avg_loss,"log":tensorboard_logs}

In [None]:
from pytorch_lightning import Trainer
train_model = LitModel(pytorch_model)
epochs=1
from utils.data_loaders import get_imagenette_loader
train_loader,val_loader,test_loader = get_imagenette_loader(batch_size=4)

trainer = Trainer(max_epochs=epochs,fast_dev_run=True,accelerator="cpu")

trainer.fit(train_model,train_loader,val_loader)

# step 5 Convert onnx model back to tf models and train

In [None]:
from onnx_tf.backend import prepare
import tensorflow as tf
import onnx_tf
load_path_onnx = "saved_models/tf2torch/resnet18.onnx"
onnx_model = onnx.load(load_path_onnx)
tf_rep = prepare(onnx_model)

In [None]:
# save and load
save_path = "saved_models/tf2torch/resnet18"
tf_rep.export_graph(save_path)

In [None]:
loaded=tf.saved_model.load(save_path)
print(list(loaded.signatures.keys())) 
infer = loaded.signatures["serving_default"]
key=list(infer.structured_outputs.keys())[0]

In [None]:
_all=0
_correct=0
for imgs,labels in test_ds:
    out = infer(**{'input': imgs})
    pred = np.argmax(out[key],axis=1)
    _all+=len(labels)
    _correct+=(pred==labels.numpy()).sum()
print(f"accuracy:{_correct/_all}")