# AAI-521 Final Project – Group 3  
## 03 – Vehicle Classification Model (LMV vs HMV)

**Goal:**  
Train and evaluate a convolutional neural network (CNN) using the cropped
vehicle dataset created in Notebook 02.

This notebook includes:
- Loading preprocessed crops and labels
- Train/validation split
- CNN architecture
- Training loop and learning curves
- Evaluation (accuracy, classification report, confusion matrix)
- Qualitative results (annotated frames and video)

#### Imports & Configuration

In [None]:
import numpy as np
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split

import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, ConfusionMatrixDisplay

PROJECT_ROOT = Path().resolve().parent
DATA_ROOT = PROJECT_ROOT / "data"
IMAGES_ROOT = DATA_ROOT / "DETRAC-Images"
TRAIN_ANN_ROOT = DATA_ROOT / "DETRAC-Train-Annotations"
CROPPED_DATA_PATH = PROJECT_ROOT / "outputs" / "cropped_vehicle_dataset.npz"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Images root:", IMAGES_ROOT)
print("Train annotations root:", TRAIN_ANN_ROOT)
print("Using device:", DEVICE)
print("Loading dataset from:", CROPPED_DATA_PATH)

#### Load Dataset

In [None]:
data = np.load(CROPPED_DATA_PATH, allow_pickle=True)
images = data["images"]  # (N, H, W, 3), float32 in [0,1]
labels = data["labels"]  # (N,)
class_to_idx = data["class_to_idx"].item()
metadata = data["metadata"]

idx_to_class = {v: k for k, v in class_to_idx.items()}

print("Images shape:", images.shape)
print("Labels shape:", labels.shape)
print("Classes:", idx_to_class)

## 1. Train / Validation Split

We split the cropped dataset into training and validation sets
using an 80/20 random split with a fixed random seed for reproducibility.

#### Prepare Tensors & Split

In [None]:
X = torch.from_numpy(images).permute(0, 3, 1, 2)  # (N,3,H,W)
y = torch.from_numpy(labels)

dataset = TensorDataset(X, y)

train_ratio = 0.8
n_total = len(dataset)
n_train = int(train_ratio * n_total)
n_val = n_total - n_train

torch.manual_seed(42)
train_dataset, val_dataset = random_split(dataset, [n_train, n_val])

BATCH_SIZE = 128

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Train size: {len(train_dataset)}, Val size: {len(val_dataset)}")

## 2. CNN Architecture

We use a simple convolutional neural network with four convolutional
blocks followed by two fully connected layers.

This is our **baseline** model for vehicle classification.

#### CNN Model Definition

In [None]:
class VehicleClassifier(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 256),  # adjust if our feature map size differs
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

num_classes = len(class_to_idx)
model = VehicleClassifier(num_classes=num_classes).to(DEVICE)
print(model)

#### Training Setup

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

NUM_EPOCHS = 15

history = {
    "train_loss": [],
    "val_loss": [],
    "train_acc": [],
    "val_acc": [],
}
best_val_acc = 0.0
best_state_dict = None

## 3. Training Loop

We train for a fixed number of epochs and track:

- Training and validation loss
- Training and validation accuracy

We keep the model weights that achieved the best validation accuracy.

#### Training Loop

In [None]:
for epoch in range(1, NUM_EPOCHS + 1):
    # ---- Training ----
    model.train()
    train_loss = 0.0
    train_correct = 0
    n_train = 0

    for xb, yb in train_loader:
        xb = xb.to(DEVICE)
        yb = yb.to(DEVICE)

        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * xb.size(0)
        preds = logits.argmax(dim=1)
        train_correct += (preds == yb).sum().item()
        n_train += xb.size(0)

    train_loss /= n_train
    train_acc = train_correct / n_train

    # ---- Validation ----
    model.eval()
    val_loss = 0.0
    val_correct = 0
    n_val = 0

    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(DEVICE)
            yb = yb.to(DEVICE)

            logits = model(xb)
            loss = criterion(logits, yb)

            val_loss += loss.item() * xb.size(0)
            preds = logits.argmax(dim=1)
            val_correct += (preds == yb).sum().item()
            n_val += xb.size(0)

    val_loss /= n_val
    val_acc = val_correct / n_val

    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["train_acc"].append(train_acc)
    history["val_acc"].append(val_acc)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_state_dict = model.state_dict()

    print(
        f"Epoch {epoch:02d}/{NUM_EPOCHS} "
        f"- train_loss: {train_loss:.4f}, train_acc: {train_acc:.3f} "
        f"- val_loss: {val_loss:.4f}, val_acc: {val_acc:.3f}"
    )

