In [None]:
!wget https://www-cs-toronto-edu.translate.goog/~kriz/cifar-10-python.tar.gz?_x_tr_sl=en&_x_tr_tl=es&_x_tr_hl=es&_x_tr_pto=tcb -O cifar-10-python.tar.gz
!tar -xvzf cifar-10-python.tar.gz?_x_tr_sl=en

!pip install loguru
!pip install plotly
!pip install open_clip_torch

In [None]:
import open_clip

model, _, preprocess = open_clip.create_model_and_transforms(
    "ViT-B-32", pretrained="laion2b_s34b_b79k"
)
model.eval()  # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
model = model.to("cuda")

tokenizer = open_clip.get_tokenizer("ViT-B-32")

## Load data

In [None]:
import pickle

import numpy as np
import plotly.express as px
from loguru import logger
from PIL import Image
from torch.utils.data import Dataset


def unpickle(path):
    logger.info(f"File loaded: {path}")
    with open(path, "rb") as fo:
        loaded_dict = pickle.load(fo, encoding="bytes")
    logger.info(f"Loaded dict batch label: {loaded_dict[b'batch_label']}")

    return loaded_dict


class CIFAR10Dataset(Dataset):
    def __init__(self, paths, n_images: int | None = None):
        self.labels = []
        self.images = None

        for path in paths:
            data_batch = unpickle(path)

            self.labels += data_batch[b"labels"]
            if self.images is None:
                self.images = data_batch[b"data"]
            else:
                self.images = np.concat([self.images, data_batch[b"data"]])

        if n_images != None:
            self.labels = self.labels[0:n_images]
            self.images = self.images[0:n_images]

        logger.info("Dataset info:")
        logger.info(f"\tShape: {self.images.shape}")
        logger.info(f"\tSize: {self.images.nbytes / 10e6} MB")

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        preprocessed_image = preprocess(
            Image.fromarray(image.reshape(3, 32, 32).transpose(1, 2, 0))
        )
        return preprocessed_image, label


train_dataset = CIFAR10Dataset(
    [
        "cifar-10-batches-py/data_batch_1",
        "cifar-10-batches-py/data_batch_2",
        "cifar-10-batches-py/data_batch_3",
        "cifar-10-batches-py/data_batch_4",
        "cifar-10-batches-py/data_batch_5",
    ],
    1000,
)

test_dataset = CIFAR10Dataset(["cifar-10-batches-py/test_batch"], n_images=500)

In [None]:
with open("cifar-10-batches-py/batches.meta", "rb") as fo:
    meta = pickle.load(fo, encoding="bytes")

In [None]:
def plot_example(dataset: dict, index: int):
    img, label = dataset[index]
    px.imshow(
        img.permute(1, 2, 0) * 128,
        title=f"Class {label} - {meta[b'label_names'][label]}",
        height=400,
    ).show()


plot_example(test_dataset, 1)
plot_example(test_dataset, 2)
plot_example(test_dataset, 3)
plot_example(test_dataset, 4)
plot_example(test_dataset, 5)

In [None]:
import open_clip
import torch
from sklearn.metrics import accuracy_score, classification_report
from torch.utils.data import DataLoader


def eval_clip(
    dataset: CIFAR10Dataset, classes: list[str], batch_size: int, shuffle: bool = False
) -> list[int]:
    logger.info("Tokenizing classes")
    tokenized_classes = tokenizer(classes).to("cuda")
    y_true = []
    y_pred = []

    logger.info("Generating DataLoader")
    dataloader = DataLoader(test_dataset, batch_size=128, shuffle=shuffle)

    logger.info("Starting evaluation")
    for batch_idx, (X, y) in enumerate(dataloader):
        X, y = X.to("cuda"), y.to("cuda")
        logger.info(f"\tProcessing batch {batch_idx}")
        with torch.no_grad(), torch.autocast("cuda"):
            image_features = model.encode_image(X)
            text_features = model.encode_text(tokenized_classes)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_features /= text_features.norm(dim=-1, keepdim=True)

            text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

            pred = text_probs.argmax(dim=-1)
            logger.info(f"\tBatch accuracy: {accuracy_score(y.cpu(), pred.cpu())}")

            y_pred += pred.cpu().tolist()
            y_true += y.cpu().tolist()

    logger.info("Classification Report:")
    logger.info(
        "\n"
        + classification_report(
            y_true, y_pred, labels=range(len(classes)), target_names=classes
        )
    )

    return y_pred

