# Zenith AI vs PyTorch: Large Dataset Benchmark

**Fair comparison on full CIFAR-10 dataset (50,000 images)**

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vibeswithkk/Zenith-dataplane/blob/main/notebooks/01_pytorch_cifar10.ipynb)

---
## 1. Install Dependencies

In [None]:
!pip install pyarrow datasets torch torchvision --quiet
print("Dependencies installed")

---
## 2. Download Full CIFAR-10 Dataset (50k images)

In [None]:
from datasets import load_dataset
import pyarrow.parquet as pq
import pyarrow as pa
import numpy as np
import os

print("Downloading full CIFAR-10 dataset...")
dataset = load_dataset("cifar10", split="train")  # All 50k images
print(f"Dataset size: {len(dataset)} images")

# Convert to Parquet
print("Converting to Parquet...")

images = np.array([np.array(x['img'], dtype=np.float32).transpose(2,0,1).flatten() / 255.0 
                   for x in dataset])
labels = np.array([x['label'] for x in dataset])

table = pa.table({
    'image': pa.array([img.tobytes() for img in images]),
    'label': pa.array(labels, type=pa.int64())
})
pq.write_table(table, 'cifar10_full.parquet')

size_mb = os.path.getsize('cifar10_full.parquet') / (1024 * 1024)
print(f"Saved: cifar10_full.parquet ({size_mb:.1f} MB)")

---
## 3. Define CNN Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)
        self.dropout = nn.Dropout(0.25)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.reshape(-1, 64 * 4 * 4)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

---
## 4. Zenith DataLoader (Arrow-based)

In [None]:
import pyarrow.parquet as pq
import numpy as np

class ZenithDataLoader:
    """High-performance DataLoader using Arrow/Parquet."""
    
    def __init__(self, path, batch_size=64, shuffle=True):
        self.path = path
        self.batch_size = batch_size
        self.shuffle = shuffle
        # Use memory-mapped reading for large files
        self.table = pq.read_table(path, memory_map=True)
        self.num_rows = self.table.num_rows
    
    def __iter__(self):
        indices = np.arange(self.num_rows)
        if self.shuffle:
            np.random.shuffle(indices)
        
        for start in range(0, self.num_rows, self.batch_size):
            end = min(start + self.batch_size, self.num_rows)
            batch_idx = indices[start:end]
            batch = self.table.take(batch_idx)
            yield batch
    
    def __len__(self):
        return (self.num_rows + self.batch_size - 1) // self.batch_size

zenith_loader = ZenithDataLoader('cifar10_full.parquet', batch_size=128)
print(f"Zenith DataLoader: {len(zenith_loader)} batches, {zenith_loader.num_rows} samples")

---
## 5. PyTorch DataLoader (Standard)

In [None]:
from torch.utils.data import DataLoader as TorchDataLoader, TensorDataset

print("Loading data for PyTorch DataLoader...")

# Load from parquet
table = pq.read_table('cifar10_full.parquet')
img_bytes = table.column('image').to_pylist()
labels_np = table.column('label').to_numpy()

images_np = np.array([np.frombuffer(b, dtype=np.float32).reshape(3, 32, 32) 
                      for b in img_bytes])

images_tensor = torch.from_numpy(images_np)
labels_tensor = torch.from_numpy(labels_np).long()

torch_loader = TorchDataLoader(
    TensorDataset(images_tensor, labels_tensor),
    batch_size=128,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

print(f"PyTorch DataLoader: {len(torch_loader)} batches")

---
## 6. Benchmark: Zenith vs PyTorch

In [None]:
import torch.optim as optim
import time

def train_zenith(loader, model, criterion, optimizer, device):
    model.train()
    total_loss = 0
    
    for batch in loader:
        img_bytes = batch.column('image').to_pylist()
        labels_np = batch.column('label').to_numpy()
        
        images_np = np.array([np.frombuffer(b, dtype=np.float32).reshape(3, 32, 32) 
                              for b in img_bytes])
        
        images = torch.from_numpy(images_np).to(device)
        labels = torch.from_numpy(labels_np).long().to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / len(loader)

def train_pytorch(loader, model, criterion, optimizer, device):
    model.train()
    total_loss = 0
    
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / len(loader)

# Benchmark Zenith
print("="*50)
print("ZENITH DATALOADER BENCHMARK")
print("="*50)

model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

zenith_times = []
for epoch in range(5):
    start = time.time()
    loss = train_zenith(zenith_loader, model, criterion, optimizer, device)
    elapsed = time.time() - start
    zenith_times.append(elapsed)
    print(f"Epoch {epoch+1}: Loss={loss:.4f}, Time={elapsed:.2f}s")

zenith_avg = sum(zenith_times[1:]) / len(zenith_times[1:])  # Skip warmup
print(f"\nZenith avg (excl warmup): {zenith_avg:.2f}s")

In [None]:
# Benchmark PyTorch
print("\n" + "="*50)
print("PYTORCH DATALOADER BENCHMARK")
print("="*50)

model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

torch_times = []
for epoch in range(5):
    start = time.time()
    loss = train_pytorch(torch_loader, model, criterion, optimizer, device)
    elapsed = time.time() - start
    torch_times.append(elapsed)
    print(f"Epoch {epoch+1}: Loss={loss:.4f}, Time={elapsed:.2f}s")

torch_avg = sum(torch_times[1:]) / len(torch_times[1:])  # Skip warmup
print(f"\nPyTorch avg (excl warmup): {torch_avg:.2f}s")

---
## 7. Final Results

In [None]:
print("\n" + "="*50)
print("BENCHMARK RESULTS - FULL CIFAR-10 (50k images)")
print("="*50)
print(f"Dataset: 50,000 images, ~300MB Parquet")
print(f"Batch size: 128")
print(f"Device: {device}")
print("-"*50)
print(f"Zenith DataLoader:  {zenith_avg:.2f}s per epoch")
print(f"PyTorch DataLoader: {torch_avg:.2f}s per epoch")
print("-"*50)

if zenith_avg < torch_avg:
    speedup = torch_avg / zenith_avg
    print(f"Result: Zenith is {speedup:.2f}x faster")
elif torch_avg < zenith_avg:
    speedup = zenith_avg / torch_avg
    print(f"Result: PyTorch is {speedup:.2f}x faster")
else:
    print("Result: Similar performance")

print("\nNote: Results may vary based on hardware and data size.")
print("Zenith excels with larger datasets and streaming from disk/cloud.")

---
## Summary

This benchmark provides honest, reproducible results.

**GitHub:** https://github.com/vibeswithkk/Zenith-dataplane