# Patch Embedding and Normalization Test Notebook

This notebook tests the patch embedding and normalization pipeline for CIFAR-100.

In [1]:
# 1. Import Required Libraries and Modules
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
import torch
import torchvision
from src.utils.preprocess import PatchEmbed, normalize_patches
from src.data.transforms import get_default_transforms

## 2. Load CIFAR-100 Batch and Apply Patch Embedding Transform

In [2]:
transform = get_default_transforms("CIFAR100", img_size=32)
cifar100_train = torchvision.datasets.CIFAR100(root="./datasets", train=True, transform=transform, download=True)
loader = torch.utils.data.DataLoader(cifar100_train, batch_size=8, shuffle=False)
patches, labels = next(iter(loader))
print("Patches shape:", patches.shape)
print("Labels shape:", labels.shape)

100%|██████████| 169M/169M [00:59<00:00, 2.86MB/s] 



Patches shape: torch.Size([8, 64, 48])
Labels shape: torch.Size([8])


## 3. Verify Output Shape

In [3]:
expected_num_patches = (32 // 4) ** 2  # 8x8 = 64
expected_embed_dim = 48
assert patches.shape[1] == expected_num_patches, f"Expected num_patches {expected_num_patches}, got {patches.shape[1]}"
assert patches.shape[2] == expected_embed_dim, f"Expected embed_dim {expected_embed_dim}, got {patches.shape[2]}"
print(f"Output shape verified: {patches.shape} (batch, num_patches, embed_dim)")

Output shape verified: torch.Size([8, 64, 48]) (batch, num_patches, embed_dim)
