# LiteRT using Google AI Edge for on-device object detection
This notebook is an implementation of converting the YOLO11 object detection model to LiteRT (.tflite) format using Google AI Edge and deploy it on Android for on-device inference.

Developed by [Levi Lin](https://github.com/gy6543721).

#### Step 1: Install dependencies

In [None]:
!pip install ultralytics
!pip install ai-edge-model-explorer
!pip install ai-edge-litert

#### Step 2: Import libraries

In [None]:
from ultralytics import YOLO
from ai_edge_litert.interpreter import Interpreter
from google.colab import files

import model_explorer
import yaml
import json
import cv2
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import random

#### Step 3: Convert YOLO11 model to LiteRT (TF Lite)

In [None]:
# Load the YOLO11 model.
model = YOLO("yolo11n.pt")

# Export the model to LiteRT (TF Lite) format.
model.export(format="tflite")

Download a sample image or load your own image

In [None]:
# Download sample image and video.
!wget https://raw.githubusercontent.com/gy6543721/LiteRT/main/assets/test_image.jpg
!wget https://raw.githubusercontent.com/gy6543721/LiteRT/main/assets/test_image_2.jpg
!wget https://raw.githubusercontent.com/gy6543721/LiteRT/main/assets/test_video.mp4

image = Image.open('test_image_2.jpg')

plt.figure(figsize=(12, 8))
plt.imshow(image)
plt.axis('off')
plt.show()

In [None]:
LITE_RT_EXPORT_PATH = "yolo11n_saved_model/" # @param {type : 'string'}
LITE_RT_MODEL = "yolo11n_float32.tflite" # @param {type : 'string'}

LITE_RT_MODEL_PATH = LITE_RT_EXPORT_PATH + LITE_RT_MODEL

# Load the exported TF Lite model.
litert_model = YOLO(LITE_RT_MODEL_PATH, task = 'detect')

# Input image.
image = 'test_image_2.jpg' # @param {type : 'string'}

# Perform inference on the input image.
result = litert_model(image)
result[0].show()

#### Step 4: Visualize the LiteRT model

In [None]:
model_explorer.visualize(LITE_RT_MODEL_PATH)

#### Step 5: Create labelmap

In [None]:
metadata_file = "metadata.yaml" # @param {type : 'string'}
json_file = "labels.json" # @param {type : 'string'}

metadata_path = LITE_RT_EXPORT_PATH + metadata_file

with open(metadata_path, "r") as file:
    metadata = yaml.safe_load(file)

names = metadata.get("names", {})

with open(json_file, 'w') as file:
  json.dump(names, file, indent=2)

print("Labelmap created.")

Labelmap created.


#### Step 6: Inference the TF Lite model using LiteRT interpreter

In [None]:
# Load the TF Lite model.
interpreter = Interpreter(model_path = LITE_RT_MODEL_PATH)
interpreter.allocate_tensors()

# Get input and output details.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_size = input_details[0]['shape'][1]

print(f"Model input size: {input_size}")
print(f"Output tensor shape: {output_details[0]['shape']}")

#### Step 7: Define utility functions

`load_labels`: Loads the `labels.json` file.

`load_image`: Loads the input image.

`detect`: Run the LiteRT model.

`postprocess_output`: Normalize the bounding box coordinates.

`generate_color_map`: Generates unique colors randomly for each label.

`inference_image`: Inference detection on images.

`inference_video`: Inference detection on videos.

In [None]:
# Load labels.
def load_labels(label_file):
  with open(label_file, 'r') as file:
    return json.load(file)


# Load and preprocess image.
def load_image(image_path, input_size):
  image = cv2.imread(image_path)
  original_height, original_width = image.shape[:2]
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  image = cv2.resize(image, (input_size, input_size))
  image = image / 255.0
  return image, (original_height, original_width)


# Run inference.
def detect(input_data, is_video_frame=False):
    input_size = input_details[0]['shape'][1]

    if is_video_frame:
        original_height, original_width = input_data.shape[:2]
        image = cv2.cvtColor(input_data, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (input_size, input_size))
        image = image / 255.0
    else:
        image, (original_height, original_width) = load_image(input_data, input_size)

    interpreter.set_tensor(input_details[0]['index'], np.expand_dims(image, axis=0).astype(np.float32))
    interpreter.invoke()

    output_data = [interpreter.get_tensor(detail['index']) for detail in output_details]
    return output_data, (original_height, original_width)



# Postprocess the output.
def postprocess_output(output_data, original_dims, labels, confidence_threshold):
  output_tensor = output_data[0]
  detections = []
  original_height, original_width = original_dims

  for i in range(output_tensor.shape[1]):
    box = output_tensor[0, i, :4]
    confidence = output_tensor[0, i, 4]
    class_id = int(output_tensor[0, i, 5])

    if confidence > confidence_threshold:
      x_min = int(box[0] * original_width)
      y_min = int(box[1] * original_height)
      x_max = int(box[2] * original_width)
      y_max = int(box[3] * original_height)

      label_name = labels.get(str(class_id), "Unknown")

      detections.append({
          "box": [y_min, x_min, y_max, x_max],
          "score": confidence,
          "class": class_id,
          "label": label_name
      })

  return detections


# Generate color map for labels.
def generate_color_map(labels):
  color_map = {}
  for label in labels.values():
      color_map[label] = [random.randint(0, 255) for _ in range(3)]
  return color_map


# Inference on image.
def inference_image(image_path, detections, color_map):
  image = cv2.imread(image_path)
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

  for detection in detections:
    box = detection['box']
    label = detection['label']
    score = detection['score']
    color = color_map[label]

    y_min, x_min, y_max, x_max = box
    cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 3)

    text = f'{label}: {score:.2f}'
    font_scale = 1
    text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 3)[0]

    label_start = (x_min, y_min - text_size[1] - 10)
    label_end = (x_min + text_size[0], y_min)

    cv2.rectangle(image, label_start, label_end, color, -1)

    text_position = (x_min, y_min - 5)
    cv2.putText(image, text, text_position, cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), 2)

  output_image_path = 'output_' + image_path
  output_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Show the image.
  plt.figure(figsize=(12, 8))
  plt.imshow(image)
  plt.axis('off')
  plt.show()

  return output_image


