In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, models
from PIL import Image
import wandb
import numpy as np
from tqdm import tqdm
import random

class Config:
    # Dataset paths
    DATA_DIR = "/tiny-imagenet-200"

    # Reduce dataset size (set to None to use full dataset)
    TRAIN_SUBSET_SIZE = 20000 
    VAL_SUBSET_SIZE = 2000    
    TEST_SUBSET_SIZE = 2000  

    # Training hyperparameters
    BATCH_SIZE = 64
    NUM_EPOCHS = 10
    LEARNING_RATE = 0.001
    NUM_WORKERS = 4
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Model
    MODEL_NAME = "resnet18"  
    NUM_CLASSES = 200
    PRETRAINED = True

    # W&B
    WANDB_PROJECT = "tiny-imagenet-assignment"
    WANDB_RUN_NAME = "resnet18-baseline"

class TinyImageNetDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        """
        Args:
            root_dir: Root directory of tiny-imagenet-200
            split: 'train', 'val', or 'test'
            transform: Transformations to apply
        """
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.samples = []
        self.class_to_idx = {}

        if split == 'train':
            self._load_train_data()
        elif split == 'val':
            self._load_val_data()
        elif split == 'test':
            self._load_test_data()

    def _load_train_data(self):
        train_dir = os.path.join(self.root_dir, 'train')

        # Check if train directory exists
        if not os.path.exists(train_dir):
            raise FileNotFoundError(f"Train directory not found: {train_dir}")

        classes = sorted([d for d in os.listdir(train_dir)
                         if os.path.isdir(os.path.join(train_dir, d))])

        print(f"Found {len(classes)} classes in training set")

        for idx, class_name in enumerate(classes):
            self.class_to_idx[class_name] = idx
            class_dir = os.path.join(train_dir, class_name, 'images')

            # Try without 'images' subfolder if it doesn't exist
            if not os.path.exists(class_dir):
                class_dir = os.path.join(train_dir, class_name)

            if os.path.exists(class_dir):
                for img_name in os.listdir(class_dir):
                    if img_name.lower().endswith(('.jpeg', '.jpg', '.png')):
                        img_path = os.path.join(class_dir, img_name)
                        self.samples.append((img_path, idx))

    def _load_val_data(self):
        val_dir = os.path.join(self.root_dir, 'val')

        # Check if val directory exists
        if not os.path.exists(val_dir):
            raise FileNotFoundError(f"Val directory not found: {val_dir}")

        # Load class mapping from train
        train_dir = os.path.join(self.root_dir, 'train')
        classes = sorted([d for d in os.listdir(train_dir)
                         if os.path.isdir(os.path.join(train_dir, d))])
        for idx, class_name in enumerate(classes):
            self.class_to_idx[class_name] = idx

        # Check if val has the same structure as train (class folders)
        val_classes = [d for d in os.listdir(val_dir)
                      if os.path.isdir(os.path.join(val_dir, d)) and d in self.class_to_idx]

        if len(val_classes) > 0:
            # Val has class folder structure (like train)
            print(f"Val set has class folder structure with {len(val_classes)} classes")
            for class_name in val_classes:
                class_idx = self.class_to_idx[class_name]
                class_dir = os.path.join(val_dir, class_name, 'images')

                # Try without 'images' subfolder if it doesn't exist
                if not os.path.exists(class_dir):
                    class_dir = os.path.join(val_dir, class_name)

                if os.path.exists(class_dir):
                    for img_name in os.listdir(class_dir):
                        if img_name.lower().endswith(('.jpeg', '.jpg', '.png')):
                            img_path = os.path.join(class_dir, img_name)
                            self.samples.append((img_path, class_idx))
        else:
            # Try val_annotations.txt format
            val_annotations = os.path.join(val_dir, 'val_annotations.txt')

            if os.path.exists(val_annotations):
                print("Using val_annotations.txt")
                with open(val_annotations, 'r') as f:
                    for line in f:
                        parts = line.strip().split('\t')
                        img_name = parts[0]
                        class_name = parts[1]
                        img_path = os.path.join(val_dir, 'images', img_name)

                        if not os.path.exists(img_path):
                            img_path = os.path.join(val_dir, img_name)

                        if os.path.exists(img_path) and class_name in self.class_to_idx:
                            self.samples.append((img_path, self.class_to_idx[class_name]))
            else:
                # Just load all images from val directory without labels
                print("Warning: No class structure or annotations found. Loading images without proper labels.")
                print("Using train class mapping for available val class folders")

                # Try to load from images subfolder or directly from val
                images_dir = os.path.join(val_dir, 'images')
                if not os.path.exists(images_dir):
                    images_dir = val_dir

                for img_name in os.listdir(images_dir):
                    if img_name.lower().endswith(('.jpeg', '.jpg', '.png')):
                        img_path = os.path.join(images_dir, img_name)
                        # Assign label 0 as placeholder (not ideal but allows loading)
                        self.samples.append((img_path, 0))

    def _load_test_data(self):
        test_dir = os.path.join(self.root_dir, 'test')

        # Check if test directory exists
        if not os.path.exists(test_dir):
            print("Test directory not found, using val as test")
            self._load_val_data()
            return

        # Load class mapping from train
        train_dir = os.path.join(self.root_dir, 'train')
        classes = sorted([d for d in os.listdir(train_dir)
                         if os.path.isdir(os.path.join(train_dir, d))])
        for idx, class_name in enumerate(classes):
            self.class_to_idx[class_name] = idx

        # Check if test has class folder structure
        test_classes = [d for d in os.listdir(test_dir)
                       if os.path.isdir(os.path.join(test_dir, d)) and d in self.class_to_idx]

        if len(test_classes) > 0:
            # Test has class folder structure
            print(f"Test set has class folder structure with {len(test_classes)} classes")
            for class_name in test_classes:
                class_idx = self.class_to_idx[class_name]
                class_dir = os.path.join(test_dir, class_name, 'images')

                if not os.path.exists(class_dir):
                    class_dir = os.path.join(test_dir, class_name)

                if os.path.exists(class_dir):
                    for img_name in os.listdir(class_dir):
                        if img_name.lower().endswith(('.jpeg', '.jpg', '.png')):
                            img_path = os.path.join(class_dir, img_name)
                            self.samples.append((img_path, class_idx))
        else:
            # Load images without labels (typical test set)
            images_dir = os.path.join(test_dir, 'images')
            if not os.path.exists(images_dir):
                images_dir = test_dir

            for img_name in os.listdir(images_dir):
                if img_name.lower().endswith(('.jpeg', '.jpg', '.png')):
                    img_path = os.path.join(images_dir, img_name)
                    self.samples.append((img_path, 0))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

