# Train and deploy a Faster R-CNN Object Detection model

<a target="_blank" href="https://colab.research.google.com/github/unionai-oss/faster-rcnn-object-detection-computer-vision-train-and-deploy/blob/main/tutorial.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>



### Setup

In [None]:
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    !git clone https://github.com/unionai-oss/faster-rcnn-object-detection-computer-vision-train-and-deploy
    %cd faster-rcnn-object-detection-computer-vision-train-and-deploy
    !pip install -r requirements.txt

### 🔐 Authenticate
To use **Union.ai**, you'll need to authenticate your account. Follow the appropriate step based on your setup:  

##### 🔸 **Using Union BYOC Enterprise**  

If you're using a **[Union BYOC Enterprise](https://www.union.ai/pricing)** account, log in with the following command:  
```bash
union create login --host <union-host-url>
```

Replace <union-host-url> with your organization's Union instance URL.

##### 🔸 Using Union Serverless
If you're using [Union Serverless](https://www.union.ai/) , authenticate by running the command below in the code cell:  

Create an account for free at [Union.ai](https://union.ai) if you don't have one yet:
 

In [None]:
# 🌟 Authenticate to union serverless
!union create login --serverless --auth device-flow

## Training Faster RCNN Object Detetection Model Pipeline

Run the command below to train a Faster RCNN Object Detection model using the Union.ai CLI. This command will create a new pipeline and start the training process.

The first time you this command it will take a while to download the model and set up the environment.

The subsequent runs will be faster as the container, model, and data will be cached.


In [None]:
# 👇 Run this command to start the training workflow & container building
!union run --remote workflows/train-frcnn-pipeline.py faster_rcnn_train_workflow --epochs 3

### 🔎 Explore the Code  

- The command above is using files from the [`workflows/`](workflows/train-frcnn-pipeline.py) and [`tasks`](tasks/) folders that got cloned on setup.

- The codeis added to this notebook for reference with the `%%writefile` magic command to overwrite the files if you want to make changes.

- You do not need to run the code cells with `%%writefile` unless you want to make changes to the pipeline or tasks.


In [None]:
%%writefile workflows/train-frcnn-pipeline.py

from union import workflow

from tasks.data import download_hf_dataset, verify_data_and_annotations
from tasks.model import download_model, evaluate_model, train_model, upload_model_to_hub


# %% ------------------------------
# Object Detection Workflow
# --------------------------------
@workflow
def faster_rcnn_train_workflow(
    epochs: int = 3, classes: int = 3, hf_repo_id: str = ""
) -> None:

    dataset_dir = download_hf_dataset(
        repo_id="sagecodes/union_flyte_swag_object_detection"
    )
    model_file = download_model()
    verify_data_and_annotations(dataset_dir=dataset_dir)
    trained_model = train_model(
        model_file=model_file,
        dataset_dir=dataset_dir,
        num_epochs=epochs,
        num_classes=classes,
    )
    evaluate_model(model=trained_model, dataset_dir=dataset_dir)
    # upload_model_to_hub(model=trained_model, repo_name=hf_repo_id) # uncomment to upload the model to Hugging Face Hub


# union run --remote workflows/train-frcnn-pipeline.py faster_rcnn_train_workflow --epochs 3
# union run --remote workflows/train-frcnn-pipeline.py faster_rcnn_train_workflow --epochs 3 --hf_repo_id "sagecodes/cv-object-rcnn"

In [None]:
%%writefile containers.py
from flytekit import ImageSpec, Resources
# from union.actor import ActorEnvironment

container_image = ImageSpec(
     name="fine-tune-qlora",
    requirements="requirements.txt",
    pip_extra_index_url=["https://download.pytorch.org/whl/cu118"],  #enables +cu118 builds
    builder="union",
    cuda="11.8",  # ensure GPU + CUDA layer is available
    apt_packages=["gcc", "g++"],  # optional, for packages like quantization 
)


In [None]:
%%writefile requirements.txt
union==0.1.181
flytekit==1.15.4
torch==2.5.1
torchvision==0.20.1
matplotlib==3.10.3
pycocotools==2.0.8
datasets==2.14.4
huggingface_hub
python-dotenv==1.1.0
opencv-python==4.11.0.86

In [None]:
%%writefile tasks/data.py

"""
this module contains the data loading and preprocessing functions
"""

import os
from textwrap import dedent

from dotenv import load_dotenv
from flytekit import Deck, Resources, current_context, task
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile

from containers import container_image
from tasks.helpers import dataset_dataloader, image_to_base64

load_dotenv()


# %% ------------------------------
# Download dataset - task
# --------------------------------
@task(
    container_image=container_image,
    enable_deck=True,
    cache=True,
    cache_version="1.333",
    requests=Resources(cpu="2", mem="2Gi"),
)
def download_hf_dataset(
    repo_id: str = "sagecodes/union_swag_coco",
    local_dir: str = "dataset",
    sub_folder: str = "swag",
) -> FlyteDirectory:

    from huggingface_hub import snapshot_download

    if local_dir:
        dataset_dir = os.path.join(local_dir)
        os.makedirs(dataset_dir, exist_ok=True)

    # Download the dataset repository
    repo_path = snapshot_download(
        repo_id=repo_id, repo_type="dataset", local_dir=local_dir
    )
    if sub_folder:
        repo_path = os.path.join(repo_path, sub_folder)
        # use sub_folder to return a specific folder from the dataset

    print(f"Dataset downloaded to {repo_path}")

    print(f"Files in dataset directory: {os.listdir(repo_path)}")

    return FlyteDirectory(repo_path)


# %% ------------------------------
# visualize data - task
# --------------------------------
@task(
    container_image=container_image,
    enable_deck=True,
    requests=Resources(cpu="2", mem="4Gi"),
)
def verify_data_and_annotations(dataset_dir: FlyteDirectory) -> FlyteFile:

    import matplotlib.patches as patches
    import matplotlib.pyplot as plt

    # Download the dataset locally from the FlyteDirectory
    dataset_dir.download()
    local_dataset_dir = dataset_dir.path

    # Load the dataset
    data_loader = dataset_dataloader(
        root=local_dataset_dir, annFile="train.json", shuffle=True
    )

    # Number of images to display
    num_images = 9
    fig, axes = plt.subplots(3, 3, figsize=(15, 15))  # Create a 3x3 grid
    axes = axes.flatten()  # Flatten the axes array for easier iteration

    images_plotted = 0  # Counter for images plotted

    # Plot images along with annotations
    for batch_idx, (images, targets) in enumerate(data_loader):
        for i, image in enumerate(images):
            if images_plotted >= num_images:
                break  # Limit to 9 images

            # Plot the image
            img = image.cpu().permute(
                1, 2, 0
            )  # Convert image to HWC format for plotting
            ax = axes[images_plotted]  # Access the correct subplot
            ax.imshow(img)

            # Iterate over the list of annotations (objects) for the current image
            for annotation in targets[i]:
                # Extract the bounding box
                bbox = annotation[
                    "bbox"
                ]  # This is in [x_min, y_min, width, height] format

                # Convert [x_min, y_min, width, height] to [x_min, y_min, x_max, y_max]
                x_min, y_min, width, height = bbox
                x_max = x_min + width
                y_max = y_min + height

                # Draw the bounding box
                rect = patches.Rectangle(
                    (x_min, y_min),
                    width,
                    height,
                    linewidth=2,
                    edgecolor="r",
                    facecolor="none",
                )
                ax.add_patch(rect)

            # Increment image counter
            images_plotted += 1

        if images_plotted >= num_images:
            break  # Stop if we've plotted the desired number of images

    plt.tight_layout()

    # Save the grid of images and annotations
    output_img = "data_verification_grid.png"
    plt.savefig(output_img)
    plt.close()

    # Convert the image to base64 for display in FlyteDeck
    verification_image_base64 = image_to_base64(output_img)

    # Display the results in FlyteDeck
    ctx = current_context()
    deck = Deck("Data Verification")
    html_report = dedent(
        f"""
    <div style="font-family: Arial, sans-serif; line-height: 1.6;">
       <h2 style="color: #2C3E50;">Data Verification: Images and Annotations</h2>
        <img src="data:image/png;base64,{verification_image_base64}" width="600">
    </div>
    """
    )

    # Append the HTML content to the deck
    deck.append(html_report)
    ctx.decks.insert(0, deck)

    # Return the image file for further use in the workflow
    return FlyteFile(output_img)



In [None]:
%%writefile tasks/model.py

# %%
import base64
import os
from textwrap import dedent

import matplotlib.patches as patches
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.optim as optim
import torchvision
from dotenv import load_dotenv
from flytekit import Deck, Resources, Secret, current_context, task
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
from torchvision.models.detection.faster_rcnn import (
    FasterRCNN_MobileNet_V3_Large_320_FPN_Weights,
)
from torchvision.ops import box_iou
from torchvision.transforms import transforms as T
from typing_extensions import Annotated
from union import Artifact

from containers import container_image
from tasks.helpers import dataset_dataloader, image_to_base64

load_dotenv()

# Define Artifacts
FRCCNPreTrainedModel = Artifact(name="frccn_pretrained_model")
FRCCNFineTunedModel = Artifact(name="frccn_fine_tuned_model")


# %% ------------------------------
# donwload model - task
# --------------------------------
@task(
    container_image=container_image,
    cache=True,
    cache_version="1.334",
    requests=Resources(cpu="2", mem="2Gi"),
)
def download_model() -> Annotated[FlyteFile, FRCCNPreTrainedModel]:

    model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(
        weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights, weights_only=True
    )

    save_dir = "frccn_mobilenet_pretrained_model.pth"
    torch.save(model, save_dir)

    # return model
    return FRCCNPreTrainedModel.create_from(save_dir)


# %% ------------------------------
# train model - task
# --------------------------------
@task(
    container_image=container_image,
    enable_deck=True,
    requests=Resources(cpu="2", mem="8Gi", gpu="1"),
)
def train_model(
    model_file: FlyteFile,
    dataset_dir: FlyteDirectory,
    num_epochs: int,
    num_classes: int = 2,
    conf_thresh: float = 0.75,
    validate_every_n_epochs: int = 1,
) -> Annotated[FlyteFile, FRCCNFineTunedModel]:

    num_classes = num_classes + 1  # + 1 background)
    print(f"Using confidence threshold: {conf_thresh} for evaluation")

    num_epochs = num_epochs
    best_mean_iou = 0
    model_dir = "models"

    device = "cuda" if torch.cuda.is_available() else "cpu"

    dataset_dir.download()
    os.makedirs(model_dir, exist_ok=True)
    local_dataset_dir = dataset_dir.path  # Use the local path for FlyteDirectory
    data_loader = dataset_dataloader(root=local_dataset_dir, annFile="train.json")
    test_data_loader = dataset_dataloader(root=local_dataset_dir, annFile="train.json")

    # Load pretrained model
    model = torch.load(model_file, map_location="cpu", weights_only=False)

    # Modify the model to add a new classification head based on the number of classes
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = (
        torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
            in_features, num_classes
        )
    )

    model.to(device)

    # Define optimizer and learning rate
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)

    # Learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

    def evaluate_model(model, data_loader):
        model.eval()
        iou_list = []
        correct_predictions, total_predictions = 0, 0

        with torch.no_grad():
            for images, targets in data_loader:
                images = [img.to(device) for img in images]
                targets = [
                    {
                        "boxes": torch.tensor(
                            [obj["bbox"] for obj in t], dtype=torch.float32
                        ).to(device),
                        "labels": torch.tensor(
                            [obj["category_id"] for obj in t], dtype=torch.int64
                        ).to(device),
                    }
                    for t in targets
                ]
                for t in targets:
                    boxes = t["boxes"]
                    boxes[:, 2] += boxes[:, 0]
                    boxes[:, 3] += boxes[:, 1]
                    t["boxes"] = boxes

                outputs = model(images)

                for i, output in enumerate(outputs):
                    if "scores" not in output:
                        continue

                    keep = output["scores"] > conf_thresh
                    pred_boxes = output["boxes"][keep]
                    pred_labels = output["labels"][keep]
                    true_boxes = targets[i]["boxes"]
                    true_labels = targets[i]["labels"]

                    if pred_boxes.size(0) == 0 or true_boxes.size(0) == 0:
                        continue

                    iou = box_iou(pred_boxes, true_boxes)
                    iou_list.append(iou.max(dim=1)[0].mean().item())  # best-match IoU

                    # Accuracy: match predictions to true labels using best IoU
                    max_iou_indices = iou.argmax(dim=1)
                    matched_true_labels = true_labels[max_iou_indices]
                    correct_predictions += (
                        (pred_labels == matched_true_labels).sum().item()
                    )
                    total_predictions += len(pred_labels)

        mean_iou = sum(iou_list) / len(iou_list) if iou_list else 0
        accuracy = correct_predictions / total_predictions if total_predictions else 0
        print(f"Mean IoU: {mean_iou:.4f}, Accuracy: {accuracy:.4f}", flush=True)

        # TODO: save the model if mean_iou > best_mean_iou or add early stopping

        return mean_iou, accuracy

    epoch_logs = []

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        for i, (images, targets) in enumerate(data_loader):
            images = [image.to(device) for image in images]
            targets = [
                {
                    "boxes": torch.tensor(
                        [obj["bbox"] for obj in t], dtype=torch.float32
                    ).to(device),
                    "labels": torch.tensor(
                        [obj["category_id"] for obj in t], dtype=torch.int64
                    ).to(device),
                }
                for t in targets
            ]
            for target in targets:
                boxes = target["boxes"]
                boxes[:, 2] += boxes[:, 0]
                boxes[:, 3] += boxes[:, 1]
                target["boxes"] = boxes

            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            total_loss += losses.item()

            if i % 1 == 0:
                print(
                    f"Epoch [{epoch}/{num_epochs}], Step [{i}/{len(data_loader)}], Loss: {losses.item():.4f}",
                    flush=True,
                )

        lr_scheduler.step()

        mean_iou, accuracy = evaluate_model(model, test_data_loader)
        avg_train_loss = total_loss / len(data_loader)

        epoch_logs.append(
            {
                "epoch": epoch + 1,
                "train_loss": avg_train_loss,
                "val_accuracy": accuracy,
                "val_mean_iou": mean_iou,
            }
        )

        if mean_iou > best_mean_iou:
            best_mean_iou = mean_iou
            torch.save(model.state_dict(), os.path.join(model_dir, "best_model.pth"))
            print("Best model saved")

    print("Training completed.")
    model_path = os.path.join(local_dataset_dir, "frccn_finetuned_model.pth")
    # torch.save(model.state_dict(), model_path)
    torch.save(model, model_path)

    df = pd.DataFrame(epoch_logs)
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.plot(df["epoch"], df["train_loss"], label="Train Loss", marker="o")
    ax.plot(df["epoch"], df["val_accuracy"], label="Val Accuracy", marker="s")
    ax.plot(df["epoch"], df["val_mean_iou"], label="Val Mean IoU", marker="^")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Metric")
    ax.set_title("Training Metrics")
    ax.legend()
    ax.grid(True)

    plot_path = os.path.join(model_dir, "training_metrics.png")
    plt.tight_layout()
    plt.savefig(plot_path)
    plt.close()

    # Convert to base64
    def image_to_base64(img_path):
        with open(img_path, "rb") as f:
            return base64.b64encode(f.read()).decode("utf-8")

    plot_base64 = image_to_base64(plot_path)
    deck = Deck("Training Metrics")
    deck.append(
        f"""
    <h2>Training Progress</h2>
    <img src="data:image/png;base64,{plot_base64}" width="600"/>
    <h3>Last Epoch:</h3>
    <pre>{df.tail(1).to_string(index=False)}</pre>
    """
    )
    current_context().decks.insert(0, deck)

    # return model
    return FRCCNFineTunedModel.create_from(model_path)