In [None]:
classes = [x.decode("utf-8") for x in meta[b"label_names"]]
logger.info(f"Classes: {classes}")

y_pred = eval_clip(test_dataset, classes, 128)

In [None]:
prompted_classes = [
    "a photo of an airplane, which is a vehicle",
    "a photo of an automobile, which is a vehicle",
    "a photo of a bird, which is an animal",
    "a photo of a cat, which is an animal",
    "a photo of a deer, which is an animal",
    "a photo of a dog, which is an animal",
    "a photo of a frog, which is an animal",
    "a photo of a horse, which is an animal",
    "a photo of a ship, which is a vehicle",
    "a photo of a truck which is a vehicle",
]

logger.info(f"Classes: {prompted_classes}")
y_pred = eval_clip(test_dataset, prompted_classes, 128)

## Part 2: Linear Probing

In [None]:
epochs = 10
learning_rate = 0.001
batch_size = 512
patience = 3
device = "cuda"

In [None]:
from torch import nn


class LinearProbeModel(nn.Module):
    def __init__(self, input_dim: int, num_classes: int):
        super().__init__()
        self.linear_layer = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        # assume [batch, feature_dim]
        return self.linear_layer(x)


linear_probe_model = LinearProbeModel(input_dim=512, num_classes=10).to(device)

In [None]:
import open_clip

y_pred = []
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(linear_probe_model.parameters(), lr=learning_rate)


train_dataset = CIFAR10Dataset(
    [
        "cifar-10-batches-py/data_batch_1",
        "cifar-10-batches-py/data_batch_2",
        "cifar-10-batches-py/data_batch_3",
        "cifar-10-batches-py/data_batch_4",
        "cifar-10-batches-py/data_batch_5",
    ]
)

test_dataset = CIFAR10Dataset(["cifar-10-batches-py/test_batch"])

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

model = model.to(device)

for epoch in range(epochs):
    # training
    for batch_idx, (X, y) in enumerate(train_dataloader):
        X, y = X.to(device), y.to(device)

        with torch.no_grad(), torch.autocast("cuda"):
            image_features = model.encode_image(X)
            image_features /= image_features.norm(dim=-1, keepdim=True)

        linear_probe_model.train()
        with torch.autocast("cuda"):
            logits = linear_probe_model(image_features)
            loss = loss_fn(logits, y)

            logger.info(
                f"[Training] Epoch {epoch + 1} | Batch {batch_idx + 1} | "
                f"Loss = {round(loss.item(), 4)}"
            )

            # Backpropagation
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

    logger.info("Evaluating in test_dataset")

    y_true = []
    y_pred = []

    linear_probe_model.eval()
    for batch_idx, (X, y) in enumerate(test_dataloader):
        X, y = X.to(device), y.to(device)

        with torch.no_grad(), torch.autocast("cuda"):
            image_features = model.encode_image(X)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            image_features = image_features.float()

            logits = linear_probe_model(image_features)
            pred = logits.argmax(dim=-1)

            y_pred += pred.cpu().tolist()
            y_true += y.cpu().tolist()
            logger.info(
                f"[Test] Epoch {epoch + 1} | Batch {batch_idx + 1} "
                f"| accuracy = {accuracy_score(y.cpu().tolist(), pred.cpu().tolist())}"
            )

    logger.info(f"[Test] Epoch {epoch + 1} accuracy: {accuracy_score(y_true, y_pred)}")