# Zenith AI + PyTorch: Image Classification

**Train a CNN on CIFAR-10 with Zenith's high-performance DataLoader**

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vibeswithkk/Zenith-dataplane/blob/main/notebooks/01_pytorch_cifar10.ipynb)

## What You'll Learn
- Install and use Zenith in 2 lines
- Load CIFAR-10 with Zenith DataLoader
- Train a CNN with zero-copy data transfer
- Compare speed vs PyTorch native DataLoader

---
## 1. Install Zenith

In [None]:
# Install zenith-ai (takes ~10 seconds)
!pip install zenith-ai torch torchvision datasets pyarrow --quiet

# Verify installation
import zenith
zenith.info()

---
## 2. Download CIFAR-10 Dataset (~170MB)

In [None]:
from datasets import load_dataset
import pyarrow.parquet as pq
import pyarrow as pa
import numpy as np

print("Downloading CIFAR-10...")
dataset = load_dataset("cifar10", split="train")
print(f"Dataset size: {len(dataset)} images")

# Convert to Parquet for Zenith
print("\nConverting to Parquet format...")

# Extract images and labels
images = [np.array(img['img']).flatten().tobytes() for img in dataset]
labels = [img['label'] for img in dataset]

# Create Arrow table
table = pa.table({
    'image': images,
    'label': labels
})

# Save as Parquet
pq.write_table(table, 'cifar10_train.parquet')
print(f"Saved: cifar10_train.parquet")

import os
size_mb = os.path.getsize('cifar10_train.parquet') / (1024 * 1024)
print(f"File size: {size_mb:.1f} MB")

---
## 3. Define CNN Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    """Simple CNN for CIFAR-10 classification."""
    
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)
        self.dropout = nn.Dropout(0.25)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # 32x32 -> 16x16
        x = self.pool(F.relu(self.conv2(x)))  # 16x16 -> 8x8
        x = self.pool(F.relu(self.conv3(x)))  # 8x8 -> 4x4
        x = x.view(-1, 64 * 4 * 4)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

# Initialize model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SimpleCNN().to(device)
print(f"Model on: {device}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

---
## 4. Load Data with Zenith DataLoader

In [None]:
import zenith
import time

# Create Zenith DataLoader
loader = zenith.DataLoader(
    'cifar10_train.parquet',
    batch_size=64,
    shuffle=True,
    device='auto'  # Auto-detect GPU
)

print(f"DataLoader created: {loader}")
print(f"Device: {zenith.auto_device()}")

---
## 5. Training Loop with Zenith

In [None]:
import torch.optim as optim

# Setup training
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_epoch(loader, model, criterion, optimizer, device):
    """Train one epoch with Zenith DataLoader."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch in loader:
        # Zero-copy conversion to PyTorch tensors
        data = batch.to_numpy()
        
        # Reshape images: bytes -> (B, 3, 32, 32)
        images = np.array([np.frombuffer(img, dtype=np.uint8).reshape(32, 32, 3) 
                          for img in data['image']])
        images = torch.from_numpy(images).permute(0, 3, 1, 2).float() / 255.0
        labels = torch.from_numpy(data['label']).long()
        
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Stats
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return total_loss / total, 100. * correct / total

# Train for 3 epochs
print("Training with Zenith DataLoader...")
print("-" * 40)

zenith_times = []
for epoch in range(3):
    start = time.time()
    loss, acc = train_epoch(loader, model, criterion, optimizer, device)
    elapsed = time.time() - start
    zenith_times.append(elapsed)
    print(f"Epoch {epoch+1}: Loss={loss:.4f}, Acc={acc:.2f}%, Time={elapsed:.2f}s")

print("-" * 40)
print(f"Average time per epoch: {sum(zenith_times)/len(zenith_times):.2f}s")

---
## 6. Benchmark: Zenith vs PyTorch DataLoader

In [None]:
from torch.utils.data import DataLoader as TorchDataLoader, TensorDataset
import time

# Prepare PyTorch DataLoader with same data
print("Preparing PyTorch DataLoader...")

# Load data into memory for PyTorch
dataset_hf = load_dataset("cifar10", split="train[:5000]")  # Subset for fair comparison
images_pt = torch.stack([torch.from_numpy(np.array(x['img'])).permute(2,0,1).float()/255.0 
                         for x in dataset_hf])
labels_pt = torch.tensor([x['label'] for x in dataset_hf])

torch_loader = TorchDataLoader(
    TensorDataset(images_pt, labels_pt),
    batch_size=64,
    shuffle=True
)

# Benchmark PyTorch
print("\nBenchmarking data loading speed...")
print("=" * 50)

# PyTorch DataLoader
start = time.time()
for _ in range(3):
    for batch in torch_loader:
        _ = batch[0].to(device)
pytorch_time = time.time() - start

# Zenith DataLoader
zenith_loader = zenith.DataLoader('cifar10_train.parquet', batch_size=64)
start = time.time()
for _ in range(3):
    for batch in zenith_loader:
        _ = batch.to_numpy()
zenith_time = time.time() - start

print(f"PyTorch DataLoader: {pytorch_time:.3f}s")
print(f"Zenith DataLoader:  {zenith_time:.3f}s")
speedup = pytorch_time/zenith_time
print(f"\nZenith is {speedup:.1f}x faster")

---
## Summary

You've learned:
1. Install Zenith with `pip install zenith-ai`
2. Load data with `zenith.DataLoader()`
3. Zero-copy conversion with `batch.to_numpy()`
4. Train PyTorch models faster

### Next Steps
- Try with your own datasets
- Use GPU: `device="cuda"`
- Scale to larger datasets (S3, etc.)

**GitHub:** https://github.com/vibeswithkk/Zenith-dataplane