In [1]:
import cv2
import time
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms, models
from ultralytics import YOLO
from IPython.display import display, HTML, clear_output
from PIL import Image
import matplotlib.pyplot as plt

Config

In [12]:
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Model paths
yolo_model_path = "model/best.pt"
vit_model_path = "model/ViT_best_model.pt"
restnet_model_path = "model/restnet_best_model.pt"

# Class labels (must match training data order)
classes = ['Bus', 'Hatchback', 'MPV', 'Motorcycle', 'PickUp', 'SUV', 'Sedan', 'Truck']

# Image preprocessing for ViT
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

Using device: cuda


Load model

In [13]:
# Load YOLO model
yolo_model = YOLO(yolo_model_path)

In [14]:
# Load ViT model
vit_model = models.vit_b_16(weights=None)
vit_model.heads = nn.Linear(vit_model.heads.head.in_features, len(class_name))
vit_model.load_state_dict(torch.load(vit_model_path, map_location=device, weights_only=True))
vit_model.to(device)
vit_model.eval()

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

Inference function

In [None]:
# classifying function
def classify_crop(image):
    # Convert BGR to RGB
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Preprocess
    input_tensor = transform(image_rgb).unsqueeze(0).to(device)
    
    # Inference
    with torch.no_grad():
        outputs = vit_model(input_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        confidence, predicted_idx = torch.max(probabilities, 1)
    
    class_name = classes[predicted_idx.item()]
    confidence_value = confidence.item()
    
    return class_name, confidence_value

In [None]:
# process frame per frame function
def process_frame(frame, conf_threshold: float = 0.2):
    # Run YOLO detection
    results = yolo_model(frame, verbose=False)[0]
    
    # Process each detection
    for box in results.boxes:
        # Get bounding box coordinates
        x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
        det_conf = box.conf[0].item()
        
        if det_conf < conf_threshold:
            continue
        
        # Ensure valid crop dimensions
        x1, y1 = max(0, x1), max(0, y1)
        x2, y2 = min(frame.shape[1], x2), min(frame.shape[0], y2)
        
        if x2 <= x1 or y2 <= y1:
            continue
        
        # Crop detected region
        crop = frame[y1:y2, x1:x2]
        
        # Classify the crop
        class_name, cls_conf = classify_crop(crop)
        
        # Draw bounding box
        cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 0), 2)
        
        # Prepare label text
        label = f"{class_name} {cls_conf:.2f}"
        
        # Draw label background
        (label_w, label_h), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
        cv2.rectangle(frame, (x1, y1 - label_h - 10), (x1 + label_w, y1), (255, 0, 0), -1)
        
        # Draw label text
        cv2.putText(frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
    
    return frame

In [None]:
# processing video
def process_video(input_path: str, output_path: str):
    # Open input video
    cap = cv2.VideoCapture(input_path)
    
    if not cap.isOpened():
        raise ValueError(f"Cannot open video: {input_path}")
    
    # Get video properties
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    print(f"Input video: {width}x{height} @ {fps}fps, {total_frames} frames")
    
    # Create video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    
    # Process frames
    frame_count = 0
    start_time = time.time()
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        # Process frame
        annotated_frame = process_frame(frame)
        
        # Write output
        out.write(annotated_frame)
        
        frame_count += 1
        
        # Progress update every 30 frames
        if frame_count % 30 == 0:
            elapsed = time.time() - start_time
            fps_actual = frame_count / elapsed
            progress = (frame_count / total_frames) * 100
            clear_output(wait=True)
            print(f"Progress: {progress:.1f}% ({frame_count}/{total_frames}) - {fps_actual:.1f} fps")
    
    # Cleanup
    cap.release()
    out.release()
    
    # Final statistics
    total_time = time.time() - start_time
    avg_fps = frame_count / total_time
    clear_output(wait=True)
    print(f"Completed! Processed {frame_count} frames in {total_time:.2f}s ({avg_fps:.1f} fps)")
    print(f"Output saved to: {output_path}")

Inference

In [21]:
# Input and output paths
vid_input = "traffic_test.mp4"
vid_output = "output_inference.mp4"

# Run video processing
process_video(vid_input, vid_output)

Completed! Processed 5920 frames in 621.19s (9.5 fps)
Output saved to: output_inference.mp4
