# Zenith vs PyTorch: 1GB Dataset Benchmark

Official benchmark using optimized zenith-ai package v0.3.2+

In [None]:
!pip install zenith-ai torch pyarrow --quiet
import zenith
zenith.info()

In [None]:
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader as TorchDataLoader, TensorDataset
import time
import os

# Config
NUM_SAMPLES = 100000
BATCH_SIZE = 256
EPOCHS = 3
device = zenith.auto_device()
print(f"Device: {device}")
print(f"Samples: {NUM_SAMPLES:,}")

In [None]:
# Generate 1GB synthetic data
# Store as flattened float32 arrays (consistent format)
print("Generating synthetic data...")
data_images = np.random.rand(NUM_SAMPLES, 3, 32, 32).astype(np.float32)
data_labels = np.random.randint(0, 10, NUM_SAMPLES).astype(np.int64)
print(f"Shape: {data_images.shape}, Size: {data_images.nbytes/1e9:.2f} GB")

# Save as Parquet with proper numeric columns
print("Saving Parquet...")
table = pa.table({
    'features': pa.array(data_images.reshape(NUM_SAMPLES, -1).tolist()),
    'label': data_labels
})
pq.write_table(table, 'data_1gb.parquet')
print(f"Saved: {os.path.getsize('data_1gb.parquet')/1e6:.0f} MB")

In [None]:
# Model
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.c1 = nn.Conv2d(3, 32, 3, padding=1)
        self.c2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(64*8*8, 256)
        self.fc2 = nn.Linear(256, 10)
    def forward(self, x):
        x = self.pool(F.relu(self.c1(x)))
        x = self.pool(F.relu(self.c2(x)))
        x = x.reshape(-1, 64*8*8)
        return self.fc2(F.relu(self.fc1(x)))

print("Model defined")

In [None]:
# ZENITH BENCHMARK (optimized path)
print("="*50)
print("ZENITH DATALOADER (Optimized with Prefetch)")
print("="*50)

zenith_loader = zenith.DataLoader(
    'data_1gb.parquet',
    batch_size=BATCH_SIZE,
    shuffle=True,
    device=device,
    prefetch_factor=4  # Prefetch 4 batches ahead
)

model = CNN().to(device)
opt = optim.Adam(model.parameters())
crit = nn.CrossEntropyLoss()

z_times = []
for ep in range(EPOCHS):
    model.train()
    t0 = time.time()
    total_loss = 0
    
    for batch in zenith_loader:
        # Use optimized to_torch() - single method call
        data = batch.to_torch()
        
        # Reshape features to images
        features = data['features']
        if features.dim() == 2:
            features = features.view(-1, 3, 32, 32)
        
        x = features.to(device)
        y = data['label'].to(device)
        
        opt.zero_grad()
        loss = crit(model(x), y)
        loss.backward()
        opt.step()
        total_loss += loss.item()
    
    z_times.append(time.time()-t0)
    print(f"Epoch {ep+1}: Loss={total_loss/len(zenith_loader):.4f}, Time={z_times[-1]:.2f}s")

z_avg = sum(z_times[1:])/len(z_times[1:])
print(f"\nZenith avg: {z_avg:.2f}s")

In [None]:
# PYTORCH BENCHMARK
print("\n" + "="*50)
print("PYTORCH DATALOADER")
print("="*50)

pt_loader = TorchDataLoader(
    TensorDataset(torch.from_numpy(data_images), torch.from_numpy(data_labels)),
    batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True
)

model = CNN().to(device)
opt = optim.Adam(model.parameters())

pt_times = []
for ep in range(EPOCHS):
    model.train()
    t0 = time.time()
    total_loss = 0
    
    for x, y in pt_loader:
        x, y = x.to(device), y.to(device)
        opt.zero_grad()
        loss = crit(model(x), y)
        loss.backward()
        opt.step()
        total_loss += loss.item()
    
    pt_times.append(time.time()-t0)
    print(f"Epoch {ep+1}: Loss={total_loss/len(pt_loader):.4f}, Time={pt_times[-1]:.2f}s")

pt_avg = sum(pt_times[1:])/len(pt_times[1:])
print(f"\nPyTorch avg: {pt_avg:.2f}s")

In [None]:
# RESULTS
print("\n" + "="*50)
print("BENCHMARK RESULTS - 1GB DATASET")
print("="*50)
print(f"Zenith version: {zenith.__version__}")
print(f"Dataset: {NUM_SAMPLES:,} samples (~1GB)")
print(f"Device: {device}")
print(f"Prefetch: 4 batches")
print("-"*50)
print(f"Zenith:  {z_avg:.2f}s per epoch")
print(f"PyTorch: {pt_avg:.2f}s per epoch")
print("-"*50)
if z_avg < pt_avg:
    print(f"Zenith is {pt_avg/z_avg:.2f}x faster")
else:
    print(f"PyTorch is {z_avg/pt_avg:.2f}x faster")