# Training a Faster R-CNN model for 2D Object Detection on Waymo Perception Dataset

This notebook walks through the steps required in order to train and evaluate a Faster R-CNN model using Waymo Perception Open Dataset, for the purpose of 2D Object Detection - i.e bounding boxes for AV images.

## 1. Explore the Data

The data that we are using comes from the [Waymo Perception Open Dataset](https://waymo.com/open/data/perception/).
You will need to fill out a form to gain access to this data.

The dataset itself lives in a Google Cloud Storage (GCS) bucket, and exists in Parquet format.
Each Parquet file consists of 198 unique frames, with an image per camera (cameras 1 to 5), for a snippet of video.
The frames are from vehicle motion, in a given city, at a given time of day, in given weather.

In [None]:
import pyarrow.parquet as pq
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import io
import matplotlib.patches as patches
import cv2
import numpy as np

In [None]:
def print_parquet_schema(local_path: str):
    # Read the Parquet file
    table = pq.read_table(local_path)
    # View the schema (to find the image column)
    print("Schema:\n", table.schema)
    print("Columns:\n", table.column_names)

Once you gain access to the GCS bucket, download a Parquet file from the `camera_box` and `camera_image` folders. Ensure that they are named the same, so that they correspond to the same frame.
Let's investigate the schema:

In [None]:
camera_box_parquet = '/path/to/parquet'
camera_image_parquet = '/path/to/parquet'

print("=== Camera Box ===")
print_parquet_schema(camera_box_parquet)
print("=== Camera Image ===")
print_parquet_schema(camera_image_parquet)

Now, let's overlay a frame from the camera_image parquet with the corresponding bounding box:

In [None]:
# === STEP 1: Load the image file ===
# Load Parquet tables
camera_image_table = pq.read_table(camera_image_parquet)
camera_box_table = pq.read_table(camera_box_parquet)

# Inspect available columns
print("Image columns:", camera_image_table.column_names)
print("Box columns:", camera_box_table.column_names)

# Pick one row from camera image
image_row = camera_image_table.to_pandas().iloc[0]

# Extract image bytes
image_bytes = image_row['[CameraImageComponent].image']
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

# Get keys for matching
frame_timestamp = image_row['key.frame_timestamp_micros']
camera_name = image_row['key.camera_name']

# === STEP 2: Get matching bounding boxes for this image ===
box_df = camera_box_table.to_pandas()

# Filter boxes matching the same frame and camera
matching_boxes = box_df[
    (box_df['key.frame_timestamp_micros'] == frame_timestamp) &
    (box_df['key.camera_name'] == camera_name)
]

# === STEP 3: Draw image and overlay boxes ===
fig, ax = plt.subplots(1, figsize=(12, 8))
ax.imshow(image)

for _, box in matching_boxes.iterrows():
    cx = box['[CameraBoxComponent].box.center.x']
    cy = box['[CameraBoxComponent].box.center.y']
    w = box['[CameraBoxComponent].box.size.x']
    h = box['[CameraBoxComponent].box.size.y']
    
    # Convert from center to top-left
    x1 = cx - w / 2
    y1 = cy - h / 2

    # Create a rectangle
    rect = patches.Rectangle((x1, y1), w, h, linewidth=2,
                             edgecolor='red', facecolor='none')
    ax.add_patch(rect)

# Optional: show camera name and frame time
ax.set_title(f"Camera: {camera_name} | Timestamp: {frame_timestamp}")
plt.axis('off')
plt.show()

# Example Waymo class label map
label_map = {
    0: "Unknown",
    1: "Vehicle",
    2: "Pedestrian",
    3: "Sign",
    4: "Cyclist"
}

# Print class info for each bounding box
print("Bounding Box Classifications:")
for i, box in matching_boxes.iterrows():
    class_id = box['[CameraBoxComponent].type']
    class_name = label_map.get(class_id, f"Unknown ({class_id})")
    print(f"- Object {i}: Class ID = {class_id}, Label = {class_name}")


Showcase contents of a Parquet:

In [None]:
table = pq.read_table(camera_image_parquet)
print(f"Number of images: {table.num_rows}")
timestamps = table.column("key.frame_timestamp_micros").to_pylist()
print("Unique frames:", len(set(timestamps)))

Can even show how a the Parquet is essentially a video made up of individual frames:

In [None]:
# === Step 1: Load the parquet files ===
camera_image_table = pq.read_table(camera_image_parquet)
camera_box_table = pq.read_table(camera_box_parquet)

# Convert to DataFrames
df_img = camera_image_table.to_pandas()
df_box = camera_box_table.to_pandas()

# === Step 2: Filter to one camera (e.g. FRONT == 1) ===
CAMERA_ID = 1  # change if needed

df_img = df_img[df_img['key.camera_name'] == CAMERA_ID]
df_box = df_box[df_box['key.camera_name'] == CAMERA_ID]

# === Step 3: Get the first 20 unique frames ===
first_20_frames = df_img['key.frame_timestamp_micros'].unique()[:20]

# === Step 4: Plot each frame with bounding boxes ===
for i, timestamp in enumerate(first_20_frames):
    img_row = df_img[df_img['key.frame_timestamp_micros'] == timestamp].iloc[0]
    image_bytes = img_row['[CameraImageComponent].image']
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

    boxes = df_box[df_box['key.frame_timestamp_micros'] == timestamp]

    # Plot image
    fig, ax = plt.subplots(1, figsize=(10, 6))
    ax.imshow(image)

    for _, box in boxes.iterrows():
        cx = box['[CameraBoxComponent].box.center.x']
        cy = box['[CameraBoxComponent].box.center.y']
        w = box['[CameraBoxComponent].box.size.x']
        h = box['[CameraBoxComponent].box.size.y']
        x1 = cx - w / 2
        y1 = cy - h / 2

        rect = patches.Rectangle((x1, y1), w, h, linewidth=2,
                                 edgecolor='red', facecolor='none')
        ax.add_patch(rect)

    ax.set_title(f"Frame {i+1} | Timestamp: {timestamp}")
    plt.axis('off')
    plt.show()

Or even generate an actual video:

In [None]:
# === Load parquet files ===
camera_image_table = pq.read_table(camera_image_parquet)
camera_box_table = pq.read_table(camera_box_parquet)

df_img = camera_image_table.to_pandas()
df_box = camera_box_table.to_pandas()

# === Filter to a single camera (e.g. FRONT) ===
CAMERA_ID = 1  # FRONT camera, typically
df_img = df_img[df_img['key.camera_name'] == CAMERA_ID]
df_box = df_box[df_box['key.camera_name'] == CAMERA_ID]

# === Sort by frame timestamp ===
df_img = df_img.sort_values('key.frame_timestamp_micros')
timestamps = df_img['key.frame_timestamp_micros'].unique()

# === Label map from Waymo class IDs ===
label_map = {
    0: "Unknown",
    1: "Vehicle",
    2: "Pedestrian",
    3: "Sign",
    4: "Cyclist"
}

# === Color map for class IDs (BGR format for OpenCV) ===
color_map = {
    0: 128, 128, 128)     # Gray for Unknown
    1: (0, 0, 255),       # Red for Vehicle
    2: (0, 255, 0),       # Green for Pedestrian
    3: (255, 0, 0),       # Blue for Sign
    4: (0, 255, 255)     # Yellow for Cyclist
}

