# Cats vs Dogs Classification - Data Exploration

This notebook explores the Cats vs Dogs dataset for binary image classification.

In [None]:
import os
import sys
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
from torchvision import transforms

from src.data.dataset import CatsDogsDataset
from src.data.preprocessing import get_train_transforms, get_val_transforms

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. Load Dataset

In [None]:
# Load datasets
data_dir = project_root / 'data' / 'processed'

train_dataset = CatsDogsDataset(str(data_dir), transform=get_val_transforms(), split='train')
val_dataset = CatsDogsDataset(str(data_dir), transform=get_val_transforms(), split='val')
test_dataset = CatsDogsDataset(str(data_dir), transform=get_val_transforms(), split='test')

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

## 2. Class Distribution

In [None]:
# Get class distribution
train_dist = train_dataset.get_class_distribution()
print(f"Training class distribution: {train_dist}")

# Plot distribution
plt.figure(figsize=(8, 4))
plt.bar(train_dist.keys(), train_dist.values(), color=['#FF6B6B', '#4ECDC4'])
plt.title('Class Distribution (Training Set)')
plt.xlabel('Class')
plt.ylabel('Count')
plt.tight_layout()
plt.show()

## 3. Sample Images

In [None]:
# Visualize sample images
def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    return tensor * std + mean

fig, axes = plt.subplots(2, 5, figsize=(15, 6))
class_names = ['Cat', 'Dog']

for i, ax in enumerate(axes.flat):
    if i < len(train_dataset):
        img, label = train_dataset[i]
        img = denormalize(img)
        img = img.permute(1, 2, 0).numpy()
        img = np.clip(img, 0, 1)
        ax.imshow(img)
        ax.set_title(f'{class_names[label]}')
    ax.axis('off')

plt.suptitle('Sample Training Images', fontsize=14)
plt.tight_layout()
plt.show()

## 4. Data Augmentation Examples

In [None]:
# Show augmentation effects
from src.data.preprocessing import get_train_transforms

if len(train_dataset.samples) > 0:
    sample_path, _ = train_dataset.samples[0]
    original_img = Image.open(sample_path).convert('RGB')
    
    fig, axes = plt.subplots(2, 4, figsize=(14, 7))
    train_transform = get_train_transforms()
    
    axes[0, 0].imshow(original_img.resize((224, 224)))
    axes[0, 0].set_title('Original')
    axes[0, 0].axis('off')
    
    for i, ax in enumerate(axes.flat[1:]):
        augmented = train_transform(original_img)
        augmented = denormalize(augmented)
        augmented = augmented.permute(1, 2, 0).numpy()
        augmented = np.clip(augmented, 0, 1)
        ax.imshow(augmented)
        ax.set_title(f'Augmented {i+1}')
        ax.axis('off')
    
    plt.suptitle('Data Augmentation Examples', fontsize=14)
    plt.tight_layout()
    plt.show()
else:
    print("No samples available for augmentation demo")

## 5. Model Architecture

In [None]:
from src.models.cnn_model import CatsDogsCNN, get_model_summary, count_parameters

model = CatsDogsCNN(num_classes=2)
print(get_model_summary(model))
print(f"\nTotal Parameters: {count_parameters(model):,}")

## 6. Test Forward Pass

In [None]:
# Test forward pass
batch_size = 4
dummy_input = torch.randn(batch_size, 3, 224, 224)

model.eval()
with torch.no_grad():
    output = model(dummy_input)
    probs = model.predict_proba(dummy_input)

print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Output logits:\n{output}")
print(f"\nProbabilities:\n{probs}")
print(f"Predicted classes: {probs.argmax(dim=1).tolist()}")

## Summary

This notebook covered:
1. Loading and exploring the Cats vs Dogs dataset
2. Visualizing class distribution
3. Viewing sample images
4. Data augmentation examples
5. Model architecture overview
6. Testing forward pass

Next steps:
- Train the model using `python src/models/train.py`
- Track experiments with MLflow
- Deploy using Docker