# %% ------------------------------
# evaluate model - task
# --------------------------------


@task(
    container_image=container_image,
    enable_deck=True,
    requests=Resources(cpu="2", mem="8Gi", gpu="1"),
)
def evaluate_model(
    model: torch.nn.Module, dataset_dir: FlyteDirectory, threshold: float = 0.75
) -> str:

    device = "cuda" if torch.cuda.is_available() else "cpu"

    dataset_dir.download()
    local_dataset_dir = dataset_dir.path
    data_loader = dataset_dataloader(
        root=local_dataset_dir, annFile="train.json", shuffle=False
    )

    model.to(device)
    model.eval()

    num_images = 9  # Number of images to display in the grid
    fig, axes = plt.subplots(3, 3, figsize=(15, 15))
    axes = axes.flatten()

    iou_list, report = [], []
    correct_predictions, total_predictions = 0, 0
    images_plotted = 0
    global_image_index = 0

    with torch.no_grad():
        for batch_idx, (images, targets) in enumerate(data_loader):
            images = [image.to(device) for image in images]
            targets = [
                {
                    "boxes": torch.tensor(
                        [obj["bbox"] for obj in t], dtype=torch.float32
                    ).to(device),
                    "labels": torch.tensor(
                        [obj["category_id"] for obj in t], dtype=torch.int64
                    ).to(device),
                }
                for t in targets
            ]
            for target in targets:
                boxes = target["boxes"]
                boxes[:, 2] += boxes[:, 0]  # Convert width to x_max
                boxes[:, 3] += boxes[:, 1]  # Convert height to y_max
                target["boxes"] = boxes

            outputs = model(images)

            for i, output in enumerate(outputs):
                pred_boxes = output["boxes"]
                pred_scores = output["scores"]
                pred_labels = output["labels"]
                true_boxes = targets[i]["boxes"]
                true_labels = targets[i]["labels"]

                high_conf_indices = pred_scores > threshold
                pred_boxes = pred_boxes[high_conf_indices]
                pred_labels = pred_labels[high_conf_indices]

                image_index = global_image_index + i

                if pred_boxes.size(0) == 0 or true_boxes.size(0) == 0:
                    report.append(
                        f"Image {image_index}: No valid predictions or ground truths"
                    )
                    continue

                iou = box_iou(pred_boxes, true_boxes)
                max_iou_indices = iou.argmax(dim=1)
                matched_true_labels = true_labels[max_iou_indices]

                correct_predictions += (pred_labels == matched_true_labels).sum().item()
                total_predictions += len(pred_labels)

                mean_iou = iou.max(dim=1)[0].mean().item()
                iou_list.append(mean_iou)

                accuracy = (
                    correct_predictions / total_predictions if total_predictions else 0
                )
                report.append(
                    f"Image {image_index}: IoU = {mean_iou:.4f}, Accuracy = {accuracy:.4f}"
                )

                # Plotting only the first 9 images
                if images_plotted < num_images:
                    img = images[i].cpu().permute(1, 2, 0)
                    ax = axes[images_plotted]

                    ax.imshow(img)
                    for j in range(len(pred_boxes)):
                        bbox = pred_boxes[j].cpu().numpy()
                        score = pred_scores[high_conf_indices][j].cpu().item()
                        label = pred_labels[j].cpu().item()

                        if score > threshold:
                            rect = patches.Rectangle(
                                (bbox[0], bbox[1]),
                                bbox[2] - bbox[0],
                                bbox[3] - bbox[1],
                                linewidth=2,
                                edgecolor="r",
                                facecolor="none",
                            )
                            ax.add_patch(rect)
                            ax.text(
                                bbox[0],
                                bbox[1],
                                f"{label}: {score:.2f}",
                                color="white",
                                fontsize=8,
                                bbox=dict(facecolor="red", alpha=0.5),
                            )
                    ax.axis("off")
                    images_plotted += 1

            global_image_index += len(images)

    overall_iou = sum(iou_list) / len(iou_list) if iou_list else 0
    overall_accuracy = (
        correct_predictions / total_predictions if total_predictions else 0
    )

    pred_boxes_imgs = "prediction_grid.png"
    plt.tight_layout()
    plt.savefig(pred_boxes_imgs)
    plt.close()

    train_image_base64 = image_to_base64(pred_boxes_imgs)

    report_text = "\n".join(report)
    overall_report = dedent(
        f"""
    Overall Metrics on predictions with confidence threshold {threshold}:
    ----------------
    Mean IoU: {overall_iou:.4f}
    Mean Accuracy: {overall_accuracy:.4f}

    Per-Image Metrics:
    ------------------
    {report_text}
    """
    )

    ctx = current_context()
    deck = Deck("Evaluation Results")
    html_report = dedent(
        f"""
    <div style="font-family: Arial, sans-serif; line-height: 1.6;">
       <h2 style="color: #2C3E50;">Predicted Bounding Boxes</h2>
        <img src="data:image/png;base64,{train_image_base64}" width="600">
    </div>               
    <div style="font-family: Arial, sans-serif; line-height: 1.6;">
        <h2 style="color: #2C3E50;">Evaluation Report</h2>
        <pre>{overall_report}</pre>
    </div>
    """
    )
    deck.append(html_report)
    ctx.decks.insert(0, deck)

    return overall_report


