# Sunflower vs Rose Classification - Interactive Tutorial

This notebook provides an interactive walkthrough of the flower classification project.

## 1. Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

# Import custom modules
from data_loader import DataLoader as FlowerDataLoader, create_sample_dataset
from train import FlowerDataset, FlowerClassifier, get_transforms
from inference import FlowerPredictor
from utils import plot_sample_images, analyze_dataset_distribution, visualize_augmentations

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Create Sample Dataset

For demonstration purposes, let's create a sample dataset.

In [None]:
# Initialize data loader
data_loader = FlowerDataLoader(data_dir='./data')
data_loader.setup_directories()

# Create sample images
create_sample_dataset(data_dir='./data', num_samples=100)

# Prepare dataset (split into train/val/test)
data_loader.prepare_dataset(train_ratio=0.7, val_ratio=0.15, test_ratio=0.15)

## 3. Explore the Dataset

In [None]:
# Analyze dataset distribution
analyze_dataset_distribution('./data')

In [None]:
# Plot sample images
plot_sample_images('./data', num_samples=8)

## 4. Data Augmentation Preview

In [None]:
# Get transforms
train_transform, val_transform = get_transforms()

# Get a sample image
sample_image_path = list(Path('./data/processed/train/sunflowers').glob('*.jpg'))[0]

# Visualize augmentations
visualize_augmentations(sample_image_path, train_transform, num_augmentations=9)

## 5. Prepare Data Loaders

In [None]:
# Get data paths
train_paths, train_labels = data_loader.get_data_paths('train')
val_paths, val_labels = data_loader.get_data_paths('val')
test_paths, test_labels = data_loader.get_data_paths('test')

print(f"Training samples: {len(train_paths)}")
print(f"Validation samples: {len(val_paths)}")
print(f"Test samples: {len(test_paths)}")

# Create datasets
train_dataset = FlowerDataset(train_paths, train_labels, transform=train_transform)
val_dataset = FlowerDataset(val_paths, val_labels, transform=val_transform)
test_dataset = FlowerDataset(test_paths, test_labels, transform=val_transform)

# Create data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

## 6. Visualize a Batch

In [None]:
# Get a batch of training data
images, labels = next(iter(train_loader))

# Denormalize for visualization
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

class_names = ['Sunflower', 'Rose']

fig, axes = plt.subplots(2, 4, figsize=(16, 8))
for i, ax in enumerate(axes.flat):
    if i < len(images):
        img = images[i] * std + mean
        img = torch.clamp(img, 0, 1)
        img_np = img.permute(1, 2, 0).numpy()
        
        ax.imshow(img_np)
        ax.set_title(f'{class_names[labels[i]]}', fontsize=14, fontweight='bold')
        ax.axis('off')

plt.tight_layout()
plt.show()

## 7. Initialize Model

In [None]:
# Create model
model = FlowerClassifier(pretrained=True)
model = model.to(device)

# Print model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 8. Setup Training

In [None]:
# Training configuration
num_epochs = 10
learning_rate = 0.001

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

## 9. Training Loop (Simplified)

In [None]:
from train import Trainer

# Create trainer
trainer = Trainer(model, device, train_loader, val_loader, criterion, optimizer, scheduler)

# Train model
best_val_acc = trainer.fit(num_epochs=num_epochs, save_path='best_model.pth')

print(f"\nBest validation accuracy: {best_val_acc:.2f}%")

## 10. Visualize Training Progress

In [None]:
# Plot training metrics
trainer.plot_metrics(save_path='training_metrics.png')

# Display the plot
from IPython.display import Image, display
display(Image('training_metrics.png'))

## 11. Evaluate on Test Set

In [None]:
# Load best model
checkpoint = torch.load('best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])

# Create trainer for test evaluation
test_trainer = Trainer(model, device, test_loader, test_loader, criterion, None)
test_loss, test_acc, all_preds, all_labels = test_trainer.validate()

print(f"\nTest Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.2f}%")

## 12. Confusion Matrix

In [None]:
from train import plot_confusion_matrix

plot_confusion_matrix(all_labels, all_preds, class_names, save_path='confusion_matrix.png')

# Display
display(Image('confusion_matrix.png'))

## 13. Classification Report

In [None]:
from sklearn.metrics import classification_report

print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=class_names))

## 14. Make Predictions on New Images

In [None]:
# Initialize predictor
predictor = FlowerPredictor('best_model.pth')

# Get a test image
test_image_path = test_paths[0]

# Make prediction
pred_class, confidence, probs = predictor.predict(test_image_path)

print(f"\nPrediction: {pred_class}")
print(f"Confidence: {confidence:.2%}")
print(f"\nProbabilities:")
for cls, prob in zip(class_names, probs):
    print(f"  {cls}: {prob:.2%}")

## 15. Visualize Predictions

In [None]:
# Visualize predictions for multiple test images
import random

sample_test_paths = random.sample(test_paths, 6)

fig, axes = plt.subplots(2, 3, figsize=(15, 10))

for i, (ax, img_path) in enumerate(zip(axes.flat, sample_test_paths)):
    # Get prediction
    pred_class, confidence, probs = predictor.predict(img_path)
    
    # Load and display image
    from PIL import Image
    img = Image.open(img_path)
    ax.imshow(img)
    
    # Get true label
    true_label = class_names[test_labels[test_paths.index(img_path)]]
    
    # Set title with color coding
    color = 'green' if pred_class == true_label else 'red'
    ax.set_title(f"True: {true_label}\nPred: {pred_class} ({confidence:.1%})",
                color=color, fontweight='bold', fontsize=12)
    ax.axis('off')

plt.tight_layout()
plt.show()

## 16. Model Analysis

Find misclassified examples to understand model weaknesses.

In [None]:
from utils import get_misclassified_images

misclassified = get_misclassified_images(model, test_loader, device, class_names)

## 17. Save and Export Model

In [None]:
# Model is already saved as 'best_model.pth'

# Optional: Export to ONNX format
model.eval()
dummy_input = torch.randn(1, 3, 224, 224).to(device)

torch.onnx.export(model, dummy_input, "flower_classifier.onnx",
                 input_names=['input'],
                 output_names=['output'],
                 dynamic_axes={'input': {0: 'batch_size'},
                              'output': {0: 'batch_size'}})

print("âœ“ Model exported to ONNX format: flower_classifier.onnx")

## Conclusion

This notebook demonstrated:
- Dataset preparation and exploration
- Data augmentation techniques
- Model training with PyTorch
- Model evaluation and visualization
- Making predictions on new images

You can now use this trained model to classify sunflower and rose images!