# ===========================
# Data Transforms
# ===========================
def get_transforms():
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])

    val_test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])

    return train_transform, val_test_transform

# ===========================
# Model Setup
# ===========================
def get_model(model_name, num_classes, pretrained=True):
    if model_name == 'resnet18':
        model = models.resnet18(pretrained=pretrained)
    elif model_name == 'resnet34':
        model = models.resnet34(pretrained=pretrained)
    elif model_name == 'resnet50':
        model = models.resnet50(pretrained=pretrained)
    else:
        raise ValueError(f"Model {model_name} not supported")

    # Modify final layer for Tiny ImageNet (200 classes)
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, num_classes)

    return model

# ===========================
# Training Functions
# ===========================
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(dataloader, desc='Training')
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        pbar.set_postfix({'loss': loss.item(), 'acc': 100 * correct / total})

    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total

    return epoch_loss, epoch_acc

def evaluate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        pbar = tqdm(dataloader, desc='Evaluating')
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            pbar.set_postfix({'loss': loss.item(), 'acc': 100 * correct / total})

    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total

    return epoch_loss, epoch_acc

# ===========================
# Main Training Function
# ===========================
def main():
    config = Config()

    # Initialize W&B
    wandb.init(
        project=config.WANDB_PROJECT,
        name=config.WANDB_RUN_NAME,
        config={
            "model": config.MODEL_NAME,
            "batch_size": config.BATCH_SIZE,
            "epochs": config.NUM_EPOCHS,
            "learning_rate": config.LEARNING_RATE,
            "pretrained": config.PRETRAINED,
            "train_subset_size": config.TRAIN_SUBSET_SIZE,
            "val_subset_size": config.VAL_SUBSET_SIZE,
            "test_subset_size": config.TEST_SUBSET_SIZE,
        }
    )

    print(f"Using device: {config.DEVICE}")

    # Get transforms
    train_transform, val_test_transform = get_transforms()

    # Load datasets
    print("Loading datasets...")
    train_dataset = TinyImageNetDataset(
        config.DATA_DIR, split='train', transform=train_transform
    )
    val_dataset = TinyImageNetDataset(
        config.DATA_DIR, split='val', transform=val_test_transform
    )
    test_dataset = TinyImageNetDataset(
        config.DATA_DIR, split='test', transform=val_test_transform
    )

    # Create subsets to reduce dataset size
    if config.TRAIN_SUBSET_SIZE:
        train_indices = random.sample(range(len(train_dataset)),
                                     min(config.TRAIN_SUBSET_SIZE, len(train_dataset)))
        train_dataset = Subset(train_dataset, train_indices)

    if config.VAL_SUBSET_SIZE:
        val_indices = random.sample(range(len(val_dataset)),
                                   min(config.VAL_SUBSET_SIZE, len(val_dataset)))
        val_dataset = Subset(val_dataset, val_indices)

    if config.TEST_SUBSET_SIZE:
        test_indices = random.sample(range(len(test_dataset)),
                                    min(config.TEST_SUBSET_SIZE, len(test_dataset)))
        test_dataset = Subset(test_dataset, test_indices)

    print(f"Train size: {len(train_dataset)}")
    print(f"Val size: {len(val_dataset)}")
    print(f"Test size: {len(test_dataset)}")

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, batch_size=config.BATCH_SIZE,
        shuffle=True, num_workers=config.NUM_WORKERS
    )
    val_loader = DataLoader(
        val_dataset, batch_size=config.BATCH_SIZE,
        shuffle=False, num_workers=config.NUM_WORKERS
    )
    test_loader = DataLoader(
        test_dataset, batch_size=config.BATCH_SIZE,
        shuffle=False, num_workers=config.NUM_WORKERS
    )

    # Initialize model
    print(f"Initializing {config.MODEL_NAME}...")
    model = get_model(config.MODEL_NAME, config.NUM_CLASSES, config.PRETRAINED)
    model = model.to(config.DEVICE)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

    # Training loop
    best_val_acc = 0.0

    for epoch in range(config.NUM_EPOCHS):
        print(f"\nEpoch {epoch+1}/{config.NUM_EPOCHS}")

        # Train
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, config.DEVICE
        )

        # Validate
        val_loss, val_acc = evaluate(
            model, val_loader, criterion, config.DEVICE
        )

        # Test
        test_loss, test_acc = evaluate(
            model, test_loader, criterion, config.DEVICE
        )

        # Update learning rate
        scheduler.step()

        # Log to W&B
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "train_acc": train_acc,
            "val_loss": val_loss,
            "val_acc": val_acc,
            "test_loss": test_loss,
            "test_acc": test_acc,
            "learning_rate": optimizer.param_groups[0]['lr']
        })

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
            }, 'best_model.pth')

            # Log model to W&B
            wandb.save('best_model.pth')

    print(f"\nTraining completed! Best validation accuracy: {best_val_acc:.2f}%")

    # Log final model as artifact
    artifact = wandb.Artifact(
        name=f"{config.MODEL_NAME}-tiny-imagenet",
        type="model",
        description=f"Trained {config.MODEL_NAME} on Tiny ImageNet"
    )
    artifact.add_file('best_model.pth')
    wandb.log_artifact(artifact)

    wandb.finish()

