In [None]:
from google.colab import files

uploaded = files.upload()

for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))


In [None]:
import importlib
import utils
importlib.reload(utils)

# ViT + Custom BBox Head for Serengeti Animal Detection

This notebook implements a Vision Transformer (ViT) with custom classification and bounding box regression heads for the Serengeti wildlife dataset.

## Phase 1: Data Preparation & Exploration

### Setup and Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from utils import (
    # Path constants
    LABELS_PATH,
    DATASET_BASE_PATH,
    IMAGE_SIZE,
    BATCH_SIZE,
    # Data loading functions
    download_dataset,
    load_labels,
    prepare_label_data,
    # Dataset builders
    build_tf_dataset,
    split_dataset,
    # Visualization
    denormalize_bbox,
)

sns.set_theme()
print("Imports complete!")

### Download Dataset (if needed)

Uncomment and run the cell below if you need to download the dataset. This downloads from Kaggle and moves to `/content/dataset`.

In [None]:
# Uncomment to download dataset
download_dataset()

## Data Exploration

### Load and Inspect Raw Labels

In [None]:
# Load raw labels
labels_df = load_labels(LABELS_PATH)

print(f"Dataset shape: {labels_df.shape}")
print(f"Columns: {labels_df.columns.tolist()}")
labels_df.head(10)

### Check for Missing Values

Let's check if there are any missing values, especially in the bounding box columns (`a1`, `a2`, `a3`, `a4`) and `animal_count`.

In [None]:
# Check missing values across all columns
print("Missing values per column:")
print(labels_df.isnull().sum())
print(f"\nTotal rows: {len(labels_df)}")

In [None]:
# Check animal_count distribution
print("Animal count distribution:")
print(labels_df["animal_count"].value_counts().sort_index())

# How many rows have animals (animal_count > 0)?
has_animals = labels_df[labels_df["animal_count"] > 0]
no_animals = labels_df[labels_df["animal_count"] == 0]

print(f"\nRows with animals (animal_count > 0): {len(has_animals)}")
print(f"Rows without animals (animal_count = 0): {len(no_animals)}")

In [None]:
# Check bounding box columns for rows WITH animals
bbox_cols = ["a1", "a2", "a3", "a4"]

print("Bounding box analysis for rows WITH animals (animal_count > 0):")
print("-" * 50)

# Check for NaN values in bbox columns
bbox_nulls = has_animals[bbox_cols].isnull().sum()
print(f"NaN values in bbox columns:\n{bbox_nulls}")

# Check for zero values in bbox columns (might indicate missing data)
bbox_zeros = (has_animals[bbox_cols] == 0).all(axis=1).sum()
print(f"\nRows where ALL bbox values are 0: {bbox_zeros}")

# Check for any invalid bboxes (all zeros or any NaN)
invalid_bbox_mask = has_animals[bbox_cols].isnull().any(axis=1) | (has_animals[bbox_cols] == 0).all(axis=1)
invalid_count = invalid_bbox_mask.sum()
valid_count = len(has_animals) - invalid_count

print(f"\nValid bounding boxes: {valid_count}")
print(f"Invalid bounding boxes: {invalid_count}")

### Prepare Filtered Data

Now let's use `prepare_label_data()` to filter for images with animals and prepare the data for training.

In [None]:
# Prepare label data using utils function
filepaths, labels, bbox_array, label_encoder, positives = prepare_label_data(labels_df)

print(f"Filtered dataset size: {len(filepaths)}")
print(f"Number of classes: {len(label_encoder.classes_)}")
print(f"Classes: {label_encoder.classes_}")

In [None]:
# Verify bounding boxes in prepared data
print("Bounding box array shape:", bbox_array.shape)
print("Bounding box dtype:", bbox_array.dtype)

# Check for NaN or invalid values
nan_count = np.isnan(bbox_array).any(axis=1).sum()
all_zeros_count = (bbox_array == 0).all(axis=1).sum()

print(f"\nRows with NaN in bbox: {nan_count}")
print(f"Rows with all-zero bbox: {all_zeros_count}")

# Bbox statistics
print("\nBounding box statistics:")
print(f"  Min values: {bbox_array.min(axis=0)}")
print(f"  Max values: {bbox_array.max(axis=0)}")
print(f"  Mean values: {bbox_array.mean(axis=0).round(2)}")

### Class Distribution

Let's visualize the distribution of animal classes in our dataset.

In [None]:
# Class distribution
class_names = label_encoder.inverse_transform(labels)
class_counts = pd.Series(class_names).value_counts()

print("Class distribution:")
print(class_counts)

