# Zenith AI + PyTorch: Image Classification

**Train a CNN on CIFAR-10 with Zenith's high-performance DataLoader**

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vibeswithkk/Zenith-dataplane/blob/main/notebooks/01_pytorch_cifar10.ipynb)

---
## 1. Install Dependencies

In [None]:
!pip install pyarrow datasets torch torchvision --quiet
print("Dependencies installed")

---
## 2. Download CIFAR-10 Dataset

In [None]:
from datasets import load_dataset
import pyarrow.parquet as pq
import pyarrow as pa
import numpy as np

print("Downloading CIFAR-10 (subset for demo)...")
dataset = load_dataset("cifar10", split="train[:10000]")
print(f"Dataset size: {len(dataset)} images")

# Convert to Parquet
print("Converting to Parquet...")

# Store as float arrays directly
images = np.array([np.array(x['img'], dtype=np.float32).transpose(2,0,1).flatten() / 255.0 
                   for x in dataset])
labels = np.array([x['label'] for x in dataset])

# Save as Parquet with proper column types
table = pa.table({
    'image': pa.array([img.tobytes() for img in images]),
    'label': pa.array(labels, type=pa.int64())
})
pq.write_table(table, 'cifar10_train.parquet')

import os
size_mb = os.path.getsize('cifar10_train.parquet') / (1024 * 1024)
print(f"Saved: cifar10_train.parquet ({size_mb:.1f} MB)")

---
## 3. Define CNN Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)
        self.dropout = nn.Dropout(0.25)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.reshape(-1, 64 * 4 * 4)  # Use reshape instead of view
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SimpleCNN().to(device)
print(f"Device: {device}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

---
## 4. Zenith DataLoader (Pure Python)

In [None]:
import pyarrow.parquet as pq
import numpy as np

class ZenithDataLoader:
    """High-performance DataLoader using Arrow/Parquet."""
    
    def __init__(self, path, batch_size=64, shuffle=True):
        self.table = pq.read_table(path)
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_rows = self.table.num_rows
    
    def __iter__(self):
        indices = np.arange(self.num_rows)
        if self.shuffle:
            np.random.shuffle(indices)
        
        for start in range(0, self.num_rows, self.batch_size):
            end = min(start + self.batch_size, self.num_rows)
            batch_idx = indices[start:end]
            batch = self.table.take(batch_idx)
            yield batch
    
    def __len__(self):
        return (self.num_rows + self.batch_size - 1) // self.batch_size

loader = ZenithDataLoader('cifar10_train.parquet', batch_size=64)
print(f"DataLoader: {len(loader)} batches")

---
## 5. Training Loop

In [None]:
import torch.optim as optim
import time

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_epoch(loader, model, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch in loader:
        # Convert Arrow to numpy
        img_bytes = batch.column('image').to_pylist()
        labels_np = batch.column('label').to_numpy()
        
        # Reconstruct images from bytes
        images_np = np.array([np.frombuffer(b, dtype=np.float32).reshape(3, 32, 32) 
                              for b in img_bytes])
        
        images = torch.from_numpy(images_np).to(device)
        labels = torch.from_numpy(labels_np).long().to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return total_loss / len(loader), 100. * correct / total

print("Training with Zenith DataLoader...")
print("-" * 40)

zenith_times = []
for epoch in range(3):
    start = time.time()
    loss, acc = train_epoch(loader, model, criterion, optimizer, device)
    elapsed = time.time() - start
    zenith_times.append(elapsed)
    print(f"Epoch {epoch+1}: Loss={loss:.4f}, Acc={acc:.2f}%, Time={elapsed:.2f}s")

zenith_avg = sum(zenith_times) / len(zenith_times)
print("-" * 40)
print(f"Zenith avg time: {zenith_avg:.2f}s per epoch")

---
## 6. Benchmark vs PyTorch DataLoader

In [None]:
from torch.utils.data import DataLoader as TorchDataLoader, TensorDataset

print("Loading data for PyTorch DataLoader...")

# Load same data for PyTorch
table = pq.read_table('cifar10_train.parquet')
img_bytes = table.column('image').to_pylist()
labels_np = table.column('label').to_numpy()

images_np = np.array([np.frombuffer(b, dtype=np.float32).reshape(3, 32, 32) 
                      for b in img_bytes])

images_tensor = torch.from_numpy(images_np)
labels_tensor = torch.from_numpy(labels_np).long()

torch_loader = TorchDataLoader(
    TensorDataset(images_tensor, labels_tensor),
    batch_size=64,
    shuffle=True
)

# Reset model
model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_epoch_torch(loader, model, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return total_loss / len(loader), 100. * correct / total

print("\nTraining with PyTorch DataLoader...")
print("-" * 40)

torch_times = []
for epoch in range(3):
    start = time.time()
    loss, acc = train_epoch_torch(torch_loader, model, criterion, optimizer, device)
    elapsed = time.time() - start
    torch_times.append(elapsed)
    print(f"Epoch {epoch+1}: Loss={loss:.4f}, Acc={acc:.2f}%, Time={elapsed:.2f}s")

torch_avg = sum(torch_times) / len(torch_times)
print("-" * 40)
print(f"PyTorch avg time: {torch_avg:.2f}s per epoch")

---
## 7. Results Comparison

In [None]:
print("="*50)
print("BENCHMARK RESULTS")
print("="*50)
print(f"Zenith DataLoader:  {zenith_avg:.2f}s per epoch")
print(f"PyTorch DataLoader: {torch_avg:.2f}s per epoch")
print()

if zenith_avg < torch_avg:
    speedup = torch_avg / zenith_avg
    print(f"Zenith is {speedup:.2f}x faster")
else:
    slowdown = zenith_avg / torch_avg
    print(f"PyTorch is {slowdown:.2f}x faster")
    print("(Note: For small datasets, overhead may dominate)")

---
## Summary

This notebook demonstrated:
1. Loading CIFAR-10 data with Arrow/Parquet
2. Training a CNN with Zenith DataLoader
3. Fair benchmark comparison

**GitHub:** https://github.com/vibeswithkk/Zenith-dataplane