if __name__ == "__main__":
    main()

Using device: cuda
Loading datasets...
Found 200 classes in training set
Val set has class folder structure with 200 classes
Train size: 20000
Val size: 2000
Test size: 2000
Initializing resnet18...
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 194MB/s]



Epoch 1/10


Training: 100%|██████████| 313/313 [32:58<00:00,  6.32s/it, loss=4.49, acc=5.63]
Evaluating: 100%|██████████| 32/32 [03:17<00:00,  6.18s/it, loss=4.63, acc=9.75]
Evaluating: 100%|██████████| 32/32 [03:07<00:00,  5.87s/it, loss=7.34, acc=0.3]


Train Loss: 4.7263, Train Acc: 5.63%
Val Loss: 4.2179, Val Acc: 9.75%
Test Loss: 7.8573, Test Acc: 0.30%

Epoch 2/10


Training: 100%|██████████| 313/313 [02:07<00:00,  2.45it/s, loss=3.85, acc=12.3]
Evaluating: 100%|██████████| 32/32 [00:09<00:00,  3.36it/s, loss=4.31, acc=15]
Evaluating: 100%|██████████| 32/32 [00:08<00:00,  3.83it/s, loss=7.14, acc=1.3]


Train Loss: 4.0838, Train Acc: 12.34%
Val Loss: 3.8575, Val Acc: 15.00%
Test Loss: 7.5138, Test Acc: 1.30%

Epoch 3/10


Training: 100%|██████████| 313/313 [02:08<00:00,  2.44it/s, loss=3.52, acc=16.9]
Evaluating: 100%|██████████| 32/32 [00:09<00:00,  3.35it/s, loss=4.98, acc=20.6]
Evaluating: 100%|██████████| 32/32 [00:08<00:00,  3.83it/s, loss=9.2, acc=0.55]


Train Loss: 3.7756, Train Acc: 16.88%
Val Loss: 3.4977, Val Acc: 20.60%
Test Loss: 9.8222, Test Acc: 0.55%

Epoch 4/10


