In [None]:
"""
## Setup and Imports
"""
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# Import functions from our training script
from scripts.train import load_dataset, build_model


In [None]:
"""
## Configuration
"""
DATA_DIR = Path('../data')               # Base data directory
IMG_SIZE = (224, 224)                    # Image dimensions
BATCH_SIZE = 8                           # Small batch for quick iteration

In [None]:
"""
## Visualize Sample Images
"""
# Load a small batch for visualization
train_ds = load_dataset(DATA_DIR / 'train', IMG_SIZE, BATCH_SIZE, shuffle=True)
class_names = train_ds.class_names

for images, labels in train_ds.take(1):
    plt.figure(figsize=(12, 6))
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype('uint8'))
        plt.title(class_names[int(labels[i])])
        plt.axis('off')
    plt.show()

In [None]:
"""
## Prototype Model Architecture
"""
# Build and summarize the model
model = build_model(IMG_SIZE)
model.summary()

In [None]:
"""
## Quick Training Run
"""
# Train on a few batches to verify the pipeline
history = model.fit(
    train_ds.take(10),           # only 10 batches
    validation_data=train_ds.take(2),  # small validation
    epochs=1
)

In [None]:
"""
## Plot Loss and Accuracy Curves
"""
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='train_loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.legend(); plt.title('Loss')

plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='train_acc')
plt.plot(history.history['val_accuracy'], label='val_acc')
plt.legend(); plt.title('Accuracy')
plt.show()

In [None]:
"""
## Evaluate on Validation Set
"""
val_ds = load_dataset(DATA_DIR / 'validation', IMG_SIZE, BATCH_SIZE, shuffle=False)
eval_results = model.evaluate(val_ds.take(5))  # small subset
print(f"Validation (subset) - Loss: {eval_results[0]:.4f}, Accuracy: {eval_results[1]:.4f}")