# === Get video frame size ===
sample_image = Image.open(io.BytesIO(df_img.iloc[0]['[CameraImageComponent].image'])).convert("RGB")
video_width, video_height = sample_image.size

# === Set up OpenCV video writer ===
fps = 10
video_writer = cv2.VideoWriter(
    "waymo_output_video.mp4",
    cv2.VideoWriter_fourcc(*'mp4v'),
    fps,
    (video_width, video_height)
)

# === Frame-by-frame processing ===
for timestamp in timestamps:
    # Load image
    img_row = df_img[df_img['key.frame_timestamp_micros'] == timestamp].iloc[0]
    image = Image.open(io.BytesIO(img_row['[CameraImageComponent].image'])).convert("RGB")
    img_np = np.array(image)

    # Get boxes for this frame
    boxes = df_box[df_box['key.frame_timestamp_micros'] == timestamp]

    for _, box in boxes.iterrows():
        cx = box['[CameraBoxComponent].box.center.x']
        cy = box['[CameraBoxComponent].box.center.y']
        w = box['[CameraBoxComponent].box.size.x']
        h = box['[CameraBoxComponent].box.size.y']
        x1 = int(cx - w / 2)
        y1 = int(cy - h / 2)
        x2 = int(cx + w / 2)
        y2 = int(cy + h / 2)

        class_id = box['[CameraBoxComponent].type']
        class_name = label_map.get(class_id, f"Unknown ({class_id})")
        color = color_map.get(class_id, (255, 255, 255))  # default white

        # Draw bounding box
        cv2.rectangle(img_np, (x1, y1), (x2, y2), color=color, thickness=2)

        # Prepare label text
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.6
        text_size = cv2.getTextSize(class_name, font, font_scale, 1)[0]
        text_origin = (x1, max(y1 - 5, text_size[1]))

        # Draw background for text
        cv2.rectangle(img_np,
                      (text_origin[0], text_origin[1] - text_size[1]),
                      (text_origin[0] + text_size[0], text_origin[1] + 4),
                      color, thickness=-1)

        # Draw label
        cv2.putText(img_np, class_name, (text_origin[0], text_origin[1]),
                    font, font_scale, (0, 0, 0), thickness=1, lineType=cv2.LINE_AA)

    # Convert to BGR for OpenCV
    frame_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
    video_writer.write(frame_bgr)