print(f"Best validation accuracy: {best_val_acc:.3f}")
model.load_state_dict(best_state_dict)

## 4. Learning Curves

We plot training and validation loss and accuracy across epochs.
These figures will be included in the final report.

#### Learning Curves

In [None]:
epochs = range(1, NUM_EPOCHS + 1)

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# ----- LOSS PLOT -----
axes[0].plot(epochs, history["train_loss"], label="Train", marker="o")
axes[0].plot(epochs, history["val_loss"], label="Validation", marker="o")
axes[0].set_title("Training vs Validation Loss")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Cross-Entropy Loss")
axes[0].legend()
axes[0].grid(True, linestyle="--", alpha=0.4)

# ----- ACCURACY PLOT -----
axes[1].plot(epochs, history["train_acc"], label="Train", marker="o")
axes[1].plot(epochs, history["val_acc"], label="Validation", marker="o")
axes[1].set_title("Training vs Validation Accuracy")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Accuracy")
axes[1].legend()
axes[1].set_ylim(0.8, 1.01)  # adjust if needed
axes[1].grid(True, linestyle="--", alpha=0.4)

plt.tight_layout()
plt.show()

#### Save Model

In [None]:
MODELS_DIR = PROJECT_ROOT / "models"
MODELS_DIR.mkdir(exist_ok=True)

MODEL_PATH = MODELS_DIR / "vehicle_classifier.pth"
torch.save(model.state_dict(), MODEL_PATH)
print("Saved best model to:", MODEL_PATH)

## 5. Evaluation on Validation Set

We compute:
- Overall accuracy
- Classification report (precision, recall, F1)
- Normalized confusion matrix

#### Evaluation

In [None]:
model.eval()
all_preds = []
all_targets = []

with torch.no_grad():
    for xb, yb in val_loader:
        xb = xb.to(DEVICE)
        logits = model(xb)
        preds = logits.argmax(dim=1).cpu().numpy()
        all_preds.append(preds)
        all_targets.append(yb.numpy())

all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)

print(classification_report(
    all_targets, all_preds,
    target_names=[idx_to_class[i] for i in range(num_classes)]
))

fig, ax = plt.subplots(figsize=(6, 6))
ConfusionMatrixDisplay.from_predictions(
    all_targets, all_preds,
    display_labels=[idx_to_class[i] for i in range(num_classes)],
    normalize="true",
    ax=ax,
    cmap="Blues",
)
ax.set_title("Normalized Confusion Matrix")
plt.tight_layout()
plt.show()

## 6. Qualitative Results – Annotated Frames

We apply the trained CNN on original DETRAC frames to visualize its
predictions per vehicle bounding box.

#### Imports + paths

In [None]:
import cv2
import xml.etree.ElementTree as ET

# from src.utils_detrac import load_detrac_annotations
def load_detrac_annotations(xml_path: Path):
    """
    Parse a UA-DETRAC XML file into a dict:
    { frame_num (int): [ { 'id': str, 'bbox': [x, y, w, h], 'class': str }, ... ] }
    """
    tree = ET.parse(str(xml_path))
    root = tree.getroot()

    annotations = {}

    for frame in root.findall("frame"):
        frame_num = int(frame.get("num"))

        target_list = frame.find("target_list")
        if target_list is None:
            annotations[frame_num] = []
            continue

        targets = []
        for target in target_list.findall("target"):
            tid = target.get("id")

            box = target.find("box")
            attr = target.find("attribute")
            if box is None or attr is None:
                continue

            left = float(box.get("left"))
            top = float(box.get("top"))
            width = float(box.get("width"))
            height = float(box.get("height"))
            vehicle_class = attr.get("vehicle_type")

            targets.append({
                "id": tid,
                "bbox": [left, top, width, height],
                "class": vehicle_class,
            })

        annotations[frame_num] = targets

    return annotations

# define TARGET_SIZE using the cropped dataset shape (to match training)
# images: (N, H, W, 3)
crop_h, crop_w = images.shape[1], images.shape[2]
TARGET_SIZE = (crop_w, crop_h)  # (width, height) for cv2.resize
print("CNN input size (W,H):", TARGET_SIZE)

#### predict class for a single crop

In [None]:
import torch.nn.functional as F

