In [1]:
!pip install transformers datasets evaluate --quiet

In [2]:
import os
import torch
import numpy as np
import evaluate
import timm
from datasets import load_dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.optim import AdamW
import torch.nn as nn
from PIL import Image

def save_checkpoint(model, output_path):
    """Save the model state_dict to the specified output path."""
    torch.save(model.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
    print(f"Checkpoint saved at {output_path}")

def compute_metrics(eval_pred):
    """Compute accuracy using the evaluate package."""
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy_metric = evaluate.load("accuracy")
    return accuracy_metric.compute(predictions=predictions, references=labels)

class CustomTrainer:
    def __init__(self, model, criterion, optimizer, train_loader, eval_loader, compute_metrics, device, fp16=False):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.eval_loader = eval_loader
        self.compute_metrics = compute_metrics
        self.device = device
        self.fp16 = fp16
        if self.fp16:
            self.scaler = torch.amp.GradScaler()
        else:
            self.scaler = None

    def train_epoch(self, epoch):
        self.model.train()
        epoch_loss = 0.0
        for step, batch in enumerate(self.train_loader):
            inputs = batch["pixel_values"].to(self.device)
            labels = batch["labels"].to(self.device)
            self.optimizer.zero_grad()
            if self.fp16:
                with torch.cuda.amp.autocast():
                    logits = self.model(inputs)
                    loss = self.criterion(logits, labels)
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                logits = self.model(inputs)
                loss = self.criterion(logits, labels)
                loss.backward()
                self.optimizer.step()
            epoch_loss += loss.item()
            if (step + 1) % 100 == 0:
                print(f"Step {step+1}/{len(self.train_loader)} - Loss: {loss.item():.4f}")
        avg_loss = epoch_loss / len(self.train_loader)
        print(f"Epoch {epoch} training loss: {avg_loss:.4f}")
        return avg_loss

    def evaluate(self):
        self.model.eval()
        all_logits = []
        all_labels = []
        with torch.no_grad():
            for batch in self.eval_loader:
                inputs = batch["pixel_values"].to(self.device)
                labels = batch["labels"].to(self.device)
                logits = self.model(inputs)
                all_logits.append(logits.cpu())
                all_labels.append(labels.cpu())
        all_logits = torch.cat(all_logits, dim=0).numpy()
        all_labels = torch.cat(all_labels, dim=0).numpy()
        metrics = self.compute_metrics((all_logits, all_labels))
        print("Evaluation Metrics:", metrics)
        return metrics

def get_dataloaders(train_dataset, eval_dataset, train_batch_size, eval_batch_size):
    train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
    eval_loader = DataLoader(eval_dataset, batch_size=eval_batch_size, shuffle=False)
    return train_loader, eval_loader

def load_components(config):
    dataset = load_dataset(config["dataset_name"])

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])

    def transform_fn(example):
        img = example["img"]
        if isinstance(img, list):
            img = img[0]
        if not isinstance(img, Image.Image):
            img = Image.fromarray(img)
        img = img.convert("RGB")
        tensor = transform(img)
        if tensor.ndim == 2:
            tensor = tensor.unsqueeze(0).repeat(3, 1, 1)
        example["pixel_values"] = tensor
        example["labels"] = example["fine_label"]
        return example

    dataset = dataset.map(transform_fn)

    split_datasets = dataset["train"].train_test_split(test_size=config["validation_split"], seed=config["seed"])
    train_dataset = split_datasets["train"]
    eval_dataset = split_datasets["test"]

    train_dataset.set_format("torch", columns=["pixel_values", "labels"])
    eval_dataset.set_format("torch", columns=["pixel_values", "labels"])

    sample = train_dataset[0]
    print("\n--- Sample from train dataset after transformation ---")
    print("Sample 'pixel_values' type:", type(sample["pixel_values"]))
    print("Sample 'pixel_values' shape:", sample["pixel_values"].shape)
    print("Sample label:", sample["labels"])

    model = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=config["num_labels"])
    return train_dataset, eval_dataset, model


def main():
    config = {
        "dataset_name": "cifar100",
        "num_labels": 100,
        "validation_split": 0.1,
        "seed": 42,
        "output_dir": "./vit-cifar100-checkpoints",
        "per_device_train_batch_size": 16,
        "per_device_eval_batch_size": 32,
        "num_train_epochs": 1,
        "learning_rate": 5e-5,
        "weight_decay": 0.05,
        "fp16": True,
    }
    os.makedirs(config["output_dir"], exist_ok=True)
    train_dataset, eval_dataset, model = load_components(config)
    train_loader, eval_loader = get_dataloaders(train_dataset, eval_dataset,
                                                config["per_device_train_batch_size"],
                                                config["per_device_eval_batch_size"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
    trainer = CustomTrainer(model, criterion, optimizer, train_loader, eval_loader, compute_metrics, device, fp16=config["fp16"])

    best_accuracy = 0.0
    num_epochs = config["num_train_epochs"]
    for epoch in range(1, num_epochs + 1):
        print(f"\nEpoch {epoch}/{num_epochs}")
        trainer.train_epoch(epoch)
        metrics = trainer.evaluate()
        current_metric = metrics.get("accuracy")
        if current_metric is not None and current_metric > best_accuracy:
            best_accuracy = current_metric
            checkpoint_path = os.path.join(config["output_dir"], f"best_model_{current_metric:.4f}")
            os.makedirs(checkpoint_path, exist_ok=True)
            save_checkpoint(model, checkpoint_path)

if __name__ == "__main__":
    main()


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]


--- Sample from train dataset after transformation ---
Sample 'pixel_values' type: <class 'torch.Tensor'>
Sample 'pixel_values' shape: torch.Size([3, 224, 224])
Sample label: tensor(47)

Epoch 1/1


  with torch.cuda.amp.autocast():


Step 100/2813 - Loss: 2.3994
Step 200/2813 - Loss: 0.3684
Step 300/2813 - Loss: 0.6844
Step 400/2813 - Loss: 0.4878
Step 500/2813 - Loss: 0.7856
Step 600/2813 - Loss: 2.6139
Step 700/2813 - Loss: 0.4106
Step 800/2813 - Loss: 0.6817
Step 900/2813 - Loss: 0.1805
Step 1000/2813 - Loss: 1.0217
Step 1100/2813 - Loss: 0.6897
Step 1200/2813 - Loss: 0.8167
Step 1300/2813 - Loss: 0.4020
Step 1400/2813 - Loss: 0.6790
Step 1500/2813 - Loss: 0.9043
Step 1600/2813 - Loss: 0.5107
Step 1700/2813 - Loss: 0.5574
Step 1800/2813 - Loss: 0.5841
Step 1900/2813 - Loss: 1.2514
Step 2000/2813 - Loss: 0.9650
Step 2100/2813 - Loss: 0.5962
Step 2200/2813 - Loss: 0.2033
Step 2300/2813 - Loss: 0.7170
Step 2400/2813 - Loss: 0.4085
Step 2500/2813 - Loss: 1.6168
Step 2600/2813 - Loss: 0.7278
Step 2700/2813 - Loss: 0.7440
Step 2800/2813 - Loss: 0.5184
Epoch 1 training loss: 0.8288
Evaluation Metrics: {'accuracy': 0.8458}
Checkpoint saved at ./vit-cifar100-checkpoints/best_model_0.8458