# Inference on video.
def inference_video(frame, detections, color_map, out):
    # Convert the frame to RGB for processing.
    image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    for detection in detections:
        box = detection['box']
        label = detection['label']
        score = detection['score']
        color = color_map[label]

        y_min, x_min, y_max, x_max = box
        cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 3)

        text = f'{label}: {score:.2f}'
        font_scale = 0.5
        text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 3)[0]

        label_start = (x_min, y_min - text_size[1] - 10)
        label_end = (x_min + text_size[0], y_min)

        cv2.rectangle(image, label_start, label_end, color, -1)

        text_position = (x_min, y_min - 5)
        cv2.putText(image, text, text_position, cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), 1)

    output_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    out.write(output_image)

#### Step 8: Visualize the inference on image and video

In [None]:
label_file = 'labels.json' # @param {type : 'string'}
input_type = 'image' # @param ['image', 'video']
image_path = 'test_image_2.jpg' # @param {type : 'string'}
video_path = 'test_video.mp4' # @param {type : 'string'}
confidence_threshold = 0.4 # @param {type : 'slider', min:0, max:1, step: 0.1}

labels = load_labels(label_file)
color_map = generate_color_map(labels)

In [None]:
if input_type == 'image':
    output_data, original_dims = detect(image_path)
    detections = postprocess_output(output_data, original_dims, labels, confidence_threshold)
    output_img = inference_image(image_path, detections, color_map)
    output_image = 'output_' + image_path
    cv2.imwrite(output_image, output_img)
    print(f"Image saved as {output_image}")

else:
    cap = cv2.VideoCapture(video_path)
    output_video = 'output_' + video_path.replace('mp4', 'avi')
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = cv2.VideoWriter(output_video, fourcc, 20.0, (int(cap.get(3)), int(cap.get(4))))

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        output_data, original_dims = detect(frame, is_video_frame=True)
        detections = postprocess_output(output_data, original_dims, labels, confidence_threshold)

        inference_video(frame, detections, color_map, out)

    cap.release()
    out.release()
    cv2.destroyAllWindows()
    print(f"Output video saved as {output_video}")

#### Step 9: Download output image and video (optional)

In [None]:
# Download output image.
files.download(output_image)

# Download output video.
files.download(output_video)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

#### Step 10: Download the LiteRT model

Download the exported LiteRT model for on-device deployment.

In [None]:
files.download(LITE_RT_MODEL_PATH)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>