def predict_vehicle_class(crop_rgb_np, model, device=DEVICE):
    """
    crop_rgb_np: numpy array (H, W, 3) in [0, 1] or [0, 255]
    Returns: (pred_label_str, confidence_float)
    """
    # Ensure float32 [0,1]
    if crop_rgb_np.dtype != "float32":
        crop = crop_rgb_np.astype("float32")
    else:
        crop = crop_rgb_np.copy()

    if crop.max() > 1.0:
        crop /= 255.0

    # (H,W,3) -> (1,3,H,W) tensor
    tensor = torch.from_numpy(crop).permute(2, 0, 1).unsqueeze(0).to(device)

    model.eval()
    with torch.no_grad():
        logits = model(tensor)
        probs = F.softmax(logits, dim=1)
        conf, pred_idx = probs.max(dim=1)

    pred_idx = int(pred_idx.item())
    conf = float(conf.item())
    pred_label = idx_to_class[pred_idx]

    return pred_label, conf

#### annotate a single frame

In [None]:
def annotate_frame(
    seq_id: str,
    frame_num: int,
    model,
    images_root=IMAGES_ROOT,
    ann_root=TRAIN_ANN_ROOT,
    target_size=TARGET_SIZE,
    device=DEVICE,
):
    """
    Load a DETRAC frame + its annotations, run CNN on each bbox crop,
    and overlay predicted labels on the original frame.
    Returns: annotated RGB image as numpy array.
    """
    seq_images_dir = images_root / seq_id
    xml_path = ann_root / f"{seq_id}.xml"

    assert seq_images_dir.exists(), f"Image folder not found: {seq_images_dir}"
    assert xml_path.exists(), f"XML file not found: {xml_path}"

    annotations = load_detrac_annotations(xml_path)

    img_file = seq_images_dir / f"img{frame_num:05d}.jpg"
    assert img_file.exists(), f"Image not found: {img_file}"

    # Load original frame (BGR -> RGB)
    frame_bgr = cv2.imread(str(img_file))
    frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
    h_img, w_img = frame_rgb.shape[:2]

    # Copy for drawing
    vis = frame_rgb.copy()

    # Iterate over targets in this frame
    targets = annotations.get(frame_num, [])
    if not targets:
        print(f"No vehicles found in frame {frame_num}")
        return vis

    for t in targets:
        x, y, w, h = t["bbox"]

        # Clamp bbox to image bounds
        x1 = max(int(x), 0)
        y1 = max(int(y), 0)
        x2 = min(int(x + w), w_img)
        y2 = min(int(y + h), h_img)

        if x2 <= x1 or y2 <= y1:
            continue

        # Crop and resize to CNN input size
        crop = frame_rgb[y1:y2, x1:x2]
        crop_resized = cv2.resize(crop, target_size)

        # Predict class with the CNN
        pred_label, conf = predict_vehicle_class(crop_resized, model, device=device)

        # Draw bbox
        rect = plt.Rectangle(
            (x1, y1),
            x2 - x1,
            y2 - y1,
            fill=False,
            edgecolor="lime",
            linewidth=1.5,
        )

        # We draw using matplotlib later, so just store rect and text info
        # But here for simplicity, we will immediately draw via OpenCV on vis.
        # Use OpenCV drawing for consistency:
        vis_bgr = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR)
        cv2.rectangle(vis_bgr, (x1, y1), (x2, y2), (0, 255, 0), 2)
        label_text = f"{pred_label} ({conf:.2f})"
        cv2.putText(
            vis_bgr,
            label_text,
            (x1, max(y1 - 5, 0)),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.5,
            (255, 255, 0),
            1,
            cv2.LINE_AA,
        )
        vis = cv2.cvtColor(vis_bgr, cv2.COLOR_BGR2RGB)

    return vis

#### Show one annotated frame

In [None]:
# Choose a sequence and frame to visualize
SEQ_ID = "MVI_20011"   # adjust to any available sequence
FRAME_NUM = 1          # any frame number that exists

annotated = annotate_frame(SEQ_ID, FRAME_NUM, model)

plt.figure(figsize=(8, 6))
plt.imshow(annotated)
plt.title(f"{SEQ_ID} – Frame {FRAME_NUM} (Predicted Labels)")
plt.axis("off")
plt.show()

## 7. (Optional) Qualitative Results – Annotated Video

We generate a short video clip with predicted labels overlaid on each frame.

#### annotate a frame in BGR for video

