# Linear probe using CLIP features

In our previous tutorial, `Interacting_with_CLIP.ipynb`, we evaluated CLIP in zero-shot setting in which we use the cosine similarity between image features and label features as model prediction.

In this tutorial, we will cover another approach for using pretrained models for classification tasks, namely, linear probe.
Unlike zero-shot classification, linear probe involves training using the training dataset.
However, to keep the training cost low, we only train a linear classifier on top of the frozen pretrained model.

Side Note: Linear probe is not something new. Indeed, we did similar thing in CNN transfer learning tutorial, when we froze the main CNN and only trained linear classifier. The name 'linear probe' is often used in self-supervised learning literature to highlight that only the linear classifier is trained while the main network is being frozen.


In [None]:
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git

In [None]:
import numpy as np
import torch
from pkg_resources import packaging

print("Torch version:", torch.__version__)


## Load model

In [None]:
import clip

clip.available_models()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model, preprocess = clip.load("ViT-B/16")
model = model.to(device).eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

## Setting up train and test dataset

In [None]:
# We will evaluate CLIP on conventional image classification dataset (CIFAR10)

from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from tqdm import tqdm

cifar10_train = CIFAR10('data', transform=preprocess, download=True, train=True)
cifar10_test = CIFAR10('data', transform=preprocess, download=True, train=False)

train_loader = DataLoader(cifar10_train, batch_size=100, shuffle=True, num_workers=2)
test_loader = DataLoader(cifar10_test, batch_size=100, shuffle=False, num_workers=2)


## Linear probe option 1: using torch

In the [CLIP paper](https://arxiv.org/pdf/2103.00020), the authors use image feature before projecting it to shared projection space for linear probe.

To do so, we need to remove the projection layer (weight, to be specific) from the model

In [None]:
# See: https://github.com/openai/CLIP/blob/main/clip/model.py

sample_image = cifar10_test[0][0].unsqueeze(0).to(device)  # (1, 3, 224, 224)
print(sample_image.shape)

# before removing projection weight
with torch.no_grad():
    out_before = model.encode_image(sample_image).float()
print(out_before.shape)

# after removing projection weight
visual_proj = model.visual.proj
model.visual.proj = None

with torch.no_grad():
    out_after = model.encode_image(sample_image).float()
print(out_after.shape)

In [None]:
import torch.nn as nn
import torch.optim as optim

linear_classifier = nn.Linear(768, 10).to(device)
optimizer = optim.Adam(linear_classifier.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(3):
    for x, y in tqdm(train_loader):
        x, y = x.to(device), y.to(device)
        # TODO: compute loss and update parameter using optimizer.step()
        # 1. extract image feature and convert its dtype to float
        # 2. compute logits using linear_classifier and image feature
        # 3. compute loss using criterion

    # run evaluation every epoch
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in tqdm(test_loader):
            x, y = x.to(device), y.to(device)
            image_feature = model.encode_image(x).float()
            logits = linear_classifier(image_feature)
            loss = criterion(logits, y)
            test_loss += loss.item() * len(y)
            correct += (logits.argmax(dim=1) == y).sum().item()
            total += len(y)
    test_loss = test_loss / len(cifar10_test)
    test_acc = correct / total

    print()
    print(f"[Epoch {epoch+1}] test_loss: {test_loss:.4f}, test_acc: {test_acc * 100:.2f}%")




## Linear probe option 2: using external library

Another possible way to train a linear classifier on top of the learned feature is to first extract image features for all images and then use external library (e.g., scikit-learn) to train a linear classifier.

This allows us to easily use more complicated optimization algorithms implemented in scikit-learn, such as [L-BFGS](https://ko.wikipedia.org/wiki/L-BFGS) which is a [Quasi-Newton Method](https://en.wikipedia.org/wiki/Quasi-Newton_method).

In fact, the [CLIP paper](https://arxiv.org/pdf/2103.00020) uses this approach for linear probe evaluation (see Appendix A.3)

"We train a logistic regression classifier using scikit-learn’s L-BFGS implementation, with maximum 1,000 iteration"

scikit-learn LogisticRegression: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html

However, if you need data augmentation, the first approach is preferable.

In [None]:
import numpy as np

# TODO:
# extract train image features, convert to numpy
# store both image feature and label (y) as numpy arrays, each with name `train_features` and `train_labels`

print()
print(train_features.shape)
print(train_labels.shape)

# TODO:
# extract test image features, convert to numpy
# store both image feature and label (y) as numpy arrays, each with name `test_features` and `test_labels`

print()
print(test_features.shape)
print(test_labels.shape)

In [None]:
from sklearn.linear_model import LogisticRegression

C = 0.1
logistic_regression = LogisticRegression(solver="lbfgs", max_iter=1000, C=C)
logistic_regression.fit(train_features, train_labels)

test_pred = logistic_regression.predict(test_features)
test_acc = (test_pred == test_labels).sum() / len(test_labels)
print(f"Test acc: {test_acc * 100:.2f}%")

## Exercise: Linear probe vs zero-shot classification on CIFAR100

1. Compute zero-shot classification accuracy of CLIP on CIFAR100 as in tutorial `7_1_Interacting_with_CLIP.ipynb`.

2. Implement linear probe evaluation on CIFAR100 (option 2 using scikit-learn).

3. Compare results.

In [None]:
from torchvision.datasets import CIFAR100

cifar100_train = CIFAR100('data', transform=preprocess, download=True, train=True)
cifar100_test = CIFAR100('data', transform=preprocess, download=True, train=False)

train_loader = DataLoader(cifar100_train, batch_size=100, shuffle=True, num_workers=2)
test_loader = DataLoader(cifar100_test, batch_size=100, shuffle=False, num_workers=2)

In [None]:
# TODO 1: zero-shot classification

# make sure to re-set model.visual.proj with visual_proj for zero-shot classification
model.visual.proj = visual_proj


print()
print(f"Accuracy: {accuracy * 100:.2f}%")

In [None]:
# TODO 2: linear probe evaluation option 2

# make sure to remove model.visual.proj for linear probe
model.visual.proj = None


print(f"Test acc: {test_acc * 100:.2f}%")