# Plot class distribution
fig, ax = plt.subplots(figsize=(12, 6))
class_counts.plot(kind="bar", ax=ax, color="steelblue", edgecolor="black")
ax.set_xlabel("Animal Class")
ax.set_ylabel("Count")
ax.set_title("Distribution of Animal Classes in Dataset")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

# Check for class imbalance
print(f"\nMost common class: {class_counts.idxmax()} ({class_counts.max()} samples)")
print(f"Least common class: {class_counts.idxmin()} ({class_counts.min()} samples)")
print(f"Imbalance ratio: {class_counts.max() / class_counts.min():.2f}x")

### Visualize Sample Images with Bounding Boxes

Let's look at some sample images with their annotated bounding boxes to understand the data quality.

In [None]:
import matplotlib.patches as patches
from PIL import Image

def visualize_sample(idx, filepaths, labels, bbox_array, label_encoder):
    """Visualize a single sample with its bounding box."""
    filepath = filepaths[idx]
    label = labels[idx]
    bbox = bbox_array[idx]
    class_name = label_encoder.inverse_transform([label])[0]
    
    try:
        # Load image
        img = Image.open(filepath)
        img_width, img_height = img.size
        
        # Create figure
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.imshow(img)
        
        # Draw bounding box (bbox is [x1, y1, x2, y2] in pixel coordinates)
        x1, y1, x2, y2 = bbox
        rect = patches.Rectangle(
            (x1, y1), x2 - x1, y2 - y1,
            linewidth=2, edgecolor='lime', facecolor='none'
        )
        ax.add_patch(rect)
        
        ax.set_title(f"Class: {class_name}\nBBox: [{x1:.0f}, {y1:.0f}, {x2:.0f}, {y2:.0f}]")
        ax.axis('off')
        plt.show()
        
    except Exception as e:
        print(f"Error loading image {filepath}: {e}")

# Visualize a few random samples
np.random.seed(42)
sample_indices = np.random.choice(len(filepaths), size=6, replace=False)

for idx in sample_indices:
    visualize_sample(idx, filepaths, labels, bbox_array, label_encoder)

## Build TensorFlow Datasets

Now let's build the TensorFlow datasets and create train/validation/test splits.

In [None]:
# Build TensorFlow dataset
ds = build_tf_dataset(filepaths, labels, bbox_array)

print(f"Dataset created successfully")
print(f"Image size: {IMAGE_SIZE}")
print(f"Batch size: {BATCH_SIZE}")

In [None]:
# Split into train/val/test
total_size = len(positives)
splits = split_dataset(ds, total_size=total_size)

train_ds = splits["train"]
val_ds = splits["val"]
test_ds = splits["test"]

# Calculate split sizes
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size

print(f"Dataset splits (80/10/10):")
print(f"  Training:   {train_size} samples")
print(f"  Validation: {val_size} samples")
print(f"  Test:       {test_size} samples")
print(f"  Total:      {total_size} samples")

In [None]:
# Verify dataset by inspecting a batch
for batch_images, batch_labels in train_ds.take(1):
    print("Sample batch inspection:")
    print(f"  Image batch shape: {batch_images.shape}")
    print(f"  Image dtype: {batch_images.dtype}")
    print(f"  Image value range: [{batch_images.numpy().min():.3f}, {batch_images.numpy().max():.3f}]")
    print(f"  Class labels shape: {batch_labels['class_output'].shape}")
    print(f"  Bbox output shape: {batch_labels['bbox_output'].shape}")
    print(f"  Sample class labels: {batch_labels['class_output'].numpy()[:5]}")
    print(f"  Sample bbox (normalized): {batch_labels['bbox_output'].numpy()[0]}")

## Data Summary

Summary of the prepared dataset for ViT training.

In [None]:
# Final summary
print("=" * 60)
print("PHASE 1 COMPLETE: Data Preparation Summary")
print("=" * 60)
print(f"\nDataset Statistics:")
print(f"  Total samples (with animals): {total_size}")
print(f"  Number of classes: {len(label_encoder.classes_)}")
print(f"  Classes: {list(label_encoder.classes_)}")
print(f"\nImage Configuration:")
print(f"  Image size: {IMAGE_SIZE}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"\nData Splits:")
print(f"  Training:   {train_size} samples ({train_size/total_size*100:.1f}%)")
print(f"  Validation: {val_size} samples ({val_size/total_size*100:.1f}%)")
print(f"  Test:       {test_size} samples ({test_size/total_size*100:.1f}%)")
print(f"\nBounding Box Info:")
print(f"  Format: [x1, y1, x2, y2] (normalized 0-1 after preprocessing)")
print(f"  All samples have valid bounding boxes: {nan_count == 0 and all_zeros_count == 0}")
print("=" * 60)