In [None]:
def annotate_frame_bgr_for_video(
    seq_id: str,
    frame_num: int,
    model,
    images_root=IMAGES_ROOT,
    ann_root=TRAIN_ANN_ROOT,
    target_size=TARGET_SIZE,
    device=DEVICE,
):
    """
    Similar to annotate_frame, but returns annotated frame in BGR format,
    suitable for cv2.VideoWriter.
    """
    seq_images_dir = images_root / seq_id
    xml_path = ann_root / f"{seq_id}.xml"

    annotations = load_detrac_annotations(xml_path)

    img_file = seq_images_dir / f"img{frame_num:05d}.jpg"
    if not img_file.exists():
        return None

    frame_bgr = cv2.imread(str(img_file))
    if frame_bgr is None:
        return None

    h_img, w_img = frame_bgr.shape[:2]
    vis = frame_bgr.copy()

    targets = annotations.get(frame_num, [])
    for t in targets:
        x, y, w, h = t["bbox"]

        x1 = max(int(x), 0)
        y1 = max(int(y), 0)
        x2 = min(int(x + w), w_img)
        y2 = min(int(y + h), h_img)

        if x2 <= x1 or y2 <= y1:
            continue

        crop = vis[y1:y2, x1:x2]
        crop_resized = cv2.resize(crop, target_size)
        crop_rgb = cv2.cvtColor(crop_resized, cv2.COLOR_BGR2RGB)

        pred_label, conf = predict_vehicle_class(crop_rgb, model, device=device)

        # Draw bbox + label on vis (BGR)
        cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
        label_text = f"{pred_label} ({conf:.2f})"
        cv2.putText(
            vis,
            label_text,
            (x1, max(y1 - 5, 0)),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.5,
            (0, 255, 255),
            1,
            cv2.LINE_AA,
        )

    return vis

#### Generate the video

In [None]:
OUTPUT_DIR = PROJECT_ROOT / "outputs" / "videos"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

SEQ_ID = "MVI_20011"     # choose a sequence that exists in your data
START_FRAME = 1
END_FRAME = 150          # e.g., first 150 frames

# Probe first frame to get size
first_frame_bgr = annotate_frame_bgr_for_video(SEQ_ID, START_FRAME, model)
if first_frame_bgr is None:
    raise RuntimeError("Could not load the first frame to determine video size.")

height, width = first_frame_bgr.shape[:2]
fps = 10  # choose any reasonable FPS

output_path = OUTPUT_DIR / f"{SEQ_ID}_predictions.mp4"
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))

for frame_num in range(START_FRAME, END_FRAME + 1):
    frame_bgr = annotate_frame_bgr_for_video(SEQ_ID, frame_num, model)
    if frame_bgr is None:
        print(f"Skipping frame {frame_num} (not found or unreadable).")
        continue
    writer.write(frame_bgr)

writer.release()
print("Wrote annotated video to:", output_path)

#### Display the video inline

In [None]:
from IPython.display import Video

Video(str(output_path), embed=True)

#### Convert MP4 → GIF and Save

In [None]:
import imageio.v2 as imageio

gif_path = output_path.with_suffix(".gif")
print("Source MP4:", output_path)
print("GIF will be saved as:", gif_path)

# Read frames from MP4 and collect for GIF
cap = cv2.VideoCapture(str(output_path))
frames = []
frame_count = 0

while True:
    ok, frame_bgr = cap.read()
    if not ok:
        break
    frame_count += 1

    # Optional: take every other frame (reduces GIF size)
    if frame_count % 2 != 0:
        continue

    frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
    frames.append(frame_rgb)

cap.release()
print(f"Collected {len(frames)} frames for GIF.")

# Save GIF
if frames:
    imageio.mimsave(gif_path, frames, fps=10)
    print("GIF saved to:", gif_path)
else:
    print("No frames collected — check video file.")

#### Display the GIF Inline

In [None]:
from IPython.display import Image, display

print("Displaying GIF:", gif_path)
display(Image(filename=str(gif_path)))

## 8. Summary

- Trained a CNN for vehicle type classification on cropped DETRAC images.
- Achieved validation accuracy of **X%** (see above).
- Class-wise performance is summarized in the classification report and
  confusion matrix.
- Qualitative visualizations show that the model generally predicts
  reasonable labels on unseen frames.

In future work, we could:
- Use a pretrained backbone (ResNet, MobileNet) for better accuracy.
- Perform LMV vs HMV grouping and analyze traffic counts over time.
- Integrate detection and tracking for full vehicle-counting analytics.