# Convert PyTorch Model to TensorFlow Lite (TFLite)

This notebook guides you through converting a trained PyTorch model to a mobile-friendly TensorFlow Lite model. It uses ONNX as an interchange format and `onnx2tf` for TensorFlow export, then produces a `.tflite` file ready for deployment.



## What you'll do

- Export the PyTorch model to ONNX

- Simplify/optimize the ONNX graph

- Convert ONNX to TensorFlow SavedModel (`onnx2tf`)

- Convert SavedModel to TFLite with optional optimizations/quantization

- (Optional) Run quick sanity checks and compare outputs



In [9]:
!pip install -q torch torchvision onnx onnx2tf tensorflow onnxscript

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/693.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m686.1/693.4 kB[0m [31m28.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m693.4/693.4 kB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/139.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.1/139.1 kB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
!pip install -q sng4onnx onnxsim onnx_graphsurgeon simple_onnx_processing_tools # Helpers for onnx2tf
print("✅ Libraries Installed")

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for onnxsim (setup.py) ... [?25l[?25hdone
✅ Libraries Installed


In [5]:
from google.colab import files
import os

if not os.path.exists("mediscan_resnet50.pt"):
    print("Please upload your local 'mediscan_resnet50.pt' file:")
    uploaded = files.upload()
else:
    print("✅ Model file already exists!")

Please upload your local 'mediscan_resnet50.pt' file:


Saving mediscan_resnet50.pt to mediscan_resnet50.pt


**Re-Define and Load the Model**

load the PyTorch model exactly as defined it in the backend.

In [6]:
import torch
import torch.nn as nn
from torchvision import models

#  Define Architecture (Must match exactly)
model = models.resnet50(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_ftrs, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, 2)
)

#  Load Weights
device = torch.device('cpu')
model.load_state_dict(torch.load("mediscan_resnet50.pt", map_location=device))
model.eval()
print("✅ PyTorch Model Loaded")



✅ PyTorch Model Loaded


**Export to ONNX**

ONNX is the "*Universal Translator*" for AI models.

In [10]:
# Create dummy input of the correct shape (Batch_Size, Channels, Height, Width)
dummy_input = torch.randn(1, 3, 224, 224)

onnx_path = "mediscan.onnx"

torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    verbose=False,
    input_names=['input'],
    output_names=['output'],
    opset_version=13 # Compatible version for mobile
)
print(f"✅ Exported to {onnx_path}")

W0109 17:15:54.172000 1174 torch/onnx/_internal/exporter/_compat.py:114] Setting ONNX exporter to use operator set version 18 because the requested opset_version 13 is a lower version than we have implementations for. Automatic version conversion will be performed, which may not be successful at converting to the requested version. If version conversion is unsuccessful, the opset version of the exported model will be kept at 18. Please consider setting opset_version >=18 to leverage latest ONNX features
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/__init__.py", line 127, in call
    converted_proto = _c_api_utils.call_onnx_api(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/_c_api_utils.py", line 65, in call_onnx_api
    result = func(proto)
             ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter

Applied 106 of general pattern rewrite rules.
✅ Exported to mediscan.onnx


**Convert ONNX to TensorFlow Lite**

Use onnx2tf to convert the ONNX file into a quantized TFLite file. Quantization reduces the size (e.g., from 100MB to 25MB) so it fits on mobile phones easily.

In [None]:
import os
import subprocess # Import subprocess for better command execution control

# Convert ONNX to TensorFlow SavedModel
print(f"Converting ONNX ({onnx_path}) to TensorFlow SavedModel...")
command = f'onnx2tf -i {onnx_path} -o saved_model_tf'
result = subprocess.run(command, shell=True, capture_output=True, text=True)

if result.returncode == 0:
    print("✅ ONNX to TensorFlow conversion command executed.")
    if result.stdout:
        print("onnx2tf Output:\n", result.stdout)
    if result.stderr:
        print("onnx2tf Warnings/Errors:\n", result.stderr)
else:
    print(f"❌ ONNX to TensorFlow conversion command failed with exit code {result.returncode}.")
    print("onnx2tf Output:\n", result.stdout)
    print("onnx2tf Errors:\n", result.stderr)
    raise RuntimeError("ONNX to TensorFlow conversion failed. Please check the logs above.")

# Check if the saved_model_tf directory exists after conversion attempt
saved_model_dir = 'saved_model_tf'
if not os.path.exists(saved_model_dir) or not os.listdir(saved_model_dir):
    raise OSError(f"SavedModel directory '{saved_model_dir}' not found or is empty after onnx2tf conversion.")

print(f"✅ TensorFlow SavedModel created at '{saved_model_dir}'.")

# Convert SavedModel to TFLite
import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)

# OPTIMIZATION: Quantize weights to Float16 (Good balance of speed/accuracy)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]

tflite_model = converter.convert()

# Save the file
tflite_path = "mediscan_model.tflite"
with open(tflite_path, 'wb') as f:
    f.write(tflite_model)

print(f"✅ TFLite Model Saved: {tflite_path}")
print(f"Size: {os.path.getsize(tflite_path) / 1024 / 1024:.2f} MB")

**Download the TFLite File**

Save this file to local computer.

In [None]:
from google.colab import files
files.download('mediscan_model.tflite')