# %% ------------------------------
# upload model to hub - task
# --------------------------------
@task(
    container_image=container_image,
    requests=Resources(cpu="2", mem="2Gi"),
    secret_requests=[Secret(group=None, key="hf_token")],
)
def upload_model_to_hub(model: torch.nn.Module, repo_name: str) -> str:
    from huggingface_hub import HfApi

    # Get the Flyte context and define the model path
    ctx = current_context()
    model_path = "best_model.pth"  # Save the model locally as "best_model.pth"

    # Save the model's state dictionary
    torch.save(model.state_dict(), model_path)

    # Set Hugging Face token from local environment or Flyte secrets
    hf_token = os.getenv("HF_TOKEN")
    if hf_token is None:
        # If HF_TOKEN is not found, attempt to get it from the Flyte secrets
        hf_token = ctx.secrets.get(key="hf_token")
        print("Using Hugging Face token from Flyte secrets.")
    else:
        print("Using Hugging Face token from environment variable.")

    # Create a new repository (if it doesn't exist) on Hugging Face Hub
    api = HfApi()
    api.create_repo(repo_name, token=hf_token, exist_ok=True)

    # Upload the model to the Hugging Face repository
    api.upload_file(
        path_or_fileobj=model_path,  # Path to the local file
        path_in_repo="pytorch_model.bin",  # Destination path in the repo
        repo_id=repo_name,
        commit_message="Upload Faster R-CNN model",
        token=hf_token,
    )

    return f"Model uploaded to Hugging Face Hub: https://huggingface.co/{repo_name}"



