In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

import torchvision
import torchvision.transforms as transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, Callback

import matplotlib.pyplot as plt
import seaborn as sns

import wandb 
from pytorch_lightning.loggers import WandbLogger
from kaggle_secrets import UserSecretsClient

## Initialization

This init_wandb securely connects to Weights & Biases using Kaggle secrets storage. It configures project metadata and API keys to enable experiment tracking. 

In [2]:
def init_wandb(project_name="vision-transformer-cifar10", config=None):
    try:
        user_secrets = UserSecretsClient()
        
        wandb_api_key = user_secrets.get_secret("wandb")
        os.environ['WANDB_API_KEY'] = wandb_api_key
        
        wandb.login(key=wandb_api_key)
        
        run = wandb.init(
            project=project_name,
            config=config,
            tags=["ViT", "CIFAR-10"],
            notes="Vision Transformer implementation on CIFAR-10"
        )
        
        print("✅ W&B successfully initialized")
        return run
    
    except Exception as e:
        print(f"❌ Error initializing W&B: {str(e)}")
        return None

This function fixes random seeds for PyTorch, NumPy, and Lightning while configuring CUDA to prioritize reproducibility over performance.

In [3]:
def set_seed(seed: int = 42):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    np.random.seed(seed)
    pl.seed_everything(seed)

This function converts input images into sequences of flattened patches. For 32x32 CIFAR-10 images with patch_size=4, this creates 64 non-overlapping 16-dimension vectors. This restructuring adapts images for transformer-style sequence processing.

In [4]:
def img_to_patch(x, patch_size, flatten_channels=True):
    B, C, H, W = x.shape
    x = x.reshape(B, C, H//patch_size, patch_size, W//patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5)
    x = x.flatten(1, 2)
    
    if flatten_channels:
        x = x.flatten(2, 4)
    
    return x

## AttentionBlock

AttentionBlock implements multi-head self-attention with residual connections. Layer normalization precedes both attention and feed-forward operations. The hidden dimension expansion (embed_dim → hidden_dim → embed_dim) in the MLP provides non-linear transformation capacity.

In [5]:
class AttentionBlock(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        super().__init__()
        
        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        return x

## VisionTransformer

VisionTransformer handles end-to-end processing:

- Linear projection of flattened patches
- Prepending [CLS] token
- Adding learnable positional embeddings
- Processing through transformer layers
- Classifying via [CLS] token

In [6]:
class VisionTransformer(nn.Module):
    def __init__(self, 
                 embed_dim=256, 
                 hidden_dim=512, 
                 num_channels=3, 
                 num_heads=8, 
                 num_layers=6, 
                 num_classes=10, 
                 patch_size=4, 
                 num_patches=64, 
                 dropout=0.2):
        super().__init__()
        
        self.patch_size = patch_size
        
        self.input_layer = nn.Linear(num_channels * (patch_size**2), embed_dim)
        
        self.transformer = nn.Sequential(
            *[AttentionBlock(embed_dim, hidden_dim, num_heads, dropout) 
              for _ in range(num_layers)]
        )
        
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )
        
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, 1+num_patches, embed_dim))
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = img_to_patch(x, self.patch_size)
        B, T, _ = x.shape
        x = self.input_layer(x)
        
        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embedding[:,:T+1]
        
        x = self.dropout(x)
        x = x.transpose(0, 1)
        x = self.transformer(x)
        
        cls = x[0]
        out = self.mlp_head(cls)
        return out

## ViTLightning

ViTLightning standardizes training via:

- Cross-entropy loss calculation
- Accuracy metric tracking
- AdamW optimizer with cosine LR decay
- Automatic hyperparameter logging
- Encapsulates all training logic while remaining configurable through model_kwargs.

In [7]:
class ViTLightning(pl.LightningModule):
    def __init__(self, model_kwargs, lr=3e-4):
        super().__init__()
        self.save_hyperparameters()
        self.model = VisionTransformer(**model_kwargs)
        self.lr = lr
        
        if self.logger:
            self.logger.log_hyperparams(
                params={
                    **model_kwargs, 
                    'learning_rate': lr
                }
            )

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self(imgs)
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()
        
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self(imgs)
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()
        
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self(imgs)
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()
        
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), 
            lr=self.lr, 
            weight_decay=1e-4
        )
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 
            T_max=180 
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch',
                'frequency': 1
            }
        }

## Callback

This function logs sample predictions post-validation. It denormalizes images, compares predictions with ground truth, and visualizes results in W&B.

In [8]:
class WandbVisualizationCallback(pl.Callback):
    def __init__(self, val_loader, classes):
        super().__init__()
        self.val_loader = val_loader
        self.classes = classes

    def on_validation_epoch_end(self, trainer, pl_module):
        pl_module.eval()
        with torch.no_grad():
            images, labels = next(iter(self.val_loader))
            images = images.to(pl_module.device)
            labels = labels.to(pl_module.device)

            preds = pl_module(images)
            pred_labels = torch.argmax(preds, dim=1)

            mean = torch.tensor([0.49139968, 0.48215841, 0.44653091])
            std = torch.tensor([0.24703223, 0.24348513, 0.26158784])
            images = images * std.view(1, 3, 1, 1).to(images.device) + mean.view(1, 3, 1, 1).to(images.device)
            images = torch.clamp(images, 0, 1)

            prediction_images = []
            for idx in range(min(16, len(images))):
                img = images[idx].cpu().permute(1, 2, 0).numpy()
                true_label = self.classes[labels[idx].item()]
                pred_label = self.classes[pred_labels[idx].item()]
                
                plt.figure(figsize=(4,4))
                plt.imshow(img)
                plt.title(f'True: {true_label}\nPred: {pred_label}')
                plt.axis('off')
                
                prediction_image = wandb.Image(img, caption=f'Epoch {trainer.current_epoch}: True {true_label}, Pred {pred_label}')
                prediction_images.append(prediction_image)
                plt.close()

            trainer.logger.experiment.log({
                "predictions": prediction_images,
                "epoch": trainer.current_epoch
            })