# === Finalize video ===
video_writer.release()
print("Video saved as 'waymo_output_video.mp4'")


**Note: Steps to take if there was time:**

Due to the vast size of the Waymo dataset, under tight time and compute resource constraints, it was not possible to examine the distribution of the image data - ie. the class distribution across all the samples.

This would be useful for this use case, so under-represented classes (likely cyclists) can be idenitfied and potentially oversampled, so that the model performs better on these classes.

## 2. Data Processing

From this part of the notebook onwards, due to the compute and storage requirements for the large scale of image data required to train a Faster R-CNN model, a Google Cloud VM was (and should be) used for the compute with an 800 GB disk attached to store the images.

You can use GCS to store the image data, and mount it on the VM using GCSFuse, but at the cost of slower data loading speeds when training.

**Notes about Google Cloud:**
- You can set up a free-tier Google Cloud account [here](https://cloud.google.com/free?hl=en)
- This initial $300 is enough to get going with a VM and storage, but does not cover for any GPUs used for training.
- Any cost incurred by using this notebook on Google Cloud is at the cost of the user.
- Once set up on GCP, you will need to request a higher quota on the SDD storage on the VM.
- When doing the data processing, ensure that you are authenticated as the email that you used to gain access to the Waymo data, otherwise you'll run into authentication errors. You can do this by running `gcloud auth login` on the VM and chosing the relevant Gmail email, and then the same for `gcloud auth application-default login`.

In this section, we will be transforming the Parquet files into `.pt` files, ready to be used for training, validation and testing, and saving them to the disk storage on the VM for downstream uses.

In [None]:
import gcsfs
import pyarrow.parquet as pq
import pandas as pd
import torch
import os
import io
import numpy as np
from PIL import Image
import logging
import gc
import random
from tqdm import tqdm

# ========== CONFIG ==========
SOURCE_BUCKET = "waymo_open_dataset_v_2_0_1"
SPLIT = "training"
NUM_BATCHES = 4               
BATCH_ID = 0                  
OUTPUT_ROOT = "/tmp/waymo_data"
TARGET_PT_FILES = 18000      
RANDOM_SEED = 42
# ============================

fs = gcsfs.GCSFileSystem(token="google_default")

# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(f"BATCH {BATCH_ID}")

# Global counter
global_pt_counter = 0


def list_parquet_files(prefix):
    files = fs.ls(prefix)
    return sorted([os.path.basename(f) for f in files if f.endswith(".parquet")])

def process_file(filename, max_to_save):
    global global_pt_counter

    if global_pt_counter >= max_to_save:
        return 0

    GCS_PREFIX = f"{SOURCE_BUCKET}/{SPLIT}"
    image_path = f"gs://{GCS_PREFIX}/camera_image/{filename}"
    box_path = f"gs://{GCS_PREFIX}/camera_box/{filename}"
    stats_path = f"gs://{GCS_PREFIX}/stats/{filename}"
    output_dir = os.path.join(OUTPUT_ROOT, SPLIT)
    os.makedirs(output_dir, exist_ok=True)

    if not fs.exists(box_path[5:]) or not fs.exists(stats_path[5:]):
        print(f"Missing box or stats for {filename}, skipping.")
        return 0

    try:
        with fs.open(image_path, 'rb') as f_img:
            df_img = pq.read_table(f_img).to_pandas()
        with fs.open(box_path, 'rb') as f_box:
            df_box = pq.read_table(f_box).to_pandas()
        with fs.open(stats_path, 'rb') as f_stats:
            df_stats = pq.read_table(f_stats).to_pandas()
    except Exception as e:
        print(f"Failed to read {filename}: {e}")
        return 0

    df_stats = df_stats[['key.frame_timestamp_micros', '[StatsComponent].location',
                         '[StatsComponent].time_of_day', '[StatsComponent].weather']]
    pairs = df_img[['key.frame_timestamp_micros', 'key.camera_name']].drop_duplicates()

    processed = 0

    for _, row in pairs.iterrows():
        if global_pt_counter >= max_to_save:
            break

        timestamp = row['key.frame_timestamp_micros']
        camera_id = row['key.camera_name']
        output_path = os.path.join(output_dir, f"frame_{timestamp}_{camera_id}.pt")

        if os.path.exists(output_path):
            continue

        try:
            img_row = df_img[
                (df_img['key.frame_timestamp_micros'] == timestamp) &
                (df_img['key.camera_name'] == camera_id)
            ].iloc[0]

            image = Image.open(io.BytesIO(img_row['[CameraImageComponent].image'])).convert("RGB")
            image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0

            boxes_df = df_box[
                (df_box['key.frame_timestamp_micros'] == timestamp) &
                (df_box['key.camera_name'] == camera_id)
            ]

            box_list, label_list = [], []
            for _, box in boxes_df.iterrows():
                cx = box['[CameraBoxComponent].box.center.x']
                cy = box['[CameraBoxComponent].box.center.y']
                w = box['[CameraBoxComponent].box.size.x']
                h = box['[CameraBoxComponent].box.size.y']
                x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
                box_list.append([x1, y1, x2, y2])
                label_list.append(box['[CameraBoxComponent].type'])

            stats_row = df_stats[df_stats['key.frame_timestamp_micros'] == timestamp]
            if stats_row.empty:
                continue

            meta = {
                "timestamp": int(timestamp),
                "camera_name": int(camera_id),
                "location": stats_row['[StatsComponent].location'].values[0],
                "time_of_day": stats_row['[StatsComponent].time_of_day'].values[0],
                "weather": stats_row['[StatsComponent].weather'].values[0],
                "split": SPLIT,
                "source_file": filename
            }

            sample = {
                "image": image_tensor,
                "boxes": torch.tensor(box_list, dtype=torch.float32),
                "labels": torch.tensor(label_list, dtype=torch.int64),
                "meta": meta
            }

            torch.save(sample, output_path)
            processed += 1
            global_pt_counter += 1

            del image_tensor, sample, image, img_row, boxes_df, box_list, label_list, stats_row
            gc.collect()

        except Exception as e:
            print(f"Error processing frame {timestamp} cam {camera_id}: {e}")

    del df_img, df_box, df_stats
    gc.collect()

    print(f"Finished {filename}, saved {processed} samples.")
    return processed

In [None]:
# List and shuffle all files
all_files = list_parquet_files(f"{SOURCE_BUCKET}/{SPLIT}/camera_image")

random.seed(RANDOM_SEED)
random.shuffle(all_files)

# Simulate batching
batch_files = [f for i, f in enumerate(all_files) if i % NUM_BATCHES == BATCH_ID]

print(f"Batch {BATCH_ID}: {len(batch_files)} files to process")

# Run batch loop
for fname in tqdm(batch_files):
    if global_pt_counter >= TARGET_PT_FILES:
        break
    process_file(fname, TARGET_PT_FILES)

print(f"Batch {BATCH_ID} complete. Total .pt files saved: {global_pt_counter}")


The above code is an example of how the data can be processed from Parquet into `.pt`, in batches and saved into a /tmp directory in the VM.

To run through all the batches, increase the value of `BATCH_ID` by 1 and run again.

The Waymo Dataset is alreayd split into `training`, `validation` and `testing`. So you can change the value of `SPLIT` here to use the same code for the other sets.

To fit onto 800GB of storage, 18,000 training samples, 5,000 validation samples and 4,000 testing samples will fit.

**NOTE:** Because Waymo data is used in public competitions, to ensure no cheating in these competitions, the `testing` dataset actually doesn't come with any labels ie. bounding boxes. Therefore, you should re-use the validation set and save it to a `testing` directory. You should change the random seed when running the `testing` processing, as the same seed would duplicate files from your `validation` dataset into `testing`.

**Running Batch Processing in Parallel**

If you would like a script to use instead, please find it under `scripts/process_batch.py`. This can be used along with `run_batch.sh` to launch multiple processing jobs at once on a GCP VM.

### Improvement Note:
The data taken and processed here is taken at random. This is likely to mimic the class distribution of the Waymo dataset, but this may not be favourable for under-represented classes eg. cyclists. This data processing script could/should be modified to balance the distributions of the classes, which would lead to better accuracy/precision for all classes.

## 3. Use Bayesian Optimisation to Find Optimal Hyperparameters

As I have learned all about during this course, Bayesian Optimisation is very useful for black-box optimisation problems - especially ones where feedback is slow, like in hyperparameter tuning.

We use Optuna to perform this BO task.

Optuna is an open-source, Python-based hyperparameter optimization framework that helps you automatically search for the best hyperparameter values for your machine learning models.

It is flexible, scalable, and often used for tasks like:
- Finding the best learning rate, dropout, or architecture settings
- Optimizing data preprocessing parameters
- Tuning black-box functions where manual tuning would be slow or infeasible

Optuna’s default algorithm is Tree-structured Parzen Estimator (TPE), which is a form of Bayesian optimization.
- TPE models the performance of hyperparameters and uses this to balance exploration vs. exploitation.
- Unlike classic grid/random search, Bayesian methods build a probabilistic model of the objective function and use it to choose the next promising hyperparameter set.

### NOTE:
The hyperparameter tuning, and the final training job will take __forever__ if done without a GPU.
As mentioned before, GPUs are covered in the free-tier of GCP. I used one, because without it, I would still be waiting.
I can't endorse that anyone else uses a GPU, but it does improve speed significantly.

I used an L4 GPU VM (g2-standard-12 on Google Cloud).

### First, let's define our custom Waymo DataLoader:

In [None]:
import torch
from torch.utils.data import Dataset
import os

class WaymoDataset(Dataset):
    def __init__(self, gcs_prefix, file_list, transform=None, label_map=None):
        self.gcs_prefix = gcs_prefix.rstrip('/')
        self.file_list = file_list
        self.transform = transform
        self.label_map = label_map

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        for _ in range(len(self.file_list)):
            file_name = self.file_list[idx]
            file_path = os.path.join(self.gcs_prefix, file_name)

            try:
                with open(file_path, 'rb') as f:
                    sample = torch.load(f)
            except Exception as e:
                raise RuntimeError(f"Failed to load file: {file_path}, error: {e}")

            image = sample['image']
            boxes = sample['boxes']
            labels = sample['labels']

            # Skip frames with no boxes
            if boxes.shape[0] == 0:
                idx = (idx + 1) % len(self.file_list)
                continue

            if self.label_map is not None:
                labels = torch.tensor([self.label_map.get(int(lbl), 0) for lbl in labels], dtype=torch.int64)

            target = {
                'boxes': boxes,
                'labels': labels
            }

            if self.transform:
                image = self.transform(image)

            return image, target

        raise RuntimeError("All samples in dataset have no boxes.")

### Loading in the Pre-Trained Faster R-CNN model

Although we were taught in this course to build CNNs from scratch, due to the complexity of the ML problem, and the tight time constraints, this was not feasible. However, it is very common practice to use a pre-trained model as the starting point.

The pre-trained Faster R-CNN model with a ResNet-50 backbone and Feature Pyramid Network (FPN) is commonly used as a starting point for object detection tasks. Trained on the COCO dataset, it learns general visual features like edges, textures, and object shapes, which can transfer well to new datasets. In practice, the model is fine-tuned by replacing its classification head with a new one that matches the number of classes in the target dataset. This approach allows efficient adaptation to custom detection problems without training from scratch.


In [None]:
import torchvision
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

def get_model(num_classes=4):  # 3 foreground classes + 1 background
    # Load a pre-trained Faster R-CNN model with a ResNet-50 backbone and FPN
    model = fasterrcnn_resnet50_fpn(pretrained=True)

    # Replace the classification head with one for our dataset
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model

Let's declare some utils:

In [None]:
label_map = {
    1: 1,  # vehicle
    2: 2,  # pedestrian
    4: 3   # cyclist
}
# Note that Waymo says cyclists are a 4, because 3 was used for signs, even though there are no sign labels in the dataset.

In [None]:
def save_checkpoint(model, optimizer, epoch, path):
    """
    Save model checkpoint to a path in the mounted GCSFuse directory.
    """
    checkpoint = {
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'epoch': epoch
    }
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(checkpoint, path)

### Define our training loop for the hyperparameter experiments:

This is used in the experiments to find the best combination of hyperparameters.

In [None]:
import torch
from torch.utils.data import DataLoader
from torchmetrics.detection.mean_ap import MeanAveragePrecision
import os
from datetime import datetime

def list_local_files(directory):
    return sorted([f for f in os.listdir(directory) if f.endswith('.pt')])

def train_one_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    for batch_idx, (images, targets) in enumerate(dataloader):
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        total_loss += losses.item()

        now = datetime.now()
        print(f"Time: {now.strftime('%H:%M:%S')} - [Train] Batch {batch_idx + 1}/{len(dataloader)} - Loss: {losses.item():.4f}")

    return total_loss / len(dataloader)

def validate(model, dataloader, device):
    model.eval()
    metric = MeanAveragePrecision(iou_thresholds=[0.5])
    with torch.no_grad():
        for images, targets in dataloader:
            images = [img.to(device) for img in images]
            targets_cpu = [{k: v.cpu() for k, v in t.items()} for t in targets]

            outputs = model(images)
            outputs_cpu = [{k: v.cpu() for k, v in o.items()} for o in outputs]

            metric.update(outputs_cpu, targets_cpu)

    score = metric.compute()
    map_50 = score['map_50'].item()
    print(f"[Validation] mAP@0.5: {map_50:.4f}")
    return map_50

def run_training(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = get_model(num_classes=4)  # 3 classes + background
    model.to(device)

    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=config["lr"],
        momentum=config["momentum"],
        weight_decay=config["weight_decay"]
    )

    train_files = list_local_files(config["train_prefix"])
    val_files = list_local_files(config["val_prefix"])
    train_files = train_files[:config["num_train_files"]]
    val_files = val_files[:config["num_val_files"]]

    label_map = label_map
    train_dataset = WaymoDataset(config["train_prefix"], train_files, label_map=label_map)
    val_dataset = WaymoDataset(config["val_prefix"], val_files, label_map=label_map)

    train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True,
                              collate_fn=lambda x: tuple(zip(*x)), num_workers=8, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False,
                            collate_fn=lambda x: tuple(zip(*x)), num_workers=4, pin_memory=True)

    best_map = 0.0

    for epoch in range(config["epochs"]):
        print(f"Starting Epoch: {epoch + 1}/{config['epochs']}")
        train_loss = train_one_epoch(model, train_loader, optimizer, device)
        val_map = validate(model, val_loader, device)

        if val_map > best_map:
            best_map = val_map
            save_checkpoint(model, optimizer, epoch, config["checkpoint_path"])

        print(f"Epoch {epoch + 1}: Train Loss = {train_loss:.4f}, Val mAP@0.5 = {val_map:.4f}")

    return best_map