> **💡 Note:**  
> In more complex ML workflows, **data pipelines** are often separate from **model training pipelines**.  
> For simplicity, we'll combine them into a single workflow in this example.  

## Run Model Locally

Lets pull down our same dataset locally to run the model on examples

In [None]:
# lets call our downlaod dataset from earlier locally
!union run tasks/data.py download_hf_dataset

We will pull down the latest Model Artifact from Union and save it locally. 


In [None]:
from union import Artifact, UnionRemote
from flytekit.types.file import FlyteFile
import torch

# --------------------------------------------------
# Download & save the fine-tuned model from Union Artifacts
# --------------------------------------------------
FRCCNFineTunedModel = Artifact(name="frccn_fine_tuned_model")

query = FRCCNFineTunedModel.query(
    project="default",
    domain="development",
    # version="anmrqcq8pfbnlp42j2vp/n3/0/o0"  # Optional: specify version. Will download the latest version if not specified
)
remote = UnionRemote()
artifact = remote.get_artifact(query=query)
model_file: FlyteFile = artifact.get(as_type=FlyteFile)
model = torch.load(model_file.download(), map_location="cpu", weights_only=False)

# We'll also load the model for use
model.eval()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


Let's run the model locally on an example image and draw the bounding boxes on the image. 

In [None]:
import cv2
import numpy as np
import requests
import torch
from flytekit.types.file import FlyteFile
from PIL import Image, ImageDraw, ImageFont
from torchvision.transforms import functional as F
from union import Artifact, UnionRemote
from io import BytesIO

