In [1]:
# install dependencies
!apt-get update && apt-get install -y libnvinfer-dev libnvinfer-plugin-dev
!pip install onnx onnxruntime tensorrt onnx_graphsurgeon fastapi uvicorn nest-asyncio
!pip install fastapi uvicorn pyngrok pycuda pillow numpy tensorrt python-multipart

0% [Working]            Get:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]
0% [Connecting to archive.ubuntu.com (91.189.91.83)] [Waiting for headers] [1 InRelease 3,632 B/3,630% [Connecting to archive.ubuntu.com (91.189.91.83)] [Waiting for headers] [Connected to r2u.stat.il                                                                                                    Hit:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Get:3 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Hit:4 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:5 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Get:6 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:7 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [127 kB]
Get:8 https://r2u.stat.illinois.edu/ubuntu jammy/main all Packages [8,802 kB]
Get:9 http://security.ubuntu.com/ubuntu jammy-security/universe amd64 

In [2]:
import tensorrt as trt

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
onnx_model_path = "/content/resnet50_dog_cat.onnx"
trt_engine_path = "/content/resnet50_dog_cat.trt"

def build_engine(onnx_file):
    with trt.Builder(TRT_LOGGER) as builder, \
         builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) as network, \
         trt.OnnxParser(network, TRT_LOGGER) as parser:

        config = builder.create_builder_config()
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)  # 1GB workspace

        # Parse ONNX model
        with open(onnx_file, "rb") as model:
            if not parser.parse(model.read()):
                for error in range(parser.num_errors):
                    print(parser.get_error(error))
                return None

        # Set the input tensor shape explicitly
        input_tensor = network.get_input(0)
        input_tensor.shape = (1, 3, 224, 224)

        # TensorRT 8+ uses `build_serialized_network`
        serialized_engine = builder.build_serialized_network(network, config)
        if serialized_engine is None:
            print("❌ Failed to build serialized engine!")
            return None

        return serialized_engine

# Build and save engine
engine = build_engine(onnx_model_path)
if engine:
    with open(trt_engine_path, "wb") as f:
        f.write(engine)
    print("✅ TensorRT Engine Saved Successfully!")
else:
    print("❌ Failed to create TensorRT engine.")


✅ TensorRT Engine Saved Successfully!


In [3]:
!pip uninstall -y pyngrok
!pip install pyngrok
!ngrok authtoken 2uymyL6OOtXASHSQm7qEaluVbMn_MKxp9Q4677Jq3vKYcovN


Found existing installation: pyngrok 7.2.3
Uninstalling pyngrok-7.2.3:
  Successfully uninstalled pyngrok-7.2.3
Collecting pyngrok
  Using cached pyngrok-7.2.3-py3-none-any.whl.metadata (8.7 kB)
Using cached pyngrok-7.2.3-py3-none-any.whl (23 kB)
Installing collected packages: pyngrok
Successfully installed pyngrok-7.2.3
Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml


In [5]:
import torch
import numpy as np
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException
import torchvision.transforms as transforms
from pyngrok import ngrok
import uvicorn
import threading
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit

# TensorRT logger
TRT_LOGGER = trt.Logger()

# Load TensorRT engine
ENGINE_PATH = "/content/resnet50_dog_cat.trt"

def load_engine(engine_path):
    with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
        return runtime.deserialize_cuda_engine(f.read())

engine = load_engine(ENGINE_PATH)
context = engine.create_execution_context()

# Define class labels
CLASS_NAMES = ["Cat", "Dog"]

app = FastAPI()

# Define image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def preprocess_image(image: Image.Image) -> np.ndarray:
    """Preprocess the image to match model input size"""
    image = transform(image).numpy().astype(np.float32)
    image = np.expand_dims(image, axis=0)  # Add batch dimension
    return image

def infer_tensorrt(input_tensor):
    """Run inference using TensorRT"""
    input_tensor = input_tensor.astype(np.float32)
    d_input = cuda.mem_alloc(input_tensor.nbytes)
    d_output = cuda.mem_alloc(engine.get_binding_shape(1).volume() * input_tensor.dtype.itemsize)
    bindings = [int(d_input), int(d_output)]
    stream = cuda.Stream()

    cuda.memcpy_htod_async(d_input, input_tensor, stream)
    context.execute_async_v2(bindings, stream.handle, None)
    output = np.empty(engine.get_binding_shape(1), dtype=np.float32)
    cuda.memcpy_dtoh_async(output, d_output, stream)
    stream.synchronize()
    return output

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
    try:
        # Read and process the image
        image_bytes = await file.read()
        with Image.open(io.BytesIO(image_bytes)) as image:
            image = image.convert("RGB")

        input_tensor = preprocess_image(image)
        output = infer_tensorrt(input_tensor)

        predicted_index = np.argmax(output)
        predicted_class = CLASS_NAMES[predicted_index]
        confidence_score = float(output[predicted_index])
        confidence_score = min(max(confidence_score, 0), 1)

        return {
            "prediction": predicted_class,
            "confidence": round(confidence_score, 5)
        }

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")

# Expose FastAPI using ngrok
ngrok_tunnel = ngrok.connect(8000)
print("Public URL:", ngrok_tunnel.public_url)

def run():
    uvicorn.run(app, host="0.0.0.0", port=8000)

thread = threading.Thread(target=run)
thread.start()


Public URL: https://3d8f-34-125-28-44.ngrok-free.app