Declare the objective function for Optuna:

In [None]:
def objective(trial):
    config = {
        "lr": trial.suggest_loguniform("lr", 1e-5, 1e-2),
        "momentum": trial.suggest_float("momentum", 0.7, 0.99),
        "weight_decay": trial.suggest_loguniform("weight_decay", 1e-6, 1e-3),
        "batch_size": trial.suggest_categorical("batch_size", [2, 4, 8]),
        "epochs": 4,
        "train_prefix": "/tmp/waymo_data/training",
        "val_prefix": "/tmp/waymo_data/validation",
        "checkpoint_path": "/tmp/waymo_data/checkpoints/best_model.pt",
        "num_train_files": 6000,
        "num_val_files": 2000
    }

    return run_training(config)

This will run 10 trials, using Optuna to find the best hyperparameter combination based on mAP@50. There is a timeout too, so whichever comes first: n_trials or timeout.

In [None]:
import optuna

n_trials=10
timeout=36000

study = optuna.create_study(direction="maximize")
study.optimize(lambda trial: objective(trial), n_trials=n_trials, timeout=timeout)

print("Best trial:")
print(study.best_trial)

### Improvement Note:

If I had more time, I would have let it run for longer with more experiements, in order to have a better chance at finding the most optimal set of hyperparameters.