# Define labels map
labels_map = {1: "Union Sticker", 2: "Flyte Sticker"}

# Check and set the available device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")


font_url = "https://github.com/google/fonts/raw/refs/heads/main/apache/ultra/Ultra-Regular.ttf"
response = requests.get(font_url)
font = ImageFont.truetype(BytesIO(response.content), size=20)


# Function to draw bounding boxes
def draw_boxes(image, boxes, labels, scores, labels_map):
    draw = ImageDraw.Draw(image, 'RGBA')
    # font = ImageFont.truetype(urlopen(truetype_url), size=20)
    # font = ImageFont.load_default() # default font in pil


    colors = {
        0: (255, 173, 10, 200),  # Class 0 color (e.g., blue)
        1: (28, 140, 252, 200),  # Class 1 color (e.g., orange)
    }
    colors_fill = {
        0: (255, 173, 10, 100),  # Class 0 fill color (e.g., bluea)
        1: (28, 140, 252, 100),  # Class 1 fill color (e.g., orangea)
    }

    for box, label, score in zip(boxes, labels, scores):
        if score > 0.6: # adjust threshold as needed
          color = colors.get(label, (0, 255, 0, 200))
          fill_color = colors_fill.get(label, (0, 255, 0, 100))
          draw.rectangle([(box[0], box[1]), (box[2], box[3])], outline=color, width=3)
          draw.rectangle([(box[0], box[1]), (box[2], box[3])], fill=fill_color)
          label_text = f"{labels_map[label]}: {score:.2f}"
          text_size = font.getbbox(label_text)
          draw.rectangle([(box[0], box[1] - text_size[1]), (box[0] + text_size[0], box[1])], fill=color)
          draw.text((box[0], box[1] - text_size[1]), label_text, fill="white", font=font)

    return image


