In [11]:
from pathlib import Path
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from functools import partial

from pilot_academy.data.io import read_fn
from pilot_academy.data.sampling import balanced_sampling_with_augmentation_tags

In [2]:
# Paths
PROJECT_DIR = Path("/workspaces/pilot_academy")
DATA_DIR = PROJECT_DIR / "data" 
OUT_DIR = PROJECT_DIR / "outputs" 
CONFIG_DIR = PROJECT_DIR / "config"

DATASET_NAME = "all_towns_with_weather"
DS_DIR = DATA_DIR / f"{DATASET_NAME}" / "raw"
IMAGES_DIR = DS_DIR / "images"

df_path = DS_DIR / "df_annotations.csv"
label_to_id_path = DS_DIR / "label_to_id.json"
action_to_id_path = DS_DIR / "action_to_id.json"

In [None]:
# Load annotations
df_annotations = pd.read_csv(df_path, index_col="frame_index")

# Load mappings
with open(action_to_id_path) as f:
    action_to_id = json.load(f)

with open(label_to_id_path) as f:
    label_to_id = {
        tuple(k.split("|")): int(v)
        for k, v in json.load(f).items()
    }


# Collect images
image_paths = sorted(IMAGES_DIR.glob("*.png"))

# map filename -> full path
paths_by_name = {p.name: p for p in image_paths}

needed = set(df_annotations["image_name"])
available = set(paths_by_name.keys())

missing = needed - available
if missing:
    raise FileNotFoundError(
        f"{len(missing)} missing image(s). First: {next(iter(missing))}"
    )


# Configuration
crop_ymin = 100
grayscale = True

configured_read_fn = partial(
    read_fn,
    crop_ymin=crop_ymin,
    grayscale=grayscale,
)

In [12]:
# Use df_annotations instead of df
df = df_annotations

CLASS_COL = "class_id"   # change if your column name differs

print("Original class distribution:")
print(df[CLASS_COL].value_counts().sort_index())
print()

# Test 1: Manual n_samples_per_class
print("\n" + "="*60)
print("TEST 1: Manual n_samples_per_class=500")
print("="*60)
balanced_df1 = balanced_sampling_with_augmentation_tags(
    df=df,
    class_column=CLASS_COL,
    n_samples_per_class=500,
    seed=42
)

# Test 2: balancing_mode='min' (no augmentation needed)
print("\n" + "="*60)
print("TEST 2: balancing_mode='min' (subsample all to smallest class)")
print("="*60)
balanced_df2 = balanced_sampling_with_augmentation_tags(
    df=df,
    class_column=CLASS_COL,
    balancing_mode='min',
    seed=42
)

# Test 3: balancing_mode='max' (augment all to largest class)
print("\n" + "="*60)
print("TEST 3: balancing_mode='max' (augment all to largest class)")
print("="*60)
balanced_df3 = balanced_sampling_with_augmentation_tags(
    df=df,
    class_column=CLASS_COL,
    balancing_mode='max',
    seed=42
)

# Test 4: balancing_mode='mean'
print("\n" + "="*60)
print("TEST 4: balancing_mode='mean'")
print("="*60)
balanced_df4 = balanced_sampling_with_augmentation_tags(
    df=df,
    class_column=CLASS_COL,
    balancing_mode='mean',
    seed=42
)

# Test 5: balancing_mode='median'
print("\n" + "="*60)
print("TEST 5: balancing_mode='median'")
print("="*60)
balanced_df5 = balanced_sampling_with_augmentation_tags(
    df=df,
    class_column=CLASS_COL,
    balancing_mode='median',
    seed=42
)

# Test 6: Combining with class selection
# IMPORTANT: exclude_classes must match the dtype/values in df[CLASS_COL]
print("\n" + "="*60)
print("TEST 6: balancing_mode='mean' with class exclusion")
print("="*60)

# Example A (exclude by numeric id):
# balanced_df6 = balanced_sampling_with_augmentation_tags(
#     df=df,
#     class_column=CLASS_COL,
#     exclude_classes=[3],   # replace 3 with the id you want excluded
#     balancing_mode='mean',
#     seed=42
# )

# Example B (exclude by string label) â€” only if class_id is actually strings
balanced_df6 = balanced_sampling_with_augmentation_tags(
    df=df,
    class_column=CLASS_COL,
    exclude_classes=['u_turn'],
    balancing_mode='mean',
    seed=42
)


Original class distribution:
class_id
1     1135
2     1185
3     1005
4     1180
5      570
6     1082
7     1622
8     1176
9    44144
Name: count, dtype: int64


TEST 1: Manual n_samples_per_class=500

Balanced Sampling Summary
Target samples per class: 500
Number of classes: 9
Total samples: 4500

Per-class breakdown:
------------------------------------------------------------
Class 9:
  Original:  44144 | Subsampled to 500
  Final:    500 | Augmented:      0
Class 2:
  Original:   1185 | Subsampled to 500
  Final:    500 | Augmented:      0
Class 7:
  Original:   1622 | Subsampled to 500
  Final:    500 | Augmented:      0
Class 8:
  Original:   1176 | Subsampled to 500
  Final:    500 | Augmented:      0
Class 4:
  Original:   1180 | Subsampled to 500
  Final:    500 | Augmented:      0
Class 3:
  Original:   1005 | Subsampled to 500
  Final:    500 | Augmented:      0
Class 6:
  Original:   1082 | Subsampled to 500
  Final:    500 | Augmented:      0
Class 5:
  Original:    570