## 4. Training of Final Model

Now that we have our optimal hyperparameters, time to plug them into the final training job.
This training job will use the whole training dataset from earlier.

This training function is the same, only it has early-stopping logic, so if there is no improvement for X continuous epochs, then the training stops.

In [None]:
def run_final_training(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = get_model(num_classes=4)  # 3 classes + background
    model.to(device)

    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=config["lr"],
        momentum=config["momentum"],
        weight_decay=config["weight_decay"]
    )

    train_files = list_local_files(config["train_prefix"])
    val_files = list_local_files(config["val_prefix"])
    train_files = train_files[:config["num_train_files"]]
    val_files = val_files[:config["num_val_files"]]

    label_map = label_map
    train_dataset = WaymoDataset(config["train_prefix"], train_files, label_map=label_map)
    val_dataset = WaymoDataset(config["val_prefix"], val_files, label_map=label_map)

    train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True,
                              collate_fn=lambda x: tuple(zip(*x)), num_workers=8, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False,
                            collate_fn=lambda x: tuple(zip(*x)), num_workers=4, pin_memory=True)

    best_map = 0.0
    best_epoch = 0
    patience = config.get("early_stopping_patience", 5)
    epochs_no_improve = 0

    for epoch in range(config["epochs"]):
        print(f"Starting Epoch: {epoch + 1}/{config['epochs']}")
        train_loss = train_one_epoch(model, train_loader, optimizer, device)
        val_map = validate(model, val_loader, device)

        if val_map > best_map:
            best_map = val_map
            best_epoch = epoch
            save_checkpoint(model, optimizer, epoch, config["checkpoint_path"])
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        print(f"Epoch {epoch + 1}: Train Loss = {train_loss:.4f}, Val mAP@0.5 = {val_map:.4f}")

        if epochs_no_improve >= patience:
            print(f"Early stopping triggered at epoch {epoch + 1}")
            break

    print(f"Best model at epoch {best_epoch + 1} with mAP@0.5 = {best_map:.4f}")
    return best_map