# Load a single test image
image_path = '/content/faster-rcnn-object-detection-computer-vision-train-and-deploy/dataset/swag/images/1bd5a6b5-20240916_133544.jpg'
image = Image.open(image_path).convert("RGB")
image_tensor = F.to_tensor(image).unsqueeze(0).to(device)

# Run inference
with torch.no_grad():
    outputs = model(image_tensor)

# Get the boxes, labels, and scores
boxes = outputs[0]['boxes'].cpu().numpy()
labels = outputs[0]['labels'].cpu().numpy()
scores = outputs[0]['scores'].cpu().numpy()

# Define labels map
labels_map = {0: "Background", 1: "Union Sticker", 2: "Flyte Sticker"}

# Draw the boxes on the image
image_with_boxes = draw_boxes(image, boxes, labels, scores, labels_map)

# Display the image
image_with_boxes.show()

# Save the image
image_with_boxes.save('output_image.jpg')

We can use the same draw bounding boxes function to loop over the frame in a video.  

In [None]:
import cv2
import numpy as np
import requests
import torch
from flytekit.types.file import FlyteFile
from PIL import Image, ImageDraw, ImageFont
from torchvision.transforms import functional as F
from union import Artifact, UnionRemote
from io import BytesIO


# ------------------------------------
# create video writer
# ------------------------------------

