In [16]:
import sys
import os
import torch
import importlib
import inspect
from PIL import Image
import torchvision.transforms as T
import cv2
import numpy as np

from utils.util import *

In [17]:
WEIGHT_PATH = "weights/best.pt"
IMG_PATH = "dataset/VietNam_street.png"
#VietNam_street
#Highway

In [18]:
# load checkpoint
ckpt = torch.load(WEIGHT_PATH, map_location="cpu")

# check what keys exist
print("checkpoint type:", type(ckpt))
print("checkpoint keys:", ckpt.keys())

# extract the model
if "model" in ckpt:
    model = ckpt["model"]
    print("Loaded nn.Module directly from checkpoint.")
else:
    raise ValueError("No 'model' key in checkpoint. Found keys: ", ckpt.keys())

# move to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()


checkpoint type: <class 'dict'>
checkpoint keys: dict_keys(['epoch', 'model'])
Loaded nn.Module directly from checkpoint.


  ckpt = torch.load(WEIGHT_PATH, map_location="cpu")


In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# assume 'model' is already loaded (the nn.Module from your checkpoint)
model.to(device).eval()

# show dtype of model params
param_dtype = next(model.parameters()).dtype
print("Model param dtype:", param_dtype)

# preprocess image (adapt resize if repo expects different)
img = Image.open(IMG_PATH).convert("RGB")
transform = T.Compose([
    T.Resize((640, 640)),
    T.ToTensor(),        # yields float32 in [0,1]
])
x = transform(img).unsqueeze(0)   # shape [1,3,H,W]

# move & match dtype
x = x.to(device)
if param_dtype == torch.half:
    print("Converting input to half (fp16) to match model.")
    x = x.half()
else:
    x = x.float()

with torch.no_grad():
    out = model(x)   # forward

# move outputs to cpu and float for postprocessing/printing
def to_cpu_float(t):
    if isinstance(t, torch.Tensor):
        return t.detach().cpu().float()
    return t

# inspect output
if isinstance(out, torch.Tensor):
    print("output.shape:", to_cpu_float(out).shape)
elif isinstance(out, (list, tuple)):
    for i,o in enumerate(out):
        if isinstance(o, torch.Tensor):
            print(i, to_cpu_float(o).shape)
        else:
            print(i, type(o))
else:
    print(type(out))

Model param dtype: torch.float16
Converting input to half (fp16) to match model.
output.shape: torch.Size([1, 84, 8400])


# Inference

In [20]:
# --- Load image ---
img = Image.open(IMG_PATH).convert("RGB")

# repo usually expects [B,C,H,W] normalized to 0-1, size 640x640
transform = T.Compose([
    T.Resize((640, 640)),
    T.ToTensor(),   # converts to float32 [0,1]
])
x = transform(img).unsqueeze(0).to(device)

# Match dtype
if next(model.parameters()).dtype == torch.half:
    x = x.half()

# --- Forward pass ---
with torch.no_grad():
    pred = model(x)[0]   # usually model returns (pred, loss); we take pred

# --- Postprocess ---
# Apply NMS: filter boxes by confidence/IoU
detections = non_max_suppression(pred, confidence_threshold=0.25, iou_threshold=0.45)[0]

if detections is not None and len(detections):
    # Rescale boxes from 640x640 back to original image size
    detections[:, :4] = scale_coords(x.shape[2:], detections[:, :4], img.size[::-1]).round()

    # Draw boxes with OpenCV
    img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
    for *xyxy, conf, cls in detections:
        label = f"{int(cls)} {conf:.2f}"
        xyxy = [int(v) for v in xyxy]
        cv2.rectangle(img_cv, (xyxy[0], xyxy[1]), (xyxy[2], xyxy[3]), (0,255,0), 2)
        cv2.putText(img_cv, label, (xyxy[0], xyxy[1]-2),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
    cv2.imwrite("inference_result.jpg", img_cv)
    print("Saved inference_result.jpg with detections.")
else:
    print("No detections found.")

Saved inference_result.jpg with detections.


In [21]:
with torch.no_grad():
    out = model(x)
print(type(out))
if isinstance(out, (list,tuple)):
    for i,o in enumerate(out):
        print(i, type(o), getattr(o,"shape",None))
else:
    print(getattr(out,"shape",None))

<class 'torch.Tensor'>
torch.Size([1, 84, 8400])