In [None]:
final_config = {
    "lr": <ADD OPTIMAL LR HERE>,
    "momentum": <ADD OPTIMAL MOMENTUM HERE>,
    "weight_decay": <ADD OPTIMAL WEIGHT DECAY HERE>,
    "batch_size": <ADD OPTIMAL BATCH SIZE HERE>,
    "epochs": 30,
    "checkpoint_path": "/tmp/waymo_data/checkpoints/final/best_model.pt",
    "train_prefix": "/tmp/waymo_data/training",
    "val_prefix": "/tmp/waymo_data/validation",
    "test_prefix": "/tmp/waymo_data/testing",
    "num_train_files": 18000, #check you have this many files first
    "num_val_files": 5000, #check you have this many files first
    "num_test_files": 4000, #check you have this many files first
    "early_stopping_patience": 5, # if no improvement for 5 consecutive epochs, stop training early.
}

# Train
run_final_training(final_config)

## 5. Evaluate Final Model

Run the model against the testing dataset to see how it performs!

This runs the evalutation using torchmetrics MeanAveragePrecision and plots a graph of per-class mAP.

In [None]:
def test_model(checkpoint_path, test_prefix, num_test_files=1000, batch_size=4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print("Loading best model from checkpoint...")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model = get_model(num_classes=4)
    model.load_state_dict(checkpoint["model_state"])
    model.to(device)
    model.eval()

    test_files = list_local_files(test_prefix)
    test_files = test_files[:num_test_files]

    label_map = label_map
    test_dataset = WaymoDataset(test_prefix, test_files, label_map=label_map)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                             collate_fn=lambda x: tuple(zip(*x)), num_workers=4, pin_memory=True)

    print("Runing Testing:")
    metric = MeanAveragePrecision(class_metrics=True)
    with torch.no_grad():
        for images, targets in test_loader:
            images = [img.to(device) for img in images]
            targets_cpu = [{k: v.cpu() for k, v in t.items()} for t in targets]
            outputs = model(images)
            outputs_cpu = [{k: v.cpu() for k, v in o.items()} for o in outputs]
            metric.update(outputs_cpu, targets_cpu)

    results = metric.compute()
    print("Final Test Set Evaluation:")
    for k, v in results.items():
        if isinstance(v, torch.Tensor):
            if v.ndim == 0:
                print(f"{k}: {v.item():.4f}")
            else:
                print(f"{k}: {v}")
        else:
            print(f"{k}: {v}")


    # Plot per-class AP
    if "classes" in results and "map_per_class" in results:
        class_ids = list(range(len(results["map_per_class"])))
        ap_values = results["map_per_class"].tolist()

        plt.figure(figsize=(8, 5))
        plt.bar(class_ids, ap_values)
        plt.xlabel("Class ID")
        plt.ylabel("AP")
        plt.title("Per-Class Average Precision (AP)")
        plt.savefig("per_class_ap.png")
        print("Saved per-class AP plot to per_class_ap.png")

    return results