# Video path and properties
video_path = "dataset/swag/videos/union_sticker_video.mp4"
video = cv2.VideoCapture(video_path)
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
frames_per_second = video.get(cv2.CAP_PROP_FPS)
num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))


# Initialize video writer
video_writer = cv2.VideoWriter(
    "object_detection_video.mp4",
    cv2.VideoWriter_fourcc(*"mp4v"),
    fps=float(frames_per_second),
    frameSize=(width, height),
    isColor=True,
)


def run_inference_video(video, model, device, labels_map):
    while True:
        hasFrame, frame = video.read()
        if not hasFrame:
            break

        # Convert frame to PIL image
        image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        image_tensor = F.to_tensor(image).unsqueeze(0).to(device)

        # Run inference
        with torch.no_grad():
            outputs = model(image_tensor)

        # Get the boxes, labels, and scores
        boxes = outputs[0]["boxes"].cpu().numpy()
        labels = outputs[0]["labels"].cpu().numpy()
        scores = outputs[0]["scores"].cpu().numpy()

        # Draw the boxes on the image
        image_with_boxes = draw_boxes(image, boxes, labels, scores, labels_map)

        # Convert back to OpenCV image format
        result_frame = cv2.cvtColor(np.array(image_with_boxes), cv2.COLOR_RGB2BGR)

        yield result_frame


# Run inference and write video
for frame in run_inference_video(video, model, device, labels_map):
    video_writer.write(frame)

# Release resources
video.release()
video_writer.release()


This example below shows how you could run the model on a live video feed. This won't run in the notebook, but you can run it in your local environment. 


In [None]:
# THIS WON"T WORK IN COLAB

# Webcam example 

import torch
import cv2
import time
from torchvision.transforms import functional as F
from huggingface_hub import hf_hub_download
from union import Artifact, UnionRemote
from flytekit.types.file import FlyteFile

# --------------------------------------------------
# Load the fine-tuned SSD model from Union Artifact
# --------------------------------------------------
FRCCNFineTunedModel = Artifact(name="frccn_fine_tuned_model")
query = FRCCNFineTunedModel.query(
    project="default",
    domain="development",
    # version="anmrqcq8pfbnlp42j2vp/n3/0/o0"  # Optional: specify version
)
remote = UnionRemote()
artifact = remote.get_artifact(query=query)
model_file: FlyteFile = artifact.get(as_type=FlyteFile)
model = torch.load(model_file.download(), map_location="cpu", weights_only=False)

model.eval()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

