# SSDLite + MobileNetV3 — Inference and Evaluation
This notebook loads trained SSDLite+MobileNetV3 checkpoints, prepares the
validation/test dataloaders, and runs evaluation/visualization helpers (PR
curves, confusion matrix, per-class metrics).

Notes:
- Helpful helper modules are available in `ssdlite_mobnetv3_adis` (dataset,
  model, plot, evaluate, trainer, utils).
- The cells below perform only inference and evaluation — no training.


In [None]:
# clone the ADIS repository
!git clone https://github.com/sathishkumar67/SSD_MobileNetV3_ADIS.git
# move the files to the current directory
!mv /kaggle/working/SSD_MobileNetV3_ADIS/* /kaggle/working/ # move the files to the current directory
# upgrade pip
!pip install --upgrade pip
# install the required packages
!pip install  -r requirements.txt 

In [None]:
# Notebook imports and reproducibility settings
# This cell imports required libraries and project modules used for inference
# Keep imports minimal here; heavy work is delegated to `ssdlite_mobnetv3_adis` helpers
import os
import random
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
import torch
from torch.utils.data import DataLoader, RandomSampler
from torchvision import transforms
from ssdlite_mobnetv3_adis.dataset import collate_fn, SSDLITEOBJDET_DATASET, CachedSSDLITEOBJDET_DATASET
from ssdlite_mobnetv3_adis.model import SSDLITE_MOBILENET_V3_Large
from ssdlite_mobnetv3_adis.utils import unzip_file
from ssdlite_mobnetv3_adis.inference import draw_detections

# Set random seed for reproducibility (affects dataset sampling, torch, numpy, etc.)
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
# Dataset / repository constants and download helper
# Update REPO_ID / DATASET_NAME if you have a different dataset or hub location
REPO_ID = "pt-sk/ADIS"
DATASET_NAME = "balanced_dataset"
REPO_TYPE = "dataset"
FILENAME_IN_REPO = f"{DATASET_NAME}.zip"
LOCAL_DIR = os.getcwd()
DATASET_PATH = f"{LOCAL_DIR}/{FILENAME_IN_REPO}"
DATASET_FOLDER_PATH = f"{LOCAL_DIR}/{DATASET_NAME}"

# Class list and counts (CLASSES excludes background; NUM_CLASSES_WITH_BG adds background)
CLASSES = ['Cat', 'Cattle', 'Chicken', 'Deer', 'Dog', 'Squirrel', 'Eagle', 'Goat', 'Rodents', 'Snake']
NUM_CLASSES = len(CLASSES)
NUM_CLASSES_WITH_BG = NUM_CLASSES + 1    # 1 for background class

# Download the dataset from the Hub and unzip it locally (idempotent if already present)
hf_hub_download(repo_id=REPO_ID, filename=FILENAME_IN_REPO, repo_type=REPO_TYPE, local_dir=LOCAL_DIR)
unzip_file(DATASET_PATH, LOCAL_DIR)

In [None]:
# DataLoader construction and caching notes
# This cell prepares cached dataset wrappers and DataLoaders used for train/val/test
# BATCH_SIZE, PIN_MEMORY_DEVICE and num_workers are tuned for your environment
PIN_MEMORY_DEVICE = "cuda:0"
NUM_CORES = os.cpu_count()
BATCH_SIZE = 64 # Adjust based on your system's memory capacity

# Prepare dataset objects (uses CachedSSDLITEOBJDET_DATASET which creates/reads an LMDB cache)
train_dataset = CachedSSDLITEOBJDET_DATASET(
    dataset_class=SSDLITEOBJDET_DATASET,
    root_dir=DATASET_FOLDER_PATH,
    split="train",
    num_classes=NUM_CLASSES_WITH_BG)

val_dataset = CachedSSDLITEOBJDET_DATASET(
    dataset_class=SSDLITEOBJDET_DATASET,
    root_dir=DATASET_FOLDER_PATH,
    split="val",
    num_classes=NUM_CLASSES_WITH_BG)

test_dataset = CachedSSDLITEOBJDET_DATASET(
    dataset_class=SSDLITEOBJDET_DATASET,
    root_dir=DATASET_FOLDER_PATH,
    split="test",
    num_classes=NUM_CLASSES_WITH_BG)

# Use RandomSampler with fixed generator for reproducible shuffling across runs
train_sampler = RandomSampler(train_dataset, generator=torch.Generator().manual_seed(RANDOM_SEED))
val_sampler = RandomSampler(val_dataset, generator=torch.Generator().manual_seed(RANDOM_SEED))
test_sampler = RandomSampler(test_dataset, generator=torch.Generator().manual_seed(RANDOM_SEED))

# Create DataLoaders — collate_fn converts numpy HWC images to CHW tensors and packs targets
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    sampler=train_sampler,
    num_workers=NUM_CORES,
    collate_fn=collate_fn,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2,
    pin_memory_device=PIN_MEMORY_DEVICE)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    sampler=val_sampler,
    num_workers=NUM_CORES,
    collate_fn=collate_fn,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2,
    pin_memory_device=PIN_MEMORY_DEVICE)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    sampler=test_sampler,
    num_workers=NUM_CORES,
    collate_fn=collate_fn,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2,
    pin_memory_device=PIN_MEMORY_DEVICE)

In [None]:
# Device selection and model checkpoint loading
# Choose device (GPU if available) — ensure your environment has CUDA configured if using GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path to saved checkpoint (update this if your checkpoint lives elsewhere)
best_ckpt_path = ""
best_ckpt = torch.load(best_ckpt_path, map_location="cpu")

# Build the model with the correct number of classes (includes background)
best_model = SSDLITE_MOBILENET_V3_Large(num_classes_with_bg=NUM_CLASSES_WITH_BG)
best_model.load_state_dict(best_ckpt['model_state_dict'], strict=True)
best_model.to(device)
best_model.eval()
# Model is now ready for inference
print("Model loaded successfully!")

In [None]:
# evaluate the model
from ssdlite_mobnetv3_adis.evaluate import compute_average_metrics
eval_metrics = compute_average_metrics(
    best_model,
    val_loader, # use val_loader or test_loader or any other dataloader as needed
    device,
    CLASSES,
)
eval_metrics

In [None]:
# specify the image path for inference
image_path = ""
# load that image and convert to torch tensor with image size (320, 320)
image = Image.open(image_path)
image = transforms.Resize((320, 320))(image)
image = transforms.ToTensor()(image)
image = torch.unsqueeze(image, 0)
image = image.to(device)

In [None]:
# draw detections on the image
draw_detections(
    image_or_path=image_path,
    model=best_model,
    device=device,
    classes=CLASSES,
    conf_thresh=0.07,
    input_size=320,
    show=True
)