# Hybrid CNN + ViT Traffic Sign Recognition
End-to-end notebook: dataset download/selection, dataloaders, training, eval, TFLite export.
**Set EPOCHS higher (e.g., 50) for full training.**


## 1) Install dependencies
Uncomment the pip line if running in a clean environment. TensorFlow needs Python 3.10-3.12.


In [None]:
# If needed, install requirements.
# !pip install -r requirements.txt


## 2) Locate or download the GTSRB dataset
Uses data/dataset_path.txt if present (written by scripts/download_dataset.py). Falls back to KaggleHub download.


In [None]:
from pathlib import Path

DATASET_POINTER = Path("data/dataset_path.txt")
if DATASET_POINTER.exists():
    dataset_root = Path(DATASET_POINTER.read_text().strip())
    print(f"Using dataset from pointer: {dataset_root}")
else:
    print("dataset_path.txt not found; attempting KaggleHub download (requires Kaggle auth).")
    import kagglehub

    dataset_root = Path(kagglehub.dataset_download("meowmeowmeowmeowmeow/gtsrb-german-traffic-sign"))
    DATASET_POINTER.parent.mkdir(parents=True, exist_ok=True)
    DATASET_POINTER.write_text(str(dataset_root), encoding="utf-8")
    print(f"Wrote pointer to {DATASET_POINTER}")

if not dataset_root.exists():
    raise FileNotFoundError(f"Dataset path does not exist: {dataset_root}")

train_dir = dataset_root / "Train"
print(f"Train dir: {train_dir}")


## 3) Quick dataset sanity check
Counts a few class folders and confirms test presence.


In [None]:
class_dirs = [p for p in train_dir.iterdir() if p.is_dir()]
print(f"Found {len(class_dirs)} classes")
sample_counts = {p.name: len(list(p.glob('*'))) for p in class_dirs[:5]}
print("Sample class counts (first 5):", sample_counts)
test_dir = dataset_root / "Test"
print("Has Test directory:", test_dir.exists())
print("Has Test.csv:", (dataset_root / 'Test.csv').exists() or (dataset_root / 'test.csv').exists())


## 4) Build TensorFlow dataloaders
Augmentation + normalization are applied inside the pipeline.


In [None]:
import tensorflow as tf
from tsr.data import DatasetConfig, load_gtsrb_datasets

IMG_SIZE = 224
BATCH_SIZE = 64

data_cfg = DatasetConfig(img_size=IMG_SIZE, batch_size=BATCH_SIZE)
train_ds, val_ds, test_ds, num_classes = load_gtsrb_datasets(dataset_root, data_cfg)
print("num_classes:", num_classes)
for batch_x, batch_y in train_ds.take(1):
    print("Batch shapes:", batch_x.shape, batch_y.shape)


## 5) Build hybrid CNN + Transformer model
MobileNetV2 backbone + transformer encoder over feature-map tokens.


In [None]:
from tsr.model import ModelConfig, build_hybrid_cnn_vit

model_cfg = ModelConfig(img_size=IMG_SIZE, num_classes=num_classes, backbone_trainable=False)
model = build_hybrid_cnn_vit(model_cfg)
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=3e-4),
    loss="categorical_crossentropy",
    metrics=[
        "accuracy",
        tf.keras.metrics.Precision(name="precision"),
        tf.keras.metrics.Recall(name="recall"),
    ],
)
model.summary()


## 6) Train
Set EPOCHS higher (e.g., 50) for best accuracy; a small number keeps the demo quick.


In [None]:
from pathlib import Path

EPOCHS = 3  # increase for real training
OUT_DIR = Path("outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)
(OUT_DIR / "logs").mkdir(parents=True, exist_ok=True)
(OUT_DIR / "checkpoints").mkdir(parents=True, exist_ok=True)

callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2),
    tf.keras.callbacks.ModelCheckpoint(
        filepath=str(OUT_DIR / "checkpoints" / "best.keras"),
        monitor="val_loss",
        save_best_only=True,
    ),
    tf.keras.callbacks.TensorBoard(log_dir=str(OUT_DIR / "logs")),
]

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=callbacks,
)

saved_model_dir = OUT_DIR / "saved_model"
model.save(saved_model_dir)
print("SavedModel:", saved_model_dir)


## 7) Evaluate on test set (if present)


In [None]:
if test_ds is not None:
    eval_results = model.evaluate(test_ds, verbose=2)
    metrics = dict(zip(model.metrics_names, eval_results))
    print("Test metrics:", metrics)
else:
    print("No test set detected; skipping test evaluation.")


## 8) Export TensorFlow Lite
Adjust the block for full INT8 quantization (needs representative data).


In [None]:
converter = tf.lite.TFLiteConverter.from_saved_model(str(saved_model_dir))
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# For full int8: set representative_dataset using load_gtsrb_datasets(..., apply_preprocessing=False)
tflite_model = converter.convert()

tflite_path = OUT_DIR / "model.tflite"
tflite_path.write_bytes(tflite_model)
print("TFLite model written to:", tflite_path)


## Next steps
- Increase `EPOCHS`, set `backbone_trainable=True` for fine-tuning, and monitor `outputs/logs` in TensorBoard.
- For full INT8 export, pass `int8=True` in scripts/export_tflite.py or set a representative dataset here.
- Pair this classifier with a detector/ROI cropper for real-world traffic scenes.