## Data Visualization

Displays 10 random training samples. Serves as sanity check for data loading and augmentation pipeline

In [9]:
def visualize_cifar10_dataset(train_loader):
    plt.figure(figsize=(15, 10))
    classes = train_loader.dataset.dataset.classes

    for i in range(10):
        plt.subplot(2, 5, i+1)
        img, label = train_loader.dataset[i]
        plt.imshow(img.permute(1, 2, 0))
        plt.title(classes[label])
        plt.axis('off')

    plt.suptitle('CIFAR-10 Dataset Samples')
    plt.tight_layout()
    plt.show()

In [10]:
def inference_and_visualize(model, test_loader, classes):
    model.eval()
    plt.figure(figsize=(15, 10))
    
    with torch.no_grad():
        for i in range(10):
            plt.subplot(2, 5, i+1)
            img, true_label = test_loader.dataset[i]
            img = img.unsqueeze(0).to(model.device)
            
            pred = model(img)
            pred_label = torch.argmax(pred, dim=1).item()
            
            plt.imshow(img.squeeze().cpu().permute(1, 2, 0))
            plt.title(f'True: {classes[true_label]}\nPred: {classes[pred_label]}')
            plt.axis('off')
    
    plt.suptitle('Inference Predictions')
    plt.tight_layout()
    plt.show()

- Separate transforms for train (augmented) vs test
- 45k/5k train/val split via random sampling
- Shuffled training loader with pinned memory
- 128 batch size balancing memory/throughput
- Normalization parameters match CIFAR-10 channel statistics.

In [11]:
def prepare_cifar10_data(batch_size=128):
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop((32, 32), scale=(0.8, 1.0), ratio=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize([0.49139968, 0.48215841, 0.44653091], 
                             [0.24703223, 0.24348513, 0.26158784])
    ])
    
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.49139968, 0.48215841, 0.44653091], 
                             [0.24703223, 0.24348513, 0.26158784])
    ])
    
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, 
                                                 download=True, transform=train_transform)
    val_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, 
                                               download=True, transform=test_transform)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, 
                                                download=True, transform=test_transform)
    
    set_seed(42)
    train_set, _ = random_split(train_dataset, [45000, 5000])
    _, val_set = random_split(val_dataset, [45000, 5000])
    
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, 
                              num_workers=4, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, 
                            num_workers=4, drop_last=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, 
                             num_workers=4, drop_last=False)
    
    return train_loader, val_loader, test_loader

## Training Loop

In [12]:
def train_vision_transformer(model_kwargs, lr=3e-4, max_epochs=180):
    set_seed(42)
    
    train_loader, val_loader, test_loader = prepare_cifar10_data()
    
    classes = train_loader.dataset.dataset.classes
    
    wandb_config = {
        **model_kwargs,
        'learning_rate': lr,
        'max_epochs': max_epochs,
        'batch_size': train_loader.batch_size
    }
    run = init_wandb(config=wandb_config)
    
    wandb_logger = WandbLogger(
        project='vision-transformer-cifar10',
        config=wandb_config
    )
    
    trainer = pl.Trainer(
        default_root_dir='./checkpoints/vit',
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=1,
        max_epochs=max_epochs,
        logger=wandb_logger,
        callbacks=[
            ModelCheckpoint(save_weights_only=True, mode='max', monitor='val_acc'),
            LearningRateMonitor('epoch'),
            WandbVisualizationCallback(val_loader, classes)
        ]
    )
    
    model = ViTLightning(model_kwargs, lr=lr)
    
    trainer.fit(model, train_loader, val_loader)
    
    test_results = trainer.test(model, test_loader)
    
    wandb.finish()
    
    return model, test_results

In [13]:
model_kwargs = {
    'embed_dim': 256,
    'hidden_dim': 512,
    'num_heads': 8,
    'num_layers': 6,
    'patch_size': 4,
    'num_channels': 3,
    'num_patches': 64,
    'num_classes': 10,
    'dropout': 0.2
}

In [14]:
model, results = train_vision_transformer(model_kwargs)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 43.3MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33meva-koroleva[0m ([33mml-samurai[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Tracking run with wandb version 0.19.1
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20250325_233702-amlggrm1[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33melectric-dream-4[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/ml-samurai/vision-transformer-cifar10[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/ml-samurai/vision-transformer-cifar10/runs/amlggrm1[0m


✅ W&B successfully initialized


/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:               epoch ▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇█
[34m[1mwandb[0m:            lr-AdamW █████████▇▇▇▇▇▇▆▆▆▅▅▄▄▄▃▃▃▃▃▂▂▂▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:            test_acc ▁
[34m[1mwandb[0m:           test_loss ▁
[34m[1mwandb[0m:           train_acc ▁▁▃▄▄▅▅▆▆▇▅▇▆▆▆▇▆▇▇▇█▇▇▇████████████████
[34m[1mwandb[0m:          train_loss █▇▇▇▆▅▅▅▅▅▃▃▃▃▃▃▃▃▃▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m: trainer/global_step ▁▁▁▁▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇██
[34m[1mwandb[0m:             val_acc ▁▂▃▄▅▅▆▆▆▆▇▇▇▇▇▇████████████████████████
[34m[1mwandb[0m:            val_loss █▇▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:               epoch 180
[34m[1mwandb[0m:            lr-AdamW 0.0
[34m[1mwandb[0m:            test_acc 0.764
[34m[