In [None]:
test_model(
    checkpoint_path=final_config["checkpoint_path"],
    test_prefix=final_config["test_prefix"],
    num_test_files=final_config["num_test_files"],
    batch_size=final_config["batch_size"]
)

And this will give you some metrics, eg. for my training and evaluation:

Final Test Set Evaluation:
- map: 0.3047
- map_50: 0.5129
- map_75: 0.3142
- map_small: 0.0564
- map_medium: 0.2959
- map_large: 0.5829
- mar_1: 0.1566
- mar_10: 0.3460
- mar_100: 0.3959
- mar_small: 0.1212
- mar_medium: 0.4199
- mar_large: 0.6822
- map_per_class: tensor([0.3848, 0.2978, 0.2314])
- mar_100_per_class: tensor([0.4387, 0.3537, 0.3952])
- classes: tensor([1, 2, 3], dtype=torch.int32)

And will also give you a png chart of your per-class mAP.

## Final Notes:

- The biggest bottlenecks here are time and money. It is totally possible to attach a 100 TB Parallelstore volume to a multi-GPU VM and do the training over all the training data provided by Waymo. It would just be very expensive, and still take a long time to run hyperparameter experiments and train the final model.
- On the topic of time bottlenecks, the biggest factor here is actually the speed at which you can load samples in during the training process. Optimizations here would be very impactful. Distributed training would certainly help.

### Final, Final Note:
If you used any GCP resources, delete, destroy and wipe before you get charged :)