# --------------------------------------------------
# Function to process a single frame and draw bounding boxes
# --------------------------------------------------
def process_frame(frame):
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    image_tensor = F.to_tensor(frame_rgb).unsqueeze(0).to(device)

    with torch.no_grad():
        prediction = model(image_tensor)

    boxes = prediction[0]['boxes'].cpu().numpy()
    scores = prediction[0]['scores'].cpu().numpy()
    labels = prediction[0]['labels'].cpu().numpy()

    for i, box in enumerate(boxes):
        if scores[i] > 0.5: # Confidence threshold for detection 
            x_min, y_min, x_max, y_max = box
            cv2.rectangle(frame, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 2)
            label = f"Class {labels[i]}: {scores[i]:.2f}"
            cv2.putText(frame, label, (int(x_min), int(y_min) - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)

    return frame

# --------------------------------------------------
# Run feed with frame skipping option for efficiency
# --------------------------------------------------
def run_video_feed(skip_frames=5):
    frame_skip = skip_frames
    frame_count = 0
    last_processed_frame = None

    cap = cv2.VideoCapture(0)
    if not cap.isOpened():
        print("Error: Could not open video stream.")
        return

    prev_time = time.time()

    while True:
        ret, frame = cap.read()
        if not ret:
            print("Error: Failed to capture frame.")
            break

        current_time = time.time()
        fps = 1 / (current_time - prev_time)
        prev_time = current_time

        if frame_count % frame_skip == 0:
            last_processed_frame = process_frame(frame)
            fps_text = f"FPS: {fps:.2f}"
            cv2.putText(last_processed_frame, fps_text, (10, 30),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)

        if last_processed_frame is not None:
            cv2.imshow('Object Detection RCNN', last_processed_frame)

        frame_count += 1

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()

# Run the video feed function
if __name__ == "__main__":
    run_video_feed()

## Serving as an App on Union

We can serve our model as an app on Union. This allows us to run the model in a production environment and make it available for use by other applications or users.

This example will use Gradio, but we could also use any other web framework like Flask or FastAPI to serve our model as an API. 

In [None]:
# 👇 Run this command to Serve the trained model in a gradio application
!union deploy apps app.py frccn-object-detection-gradio


Just like the training pipeline, the code is added to this notebook for reference with the `%%writefile` magic command to overwrite the files if you want to make changes directly in the notebook. But running the cells below are not required since the code is already in the `workflows/` and `tasks/` folders.

In [None]:
%%writefile app.py

import os
from datetime import timedelta
from union import Artifact, ImageSpec, Resources
from union.app import App, Input, ScalingMetric
from flytekit.extras.accelerators import GPUAccelerator, L4

# Point to your object detection model artifact
FRCCNFineTunedModel = Artifact(name="frccn_fine_tuned_model")

image_spec = ImageSpec(
    name="union-serve-frccn-object-detector",
    packages=[
        "gradio==5.29.0",
        "torch==2.5.1",
        "torchvision==0.20.1",
        "union-runtime>=0.1.18",
        "opencv-python-headless",
    ],
    apt_packages=["ffmpeg", "libsm6", "libxext6"],
    cuda="11.8",
    builder="union",
)

gradio_app = App(
    name="frccn-object-detection-gradio",
    inputs=[
        Input(
            name="downloaded-model",
            value=FRCCNFineTunedModel.query(),
            download=True,
        )
    ],
    container_image=image_spec,
    port=8080,
    include=["./app_main.py"],  # Include your Streamlit code
    args=["python", "app_main.py"],
    limits=Resources(cpu="2", mem="8Gi", gpu="1"),
    requests=Resources(cpu="2", mem="8Gi", gpu="1"),
    accelerator=L4,
    min_replicas=0,
    max_replicas=1,
    scaledown_after=timedelta(minutes=2),
    scaling_metric=ScalingMetric.Concurrency(2),
)

# union deploy apps app.py frccn-object-detection-gradio



In [None]:
%%writefile app_main.py

import time

import gradio as gr
import numpy as np
import torch
from PIL import Image, ImageDraw
from torchvision.transforms import functional as F

# Load model from artifact or fallback path
try:
    from union_runtime import get_input

    model_path = get_input("downloaded-model")
except:
    model_path = "frccn_fine_tuned_model.pth"

# Load the model
model = torch.load(model_path, map_location="cpu", weights_only=False)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

labels_map = {1: "union", 2: "flyte"}


def detect_objects(frame: np.ndarray) -> np.ndarray:
    start = time.time()

    pil_img = Image.fromarray(frame).convert("RGB").resize((320, 240))
    img_tensor = F.to_tensor(pil_img).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(img_tensor)

    boxes = outputs[0]["boxes"].cpu()
    scores = outputs[0]["scores"].cpu()
    labels = outputs[0]["labels"].cpu()

    threshold = 0.5
    selected = scores > threshold
    boxes = boxes[selected]
    scores = scores[selected]
    labels = labels[selected]

    draw = ImageDraw.Draw(pil_img)
    for box, label, score in zip(boxes, labels, scores):
        x1, y1, x2, y2 = box.tolist()
        draw.rectangle([x1, y1, x2, y2], outline="green", width=3)
        draw.text(
            (x1, y1),
            f"{labels_map.get(label.item(), label.item())}: {score:.2f}",
            fill="white",
        )

    # Overlay inference time and device info
    end = time.time()
    inference_time = (end - start) * 1000  # ms
    debug_text = f"{device.type.upper()} | {inference_time:.1f} ms"
    draw.rectangle([0, 0, 200, 20], fill=(0, 0, 0, 128))  # semi-transparent background
    draw.text((5, 2), debug_text, fill="white")

    return np.array(pil_img)


# Create Gradio app with upload option
demo = gr.Interface(
    fn=detect_objects,
    inputs=gr.Image(type="numpy", label="Upload Image"),
    outputs=gr.Image(type="numpy", label="Detection Result"),
    title="Union Faster RCNN Object Detection",
    description="Upload an image to run Faster RCNN object detection.",
)


if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=8080)

# union deploy apps app.py frccn-object-detection-gradio