Training: 100%|██████████| 313/313 [02:08<00:00,  2.44it/s, loss=4.04, acc=20.5]
Evaluating: 100%|██████████| 32/32 [00:09<00:00,  3.32it/s, loss=3.58, acc=24.4]
Evaluating: 100%|██████████| 32/32 [00:08<00:00,  3.64it/s, loss=9.14, acc=0.25]


Train Loss: 3.5412, Train Acc: 20.55%
Val Loss: 3.3118, Val Acc: 24.45%
Test Loss: 9.6236, Test Acc: 0.25%

Epoch 5/10


Training: 100%|██████████| 313/313 [02:08<00:00,  2.44it/s, loss=3.14, acc=23.6]
Evaluating: 100%|██████████| 32/32 [00:09<00:00,  3.32it/s, loss=3.68, acc=27.6]
Evaluating: 100%|██████████| 32/32 [00:09<00:00,  3.29it/s, loss=10.4, acc=0.35]


Train Loss: 3.3562, Train Acc: 23.56%
Val Loss: 3.1178, Val Acc: 27.65%
Test Loss: 10.6641, Test Acc: 0.35%

Epoch 6/10


Training: 100%|██████████| 313/313 [02:08<00:00,  2.44it/s, loss=2.45, acc=33.4]
Evaluating: 100%|██████████| 32/32 [00:09<00:00,  3.49it/s, loss=3.99, acc=37.7]
Evaluating: 100%|██████████| 32/32 [00:09<00:00,  3.34it/s, loss=11, acc=0.3]


Train Loss: 2.8568, Train Acc: 33.45%
Val Loss: 2.6249, Val Acc: 37.70%
Test Loss: 11.5770, Test Acc: 0.30%

Epoch 7/10


Training: 100%|██████████| 313/313 [02:09<00:00,  2.42it/s, loss=3.7, acc=37.2]
Evaluating: 100%|██████████| 32/32 [00:08<00:00,  3.79it/s, loss=3.88, acc=39]
Evaluating: 100%|██████████| 32/32 [00:09<00:00,  3.28it/s, loss=11.2, acc=0.35]


Train Loss: 2.6658, Train Acc: 37.25%
Val Loss: 2.5563, Val Acc: 38.95%
Test Loss: 11.7989, Test Acc: 0.35%

Epoch 8/10


Training: 100%|██████████| 313/313 [02:08<00:00,  2.44it/s, loss=2.1, acc=38.5]
Evaluating: 100%|██████████| 32/32 [00:08<00:00,  3.59it/s, loss=3.44, acc=38.9]
Evaluating: 100%|██████████| 32/32 [00:09<00:00,  3.49it/s, loss=11.7, acc=0.3]


Train Loss: 2.5910, Train Acc: 38.53%
Val Loss: 2.5128, Val Acc: 38.90%
Test Loss: 12.2037, Test Acc: 0.30%

Epoch 9/10


Training: 100%|██████████| 313/313 [02:08<00:00,  2.44it/s, loss=2.74, acc=39.9]
Evaluating: 100%|██████████| 32/32 [00:09<00:00,  3.49it/s, loss=3.61, acc=41]
Evaluating: 100%|██████████| 32/32 [00:08<00:00,  3.66it/s, loss=10.9, acc=0.4]


Train Loss: 2.5029, Train Acc: 39.94%
Val Loss: 2.4639, Val Acc: 40.95%
Test Loss: 11.3821, Test Acc: 0.40%

Epoch 10/10


Training: 100%|██████████| 313/313 [02:09<00:00,  2.43it/s, loss=2.38, acc=41.1]
Evaluating: 100%|██████████| 32/32 [00:09<00:00,  3.32it/s, loss=3.64, acc=40.9]
Evaluating: 100%|██████████| 32/32 [00:08<00:00,  3.88it/s, loss=10.9, acc=0.4]


Train Loss: 2.4512, Train Acc: 41.08%
Val Loss: 2.4501, Val Acc: 40.90%
Test Loss: 11.5431, Test Acc: 0.40%

Training completed! Best validation accuracy: 40.95%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
learning_rate,████▂▂▂▂▂▁
test_acc,▁█▃▁▂▁▂▁▂▂
test_loss,▂▁▄▄▆▇▇█▇▇
train_acc,▁▂▃▄▅▆▇▇██
train_loss,█▆▅▄▄▂▂▁▁▁
val_acc,▁▂▃▄▅▇████
val_loss,█▇▅▄▄▂▁▁▁▁

0,1
epoch,10.0
learning_rate,1e-05
test_acc,0.4
test_loss,11.54309
train_acc,41.075
train_loss,2.45125
val_acc,40.9
val_loss,2.45011
