In [None]:
import os
import numpy as np
from typing import List, Dict, Union, Optional

from datasets import load_metric
from transformers import ViTImageProcessor, AutoConfig, ViTModel
from transformers import TrainingArguments, Trainer
from transformers.modeling_outputs import ImageClassifierOutput

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets, transforms

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In this notebook, we'll walk through the process of building an image classification model using the Vision Transformer (ViT) architecture. We'll use the Hugging Face Transformers library to create and train the model.

### Step 1: Preprocessing with ViTImageProcessor

First, we need to set up the ViTImageProcessor to preprocess our image data. The processor loads the pre-trained configuration and tokenizer for the ViT model.

### Step 2: Data Collation Function
To prepare our data for training, we define a collation function collate_fn that stacks pixel values (images) and their corresponding labels into tensors within a dictionary. This function will be used during data loading.

### Step 3: Classification Head
We create a classification head for our ViT model. This head consists of linear layers and dropout for mapping the model's output features to class logits.

### Step 4: Image Classification Model
Now, we define the main image classification model. It consists of three parts: preprocessing, ViT backbone, and the classification head.

### Step 5: Model Forward Pass
The forward method of our image classification model takes pixel values as input and returns logits. It also computes the loss if labels are provided based on the problem type and number of classes.

In [None]:
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224", cache_dir="/archive/turganbay/.huggingface")

def collate_fn(examples: List[Union[torch.Tensor, int]]) -> Dict[str, torch.Tensor]:
    """
    Collates a list of examples into a dictionary with pixel values and labels.

    Args:
        examples (List[Union[torch.Tensor, int]]): List of examples where each example is a tuple
            containing a pixel value tensor (image) and its corresponding label (integer).

    Returns:
        Dict[str, torch.Tensor]: A dictionary containing "pixel_values" (stacked pixel values) and "labels" tensors.
    """
    pixel_values = torch.stack([example[0] for example in examples])
    labels = torch.tensor([example[1] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}


class ClassificationHead(torch.nn.Module):
    def __init__(self, config: AutoConfig, num_classes: int=100):
        super().__init__()

        self.config = config

        self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.out_proj = torch.nn.Linear(config.hidden_size, num_classes)

    def init_weights(self):
        self.dense.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        if self.dense.bias is not None:
            self.dense.bias.data.zero_()

    def forward(self, features, **kwargs):
        x = features[:, 0, :] # CLS token
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

class ImageClassification(nn.Module):
    def __init__(self, num_classes: int, backbone: str) -> None:
        super().__init__()

        self.num_classes = num_classes
        self.config = AutoConfig.from_pretrained(backbone, cache_dir="/archive/turganbay/.huggingface")
        self.vit = ViTModel.from_pretrained(backbone, self.config, cache_dir="/archive/turganbay/.huggingface")
        self.classifier = ClassificationHead(self.config, num_classes)

    def forward(
        self,
        pixel_values,
        head_mask=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        interpolate_pos_encoding=None,
        return_dict=True,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.vit(
            pixel_values,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        logits = self.classifier(sequence_output)
        loss = None
        if labels is not None:
            labels = labels.to(logits.device)
            if self.config.problem_type is None:
                if self.num_classes == 1:
                    loss_fct = nn.MSELoss()
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                elif self.num_classes > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    loss_fct = nn.CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
                else:
                    loss_fct = nn.BCEWithLogitsLoss()
                    loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return ImageClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

### Step 6: Data Preparation

In this step, we'll prepare the dataset for training. We'll use the CIFAR-100 dataset, which contains 100 different classes of images. We define a transformation pipeline for our dataset. This pipeline resizes the images to a specified size and converts them to tensors.

#### Loading CIFAR-100 Dataset
We load the CIFAR-100 dataset, specifying the data directory, setting train=True for the training split and train=False for the test split. The download=True flag ensures that the dataset is downloaded if it's not already available locally. We apply the previously defined transformation to the dataset. To create a train-validation split, we split the training dataset into two subsets: one for training and one for validation. We use the random_split function from PyTorch to achieve this.

With our data prepared and split, we are ready to proceed with training and evaluation in the next steps.

In [None]:
size = processor.size["height"]
transform = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
])

train_dataset = datasets.CIFAR100("/archive/turganbay/cifar", train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR100("/archive/turganbay/cifar", train=False, download=True, transform=transform)
N = len(train_dataset)
T = int(N*0.9)
train_dataset, val_dataset = random_split(train_dataset, [T, N-T])

### Step 7: Evaluation Metrics

In this step, we'll define and compute evaluation metrics for our image classification model. We'll use commonly used metrics like accuracy, recall, precision, and F1 score. We start by loading metric functions for the evaluation. We initialize metric dictionaries for each metric we want to calculate: "accuracy," "recall," "precision," and "f1."

Next, we define a function to compute these metrics given the evaluation predictions. The function takes eval_pred, which is a tuple containing logits and labels. It calculates the predicted class labels, computes the selected metrics, and returns the results as a dictionary.

In [None]:
metrics = ["accuracy", "recall", "precision", "f1"]
metric = {}
for met in metrics:
    metric[met] = load_metric(met)


def compute_metrics(eval_pred: Tuple[np.ndarray, np.ndarray]) -> Dict[str, float]:
    """
    Compute evaluation metrics given prediction logits and true labels.

    Args:
        eval_pred (Tuple[np.ndarray, np.ndarray]): A tuple containing logits and true labels.

    Returns:
        Dict[str, float]: A dictionary containing computed evaluation metrics.
    """
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    metric_res = {}

    for met in metrics:
        if met != "accuracy":
            metric_res[met] = metric[met].compute(predictions=predictions, references=labels, average="micro")[met]
        else:
            metric_res[met] = metric[met].compute(predictions=predictions, references=labels)[met]

    return metric_res

### Step 8: Model Training

In this step, we'll set up the training process for our image classification model. We'll use the Hugging Face `Trainer` class for this purpose.

First, we initialize our image classification model. We specify the number of classes (100 in this case) and the backbone architecture (Google Vision Transformer - `google/vit-base-patch16-224`).

We configure the Hugging Face Trainer class to handle our training process. We specify various training arguments, including batch sizes, logging intervals, learning rate, and more.

Finally, we start the training process using the configured Trainer. The trainer.train() method performs training for the specified number of steps (max_steps) and saves the best model checkpoint based on the evaluation metric.

Now, your image classification model is ready to be trained!

In [None]:
model = ImageClassification(100, "google/vit-base-patch16-224")
model.to(device)

trainer = Trainer(
    model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    tokenizer=processor,
    args=TrainingArguments(
        per_device_train_batch_size=64,
        per_device_eval_batch_size=64,
        logging_steps=200,
        save_steps=200,
        save_strategy="steps",
        evaluation_strategy="steps",
        warmup_steps=200,
        max_steps=3000,
        learning_rate=1e-5,
        weight_decay=0.01,
        optim="adamw_torch",
        load_best_model_at_end=True,
        greater_is_better=False,
        output_dir=f"/archive/turganbay/vit_clf_models/model",
        report_to="wandb",
        run_name="vit_clf",
    ),
    data_collator=collate_fn
)

trainer.train()