# CLIP zero-shot Evaluation
This short notebook implements the dataset split into base and novel categories (see project assignment) and runs the zero-shot evaluation with CLIP.
Feel free to copy the code contained in this notebook or to directly use this notebook as starting point for you project.

In [1]:
# we need to install clip as it is not pre-installed
# you are also free to use open_clip which provide more models
# https://github.com/mlfoundations/open_clip
%pip install openai_clip

Collecting openai_clip
  Downloading openai-clip-1.0.1.tar.gz (1.4 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.4 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.4/1.4 MB[0m [31m36.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m22.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from openai_clip)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: openai_clip
  Building wheel for openai_clip (setup.py) ... [?25l[?25hdone
  Created wheel for openai_clip: filename=openai_clip-1.0.1-py3-none-any.whl size=1368605 sha256=5abae0045e0d2b09ee1b0777e914e

In [2]:
import torch
import torchvision
import clip
from tqdm import tqdm

## Dataset Loading
Let's get the data directly from torchvision as we have seen during labs.

In [3]:
def get_data(data_dir="./data", transform=None):
    """Load Flowers102 train, validation and test sets.
    Args:
        data_dir (str): Directory where the dataset will be stored.
        transform (torch.Compose)
    Returns:
        tuple: A tuple containing the train, validation, and test sets.
    """
    train = torchvision.datasets.Flowers102(root=data_dir, split="train", download=True, transform=transform)
    val = torchvision.datasets.Flowers102(root=data_dir, split="val", download=True, transform=transform)
    test = torchvision.datasets.Flowers102(root=data_dir, split="test", download=True, transform=transform)
    return train, val, test

## Base and Novel categories
To split in base and novel categories we list all dataset classes, and count their number (we already know it's 102 but let's do it properly).
Then, we just allocate the first half to base categories and the remaining half to novel ones.
We can do this because we are simulating a real world application, but keep in mind this will not happen out there!

In [4]:
def base_novel_categories(dataset):
    # set returns the unique set of all dataset classes
    all_classes = set(dataset._labels)
    # and let's count them
    num_classes = len(all_classes)

    # here list(range(num_classes)) returns a list from 0 to num_classes - 1
    # then we slice the list in half and generate base and novel category lists
    base_classes = list(range(num_classes))[:num_classes//2]
    novel_classes = list(range(num_classes))[num_classes//2:]
    return base_classes, novel_classes

## Inspect Classes
Let's now visualize which are the base and novel classes.
To do so, we first get a dummy test set (without augmentations) as we are just interested in the dataset labels. Then, we split it useing `base_novel_categories`.
Finally, we use the hard-coded CLASS_NAMES to print the class in natural language.

> Note: the list of class names was only recently added to `torchvision.datasets.Flowers102`. To avoid useless errors that can occour to you, we decided to also provide such a list.

In [5]:
_, _, tmp_test = get_data()
base_classes, novel_classes = base_novel_categories(tmp_test)
CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]
print("Base Class Names:", [(i, CLASS_NAMES[i]) for i in base_classes])
print("Novel Class Names:", [(i, CLASS_NAMES[i]) for i in novel_classes])

100%|██████████| 345M/345M [00:16<00:00, 20.7MB/s]
100%|██████████| 502/502 [00:00<00:00, 984kB/s]
100%|██████████| 15.0k/15.0k [00:00<00:00, 25.8MB/s]


Base Class Names: [(0, 'pink primrose'), (1, 'hard-leaved pocket orchid'), (2, 'canterbury bells'), (3, 'sweet pea'), (4, 'english marigold'), (5, 'tiger lily'), (6, 'moon orchid'), (7, 'bird of paradise'), (8, 'monkshood'), (9, 'globe thistle'), (10, 'snapdragon'), (11, "colt's foot"), (12, 'king protea'), (13, 'spear thistle'), (14, 'yellow iris'), (15, 'globe-flower'), (16, 'purple coneflower'), (17, 'peruvian lily'), (18, 'balloon flower'), (19, 'giant white arum lily'), (20, 'fire lily'), (21, 'pincushion flower'), (22, 'fritillary'), (23, 'red ginger'), (24, 'grape hyacinth'), (25, 'corn poppy'), (26, 'prince of wales feathers'), (27, 'stemless gentian'), (28, 'artichoke'), (29, 'sweet william'), (30, 'carnation'), (31, 'garden phlox'), (32, 'love in the mist'), (33, 'mexican aster'), (34, 'alpine sea holly'), (35, 'ruby-lipped cattleya'), (36, 'cape flower'), (37, 'great masterwort'), (38, 'siam tulip'), (39, 'lenten rose'), (40, 'barbeton daisy'), (41, 'daffodil'), (42, 'sword 

## Split Dataset
The next step is to actually split the dataset into the base and novel categories we extract from `base_novel_categories`.
To split the data we need the dataset (obviously) and the list of base classes. If the sample label is not part of the base categories, then it must be part of the novel ones.

In [6]:
def split_data(dataset, base_classes):
    # these two lists will store the sample indexes
    base_categories_samples = []
    novel_categories_samples = []

    # we create a set of base classes to compute the test below in O(1)
    # this is optional and can be removed
    base_set = set(base_classes)

    # here we iterate over sample labels and also get the correspondent sample index
    for sample_id, label in enumerate(dataset._labels):
        if label in base_set:
            base_categories_samples.append(sample_id)
        else:
            novel_categories_samples.append(sample_id)

    # here we create the dataset subsets
    # the torch Subset is just a wrapper around the dataset
    # it simply stores the subset indexes and the original dataset (your_subset.dataset)
    # when asking for sample i in the subset, torch will look for its original position in the dataset and retrieve it
    # https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset
    base_dataset = torch.utils.data.Subset(dataset, base_categories_samples)
    novel_dataset = torch.utils.data.Subset(dataset, novel_categories_samples)
    return base_dataset, novel_dataset

## Extract k shots
As the dataset already provides 10 train and validation shots, we do not need to extract them.
Beaware that Few-Shot Adaptation papers must do this operation as most datasets count significantly more samples in both the training and validation sets.

## Load CLIP

In [16]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# available models = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
model, preprocess = clip.load("ViT-B/16", device=device)

# preprocess contains CLIP's pre-defined augmentations, let's inspect them!
preprocess

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    <function _convert_image_to_rgb at 0x78d88c8723e0>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

## Load and Prepare Data
Here we get the three dataset split and pass clip pre-defined augmentations.
Then, we compute base and novel categories (in this case is redundand as we already did it before).
Finally, se split the three datasets into base and novel categories.
As we want to use the novel categories only for the test set, we drop `train_novel` and `val_novel`.

In [8]:
# get the three datasets
train_set, val_set, test_set = get_data(transform=preprocess)

# split classes into base and novel
base_classes, novel_classes = base_novel_categories(train_set)

# split the three datasets
train_base, _ = split_data(train_set, base_classes)
val_base, _ = split_data(val_set, base_classes)
test_base, test_novel = split_data(test_set, base_classes)

## Compute Zero-Shot Predictions

In [17]:
@torch.no_grad() # we don't want gradients
def eval(model, dataset, categories, batch_size, device, label=""):
    # let's set the model in evaluation mode
    model.eval()

    # Remap labels into a contiguous set starting from zero
    contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}

    # here we apply the standard CLIP template used for oxford flowers to all categories
    # and immediately tokenize each sentence (convert natural language into numbers - feel free to print the text input to inspect them)
    text_inputs = clip.tokenize(
        [f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in categories]
    ).to(device)

    # we can encode the text features once as they are shared for all images
    # therefore we do it outside the evaluation loop
    text_features = model.encode_text(text_inputs)
    # and here we normalize them (standard pratice with CLIP)
    text_features /= text_features.norm(dim=-1, keepdim=True) # per avere norma 1 per calcolare cosine similarity

    # simple dataloader creation
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # here we store the number of correct predictions we will make
    correct_predictions = 0
    for image, target in tqdm(dataloader, desc=label):
        # base categories range from 0 to 50, whil novel ones from 51 to 101
        # therefore we must map categories to the [0, 50], otherwise we will have wrong predictions
        # Map targets in contiguous set starting from zero
        # Labels needs to be .long() in pytorch
        target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

        image = image.to(device)
        target = target.to(device)

        # forward image through CLIP image encoder
        image_features = model.encode_image(image)
        # and normalize
        image_features /= image_features.norm(dim=-1, keepdim=True)

        # here cosine similarity between image and text features and keep the argmax for every row (every image)
        predicted_class = (image_features @ text_features.T).argmax(dim=-1)
        # now we check which are correct, and sum them (False == 0, True == 1)
        correct_predictions += (predicted_class == target).sum().item()

    # and now we compute the accuracy
    accuracy = correct_predictions / len(dataset)
    return accuracy

base_accuracy = eval(model=model, dataset=test_base, categories=base_classes, batch_size=128, device=device, label="🧠 Zero-shot evaluation on Base Classes")
novel_accuracy = eval(model=model, dataset=test_novel, categories=novel_classes, batch_size=128, device=device, label="🧠 Zero-shot evaluation on Novel Classes")

print()
print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")


🧠 Zero-shot evaluation on Base Classes: 100%|██████████| 20/20 [00:18<00:00,  1.09it/s]
🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 29/29 [00:27<00:00,  1.05it/s]


🔍 Base classes accuracy: 71.33%
🔍 Novel classes accuracy: 78.24%





## Harmonic Mean
Few-Shot Adaptations papers usually report the Harmonic Mean.
The harmonic mean tends to mitigate the impact of large outliers (base accuracy) and aggravate the impact of small ones (novel accuracy).
Thus, achieving very high base accuracies at the expense of the novel accuracy will be penalized by the HM.

In [None]:
def harmonic_mean(base_accuracy, novel_accuracy):
    numerator = 2
    denominator = 1 / base_accuracy + 1 / novel_accuracy
    hm = numerator / denominator
    return hm

print(f"🔍 Harmonic Mean: {harmonic_mean(base_accuracy, novel_accuracy)*100:.2f}%")

🔍 Harmonic Mean: 74.62%


# Data Augmentation

In [9]:
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image

# Define data augmentation transformations
augmentation_transforms = transforms.Compose([
    transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0)),  # Random crop and resize
    transforms.RandomHorizontalFlip(p=0.5),  # 50% chance to flip horizontally
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Random color adjustments
    transforms.RandomRotation(degrees=15),  # Rotate image within ±15 degrees
    #transforms.ToTensor(),  # Convert PIL image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize for CLIP
])

# Validation transformations (no augmentation, just normalization)
validation_transforms = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to a fixed size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
from torch.utils.data import Dataset
from PIL import Image

# Example dataset class
class AugmentedImageTextDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data  # Assume data is a list of (image_path, label) tuples
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image, label = self.data[idx]

        if self.transform:
          image = self.transform(image)
        return image, label
augmented_train_base = AugmentedImageTextDataset(data = train_base, transform = augmentation_transforms)

## Fine-Tuning of the textual linear layer

We fine-tune the last linear layer of the textual encoder for the classification of the base train data.

In [10]:
from torch.utils.data import DataLoader
from torch.optim import Adam
import torchvision.transforms as transforms

def fine_tuning_linear_text(model, train_dataset, val_dataset, categories, lr, batch_size, num_epochs, device):

  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
  val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

  model = model.float()

  text_projection = model.text_projection

  # Freeze all parameters in the model
  for param in model.parameters():
      param.requires_grad = False

  # Unfreeze the projection layer
  text_projection.requires_grad = True

  criterion = torch.nn.CrossEntropyLoss()
  optimizer = Adam([text_projection], lr=lr)

  text_inputs = clip.tokenize(
          [f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in categories]).to(device)

  contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}

  print("🧠 Fine-tuning training+validation on Base Classes")

  for epoch in range(num_epochs):

    # Training of one epoch

    model.train()
    # here we store the sum of all the computed losses through the all batches
    total_loss = 0
    # here we store the number of correct predictions we will make
    correct_predictions = 0

    for image, target in tqdm(train_dataloader):

          target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

          image = image.to(device).float()
          target = target.to(device)

          text_features = model.encode_text(text_inputs).float()
          # and here we normalize them (standard pratice with CLIP)
          text_feature_norm = text_features.norm(dim=-1, keepdim=True)
          text_features = text_features/text_feature_norm # per avere norma 1 per calcolare cosine similarity

          # forward image through CLIP image encoder
          image_features = model.encode_image(image).float()
          # and normalize
          image_features_norm = image_features.norm(dim=-1, keepdim=True)
          image_features = image_features / image_features_norm # per avere norma 1 per calcolare cosine similarity

          # here cosine similarity between image and text features and keep the argmax for every row (every image)
          logits = image_features @ text_features.T

          loss = criterion(logits, target)
          total_loss += loss.item()

          # Backpropagation
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
          optimizer.step()
          optimizer.zero_grad()

          # Training accuracy computation
          predicted_class = logits.argmax(dim=-1)
          correct_predictions += (predicted_class == target).sum().item()

    print(f"Epoch {epoch + 1}, Training loss: {total_loss/ len(train_dataloader)}; Training accuracy: {correct_predictions / len(train_dataset)*100:.2f}%")
    print()

    # Validation of one epoch

    model.eval()
    total_loss = 0
    correct_predictions = 0

    for image, target in tqdm(val_dataloader):

          target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

          image = image.to(device).float()
          target = target.to(device)

          text_inputs = clip.tokenize(
          [f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in categories]).to(device)

          text_features = model.encode_text(text_inputs).float()
          # and here we normalize them (standard pratice with CLIP)
          text_feature_norm = text_features.norm(dim=-1, keepdim=True)
          text_features = text_features/text_feature_norm # per avere norma 1 per calcolare cosine similarity

          # forward image through CLIP image encoder
          image_features = model.encode_image(image).float()
          # and normalize
          image_features_norm = image_features.norm(dim=-1, keepdim=True)
          image_features = image_features / image_features_norm # per avere norma 1 per calcolare cosine similarity

          # here cosine similarity between image and text features and keep the argmax for every row (every image)
          logits = image_features @ text_features.T

          loss = criterion(logits, target)
          total_loss += loss.item()

          # Validation accuracy computation
          predicted_class = logits.argmax(dim=-1)
          correct_predictions += (predicted_class == target).sum().item()

    print(f"Epoch {epoch + 1}, Validation loss: {total_loss/ len(val_dataloader)}; Validation accuracy: {correct_predictions / len(val_dataset)*100:.2f}%")
    print()
    print("-----------------------------------------------------------------------------------------------")

  return(model)

model, _ = clip.load("ViT-B/16", device=device)
model_ft_text_layer = fine_tuning_linear_text(model=model, train_dataset=train_base, val_dataset=val_base, categories=base_classes, lr = 0.0001, batch_size=16, num_epochs=10, device=device)

🧠 Fine-tuning training+validation on Base Classes


100%|██████████| 32/32 [00:11<00:00,  2.80it/s]


Epoch 1, Training loss: 3.7989226058125496; Training accuracy: 61.96%



100%|██████████| 32/32 [00:09<00:00,  3.28it/s]


Epoch 1, Validation loss: 3.7515273690223694; Validation accuracy: 75.10%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:10<00:00,  3.18it/s]


Epoch 2, Training loss: 3.7296567857265472; Training accuracy: 80.98%



100%|██████████| 32/32 [00:10<00:00,  3.11it/s]


Epoch 2, Validation loss: 3.7075588777661324; Validation accuracy: 84.12%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:10<00:00,  3.05it/s]


Epoch 3, Training loss: 3.6899445801973343; Training accuracy: 86.86%



100%|██████████| 32/32 [00:10<00:00,  3.12it/s]


Epoch 3, Validation loss: 3.6803678944706917; Validation accuracy: 89.80%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:10<00:00,  3.07it/s]


Epoch 4, Training loss: 3.6643574833869934; Training accuracy: 90.39%



100%|██████████| 32/32 [00:09<00:00,  3.21it/s]


Epoch 4, Validation loss: 3.662455826997757; Validation accuracy: 90.78%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:10<00:00,  3.09it/s]


Epoch 5, Training loss: 3.647159107029438; Training accuracy: 91.76%



100%|██████████| 32/32 [00:09<00:00,  3.23it/s]


Epoch 5, Validation loss: 3.651083379983902; Validation accuracy: 91.76%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:10<00:00,  3.04it/s]


Epoch 6, Training loss: 3.6351843774318695; Training accuracy: 95.29%



100%|██████████| 32/32 [00:10<00:00,  3.15it/s]


Epoch 6, Validation loss: 3.642608106136322; Validation accuracy: 91.96%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:10<00:00,  3.15it/s]


Epoch 7, Training loss: 3.6270277574658394; Training accuracy: 95.69%



100%|██████████| 32/32 [00:10<00:00,  3.17it/s]


Epoch 7, Validation loss: 3.6362961530685425; Validation accuracy: 93.53%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:10<00:00,  3.20it/s]


Epoch 8, Training loss: 3.620182439684868; Training accuracy: 95.29%



100%|██████████| 32/32 [00:09<00:00,  3.23it/s]


Epoch 8, Validation loss: 3.632796697318554; Validation accuracy: 92.94%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:10<00:00,  3.14it/s]


Epoch 9, Training loss: 3.6151634082198143; Training accuracy: 97.06%



100%|██████████| 32/32 [00:09<00:00,  3.21it/s]


Epoch 9, Validation loss: 3.6281269192695618; Validation accuracy: 92.94%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:10<00:00,  3.20it/s]


Epoch 10, Training loss: 3.6095001697540283; Training accuracy: 96.47%



100%|██████████| 32/32 [00:09<00:00,  3.24it/s]

Epoch 10, Validation loss: 3.625461630523205; Validation accuracy: 93.73%

-----------------------------------------------------------------------------------------------





Then we evaluate the fine-tuned CLIP in both base test set (few-shot evaluation) and novel test set (zero-shot evaluation).

In [15]:
base_accuracy = eval(model=model_ft_text_layer, dataset=test_base, categories=base_classes, batch_size=128, device=device, label="🧠 Few-shot evaluation on Base Classes")
novel_accuracy = eval(model=model_ft_text_layer, dataset=test_novel, categories=novel_classes, batch_size=128, device=device, label="🧠 zero-shot evaluation on Base Classes")
print()
print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")

🧠 Few-shot evaluation on Base Classes: 100%|██████████| 20/20 [00:28<00:00,  1.43s/it]
🧠 zero-shot evaluation on Base Classes: 100%|██████████| 29/29 [00:42<00:00,  1.46s/it]


🔍 Base classes accuracy: 92.03%
🔍 Novel classes accuracy: 61.43%





## Fine-Tuning of the visual linear layer

We fine-tune the last linear layer of the visual encoder for the classification of the base train data.

In [None]:
from torch.utils.data import DataLoader
from torch.optim import Adam
import torchvision.transforms as transforms

def fine_tuning_linear_visual(model, train_dataset, val_dataset, categories, lr, batch_size, num_epochs, device):

  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
  val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

  model = model.float()

  visual_projection = model.visual.proj

  # Freeze all parameters in the model
  for param in model.parameters():
      param.requires_grad = False

  # Unfreeze the projection layer
  visual_projection.requires_grad = True

  criterion = torch.nn.CrossEntropyLoss()
  optimizer = Adam([visual_projection], lr=lr)

  contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}

  text_inputs = clip.tokenize(
          [f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in categories]).to(device)

  text_features = model.encode_text(text_inputs).float()
  # and here we normalize them (standard pratice with CLIP)
  text_feature_norm = text_features.norm(dim=-1, keepdim=True)
  text_features = text_features/text_feature_norm # per avere norma 1 per calcolare cosine similarity

  print("🧠 Fine-tuning training+validation on Base Classes")

  for epoch in range(num_epochs):

    # Training of one epoch

    model.train()
    # here we store the sum of all the computed losses through the all batches
    total_loss = 0
    # here we store the number of correct predictions we will make
    correct_predictions = 0

    for image, target in tqdm(train_dataloader):

          target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

          image = image.to(device).float()
          target = target.to(device)

          # forward image through CLIP image encoder
          image_features = model.encode_image(image).float()
          # and normalize
          image_features_norm = image_features.norm(dim=-1, keepdim=True)
          image_features = image_features / image_features_norm # per avere norma 1 per calcolare cosine similarity

          # here cosine similarity between image and text features and keep the argmax for every row (every image)
          logits = image_features @ text_features.T

          loss = criterion(logits, target)
          total_loss += loss.item()

          # Backpropagation
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
          optimizer.step()
          optimizer.zero_grad()

          # Training accuracy computation
          predicted_class = logits.argmax(dim=-1)
          correct_predictions += (predicted_class == target).sum().item()

    print(f"Epoch {epoch + 1}, Training loss: {total_loss/ len(train_dataloader)}; Training accuracy: {correct_predictions / len(train_dataset)*100:.2f}%")
    print()

    # Validation of one epoch

    model.eval()
    total_loss = 0
    correct_predictions = 0

    for image, target in tqdm(val_dataloader):

          target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

          image = image.to(device).float()
          target = target.to(device)

          # forward image through CLIP image encoder
          image_features = model.encode_image(image).float()
          # and normalize
          image_features_norm = image_features.norm(dim=-1, keepdim=True)
          image_features = image_features / image_features_norm # per avere norma 1 per calcolare cosine similarity

          # here cosine similarity between image and text features and keep the argmax for every row (every image)
          logits = image_features @ text_features.T

          loss = criterion(logits, target)
          total_loss += loss.item()

          # Validation accuracy computation
          predicted_class = logits.argmax(dim=-1)
          correct_predictions += (predicted_class == target).sum().item()

    print(f"Epoch {epoch + 1}, Validation loss: {total_loss/ len(val_dataloader)}; Validation accuracy: {correct_predictions / len(val_dataset)*100:.2f}%")
    print()
    print("-----------------------------------------------------------------------------------------------")

  return(model)

model, _ = clip.load("ViT-B/16", device=device)
model_ft_visual_layer = fine_tuning_linear_visual(model=model, train_dataset=augmented_train_base, val_dataset=val_base, categories=base_classes, lr = 0.0001, batch_size=128, num_epochs=50, device=device)


🧠 Fine-tuning training+validation on Base Classes


100%|██████████| 4/4 [00:14<00:00,  3.61s/it]


Epoch 1, Training loss: 3.8768633604049683; Training accuracy: 47.65%



100%|██████████| 4/4 [00:08<00:00,  2.18s/it]


Epoch 1, Validation loss: 3.8230097889900208; Validation accuracy: 72.35%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.55s/it]


Epoch 2, Training loss: 3.8614810705184937; Training accuracy: 52.16%



100%|██████████| 4/4 [00:07<00:00,  1.91s/it]


Epoch 2, Validation loss: 3.808045983314514; Validation accuracy: 72.35%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:13<00:00,  3.46s/it]


Epoch 3, Training loss: 3.8457061648368835; Training accuracy: 55.88%



100%|██████████| 4/4 [00:07<00:00,  1.92s/it]


Epoch 3, Validation loss: 3.793490946292877; Validation accuracy: 71.18%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.58s/it]


Epoch 4, Training loss: 3.832327425479889; Training accuracy: 54.51%



100%|██████████| 4/4 [00:07<00:00,  1.99s/it]


Epoch 4, Validation loss: 3.7798545956611633; Validation accuracy: 69.41%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.62s/it]


Epoch 5, Training loss: 3.816689193248749; Training accuracy: 53.73%



100%|██████████| 4/4 [00:07<00:00,  2.00s/it]


Epoch 5, Validation loss: 3.7670430541038513; Validation accuracy: 66.47%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.58s/it]


Epoch 6, Training loss: 3.803203582763672; Training accuracy: 53.53%



100%|██████████| 4/4 [00:07<00:00,  1.91s/it]


Epoch 6, Validation loss: 3.755210518836975; Validation accuracy: 65.10%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:13<00:00,  3.31s/it]


Epoch 7, Training loss: 3.7915972471237183; Training accuracy: 51.37%



100%|██████████| 4/4 [00:07<00:00,  1.90s/it]


Epoch 7, Validation loss: 3.7439956665039062; Validation accuracy: 63.92%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.61s/it]


Epoch 8, Training loss: 3.7758986949920654; Training accuracy: 53.92%



100%|██████████| 4/4 [00:08<00:00,  2.17s/it]


Epoch 8, Validation loss: 3.733534514904022; Validation accuracy: 62.75%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.51s/it]


Epoch 9, Training loss: 3.763637900352478; Training accuracy: 55.69%



100%|██████████| 4/4 [00:07<00:00,  1.93s/it]


Epoch 9, Validation loss: 3.723497986793518; Validation accuracy: 62.55%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.58s/it]


Epoch 10, Training loss: 3.7542428970336914; Training accuracy: 55.69%



100%|██████████| 4/4 [00:07<00:00,  1.97s/it]


Epoch 10, Validation loss: 3.714277505874634; Validation accuracy: 61.37%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:13<00:00,  3.49s/it]


Epoch 11, Training loss: 3.7430036067962646; Training accuracy: 54.12%



100%|██████████| 4/4 [00:07<00:00,  1.92s/it]


Epoch 11, Validation loss: 3.705909252166748; Validation accuracy: 62.16%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.57s/it]


Epoch 12, Training loss: 3.7317751049995422; Training accuracy: 57.45%



100%|██████████| 4/4 [00:08<00:00,  2.06s/it]


Epoch 12, Validation loss: 3.69808030128479; Validation accuracy: 62.16%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.52s/it]


Epoch 13, Training loss: 3.724349558353424; Training accuracy: 56.67%



100%|██████████| 4/4 [00:07<00:00,  1.92s/it]


Epoch 13, Validation loss: 3.69068443775177; Validation accuracy: 61.76%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:13<00:00,  3.49s/it]


Epoch 14, Training loss: 3.7136645317077637; Training accuracy: 57.65%



100%|██████████| 4/4 [00:07<00:00,  1.91s/it]


Epoch 14, Validation loss: 3.6835858821868896; Validation accuracy: 61.76%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.61s/it]


Epoch 15, Training loss: 3.7069685459136963; Training accuracy: 56.47%



100%|██████████| 4/4 [00:08<00:00,  2.11s/it]


Epoch 15, Validation loss: 3.67710018157959; Validation accuracy: 61.96%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.56s/it]


Epoch 16, Training loss: 3.6963196992874146; Training accuracy: 59.61%



100%|██████████| 4/4 [00:08<00:00,  2.01s/it]


Epoch 16, Validation loss: 3.6711063385009766; Validation accuracy: 62.35%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.63s/it]


Epoch 17, Training loss: 3.68886661529541; Training accuracy: 58.63%



100%|██████████| 4/4 [00:07<00:00,  1.93s/it]


Epoch 17, Validation loss: 3.6655802130699158; Validation accuracy: 61.96%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:13<00:00,  3.40s/it]


Epoch 18, Training loss: 3.6832094192504883; Training accuracy: 57.84%



100%|██████████| 4/4 [00:07<00:00,  1.93s/it]


Epoch 18, Validation loss: 3.6604424715042114; Validation accuracy: 62.16%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.61s/it]


Epoch 19, Training loss: 3.675031363964081; Training accuracy: 58.43%



100%|██████████| 4/4 [00:08<00:00,  2.16s/it]


Epoch 19, Validation loss: 3.6552688479423523; Validation accuracy: 62.94%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.54s/it]


Epoch 20, Training loss: 3.6689572930336; Training accuracy: 60.98%



100%|██████████| 4/4 [00:07<00:00,  1.92s/it]


Epoch 20, Validation loss: 3.6505054235458374; Validation accuracy: 63.33%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.61s/it]


Epoch 21, Training loss: 3.659847855567932; Training accuracy: 64.90%



100%|██████████| 4/4 [00:07<00:00,  1.93s/it]


Epoch 21, Validation loss: 3.6457875967025757; Validation accuracy: 63.53%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.52s/it]


Epoch 22, Training loss: 3.6535385251045227; Training accuracy: 61.18%



100%|██████████| 4/4 [00:07<00:00,  1.93s/it]


Epoch 22, Validation loss: 3.641604244709015; Validation accuracy: 63.92%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.62s/it]


Epoch 23, Training loss: 3.6504332423210144; Training accuracy: 63.33%



100%|██████████| 4/4 [00:08<00:00,  2.02s/it]


Epoch 23, Validation loss: 3.6374924778938293; Validation accuracy: 64.51%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.56s/it]


Epoch 24, Training loss: 3.642791211605072; Training accuracy: 66.27%



100%|██████████| 4/4 [00:07<00:00,  1.93s/it]


Epoch 24, Validation loss: 3.6338298320770264; Validation accuracy: 64.12%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:13<00:00,  3.39s/it]


Epoch 25, Training loss: 3.637945294380188; Training accuracy: 64.12%



100%|██████████| 4/4 [00:07<00:00,  1.91s/it]


Epoch 25, Validation loss: 3.630530297756195; Validation accuracy: 64.71%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.65s/it]


Epoch 26, Training loss: 3.633759915828705; Training accuracy: 62.94%



100%|██████████| 4/4 [00:08<00:00,  2.17s/it]


Epoch 26, Validation loss: 3.6273223757743835; Validation accuracy: 65.10%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.58s/it]


Epoch 27, Training loss: 3.626526117324829; Training accuracy: 65.29%



100%|██████████| 4/4 [00:07<00:00,  2.00s/it]


Epoch 27, Validation loss: 3.6243260502815247; Validation accuracy: 65.29%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.59s/it]


Epoch 28, Training loss: 3.622128188610077; Training accuracy: 66.86%



100%|██████████| 4/4 [00:07<00:00,  1.93s/it]


Epoch 28, Validation loss: 3.621335208415985; Validation accuracy: 65.29%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:13<00:00,  3.38s/it]


Epoch 29, Training loss: 3.617633283138275; Training accuracy: 66.86%



100%|██████████| 4/4 [00:07<00:00,  1.92s/it]


Epoch 29, Validation loss: 3.6186075806617737; Validation accuracy: 65.29%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.60s/it]


Epoch 30, Training loss: 3.6129109263420105; Training accuracy: 66.08%



100%|██████████| 4/4 [00:08<00:00,  2.12s/it]


Epoch 30, Validation loss: 3.6161869168281555; Validation accuracy: 65.88%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.54s/it]


Epoch 31, Training loss: 3.606734037399292; Training accuracy: 68.04%



100%|██████████| 4/4 [00:07<00:00,  1.92s/it]


Epoch 31, Validation loss: 3.613726496696472; Validation accuracy: 67.06%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.60s/it]


Epoch 32, Training loss: 3.606319785118103; Training accuracy: 65.49%



100%|██████████| 4/4 [00:07<00:00,  1.94s/it]


Epoch 32, Validation loss: 3.611573040485382; Validation accuracy: 67.06%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.50s/it]


Epoch 33, Training loss: 3.604689836502075; Training accuracy: 65.49%



100%|██████████| 4/4 [00:07<00:00,  1.97s/it]


Epoch 33, Validation loss: 3.6095250844955444; Validation accuracy: 67.45%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.63s/it]


Epoch 34, Training loss: 3.5984089970588684; Training accuracy: 67.65%



100%|██████████| 4/4 [00:08<00:00,  2.02s/it]


Epoch 34, Validation loss: 3.607182264328003; Validation accuracy: 68.43%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.57s/it]


Epoch 35, Training loss: 3.595183551311493; Training accuracy: 69.80%



100%|██████████| 4/4 [00:07<00:00,  1.93s/it]


Epoch 35, Validation loss: 3.6052370071411133; Validation accuracy: 69.22%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:13<00:00,  3.32s/it]


Epoch 36, Training loss: 3.5911683440208435; Training accuracy: 70.39%



100%|██████████| 4/4 [00:07<00:00,  1.92s/it]


Epoch 36, Validation loss: 3.603569269180298; Validation accuracy: 69.02%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.61s/it]


Epoch 37, Training loss: 3.5873671770095825; Training accuracy: 68.24%



100%|██████████| 4/4 [00:08<00:00,  2.22s/it]


Epoch 37, Validation loss: 3.601924240589142; Validation accuracy: 69.80%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.61s/it]


Epoch 38, Training loss: 3.5832549333572388; Training accuracy: 73.33%



100%|██████████| 4/4 [00:07<00:00,  1.97s/it]


Epoch 38, Validation loss: 3.6002699732780457; Validation accuracy: 69.80%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.61s/it]


Epoch 39, Training loss: 3.582227349281311; Training accuracy: 72.35%



100%|██████████| 4/4 [00:07<00:00,  1.94s/it]


Epoch 39, Validation loss: 3.5990249514579773; Validation accuracy: 70.98%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:13<00:00,  3.41s/it]


Epoch 40, Training loss: 3.581490397453308; Training accuracy: 70.00%



100%|██████████| 4/4 [00:07<00:00,  1.92s/it]


Epoch 40, Validation loss: 3.5976768136024475; Validation accuracy: 72.94%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.66s/it]


Epoch 41, Training loss: 3.5755075216293335; Training accuracy: 70.98%



100%|██████████| 4/4 [00:08<00:00,  2.16s/it]


Epoch 41, Validation loss: 3.596402168273926; Validation accuracy: 72.55%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.52s/it]


Epoch 42, Training loss: 3.573099374771118; Training accuracy: 72.55%



100%|██████████| 4/4 [00:07<00:00,  1.95s/it]


Epoch 42, Validation loss: 3.594589054584503; Validation accuracy: 72.94%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.64s/it]


Epoch 43, Training loss: 3.5738479495048523; Training accuracy: 71.76%



100%|██████████| 4/4 [00:07<00:00,  1.92s/it]


Epoch 43, Validation loss: 3.5940175652503967; Validation accuracy: 73.14%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.52s/it]


Epoch 44, Training loss: 3.570216476917267; Training accuracy: 72.55%



100%|██████████| 4/4 [00:07<00:00,  1.92s/it]


Epoch 44, Validation loss: 3.593148171901703; Validation accuracy: 73.73%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.64s/it]


Epoch 45, Training loss: 3.56357878446579; Training accuracy: 74.51%



100%|██████████| 4/4 [00:08<00:00,  2.04s/it]


Epoch 45, Validation loss: 3.59201443195343; Validation accuracy: 74.31%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.62s/it]


Epoch 46, Training loss: 3.5658059120178223; Training accuracy: 73.53%



100%|██████████| 4/4 [00:07<00:00,  1.91s/it]


Epoch 46, Validation loss: 3.5903924703598022; Validation accuracy: 74.51%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:13<00:00,  3.39s/it]


Epoch 47, Training loss: 3.5635019540786743; Training accuracy: 74.51%



100%|██████████| 4/4 [00:07<00:00,  1.93s/it]


Epoch 47, Validation loss: 3.589009463787079; Validation accuracy: 74.12%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.64s/it]


Epoch 48, Training loss: 3.5581817626953125; Training accuracy: 75.69%



100%|██████████| 4/4 [00:08<00:00,  2.14s/it]


Epoch 48, Validation loss: 3.5885097980499268; Validation accuracy: 75.10%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.60s/it]


Epoch 49, Training loss: 3.5601168274879456; Training accuracy: 75.69%



100%|██████████| 4/4 [00:08<00:00,  2.05s/it]


Epoch 49, Validation loss: 3.587945818901062; Validation accuracy: 76.08%

-----------------------------------------------------------------------------------------------


100%|██████████| 4/4 [00:14<00:00,  3.64s/it]


Epoch 50, Training loss: 3.558187425136566; Training accuracy: 78.24%



100%|██████████| 4/4 [00:07<00:00,  1.93s/it]

Epoch 50, Validation loss: 3.5875794887542725; Validation accuracy: 75.69%

-----------------------------------------------------------------------------------------------





Then we evaluate the fine-tuned CLIP in both base test set (few-shot evaluation) and novel test set (zero-shot evaluation).

In [None]:
base_accuracy = eval(model=model_ft_visual_layer, dataset=test_base, categories=base_classes, batch_size=128, device=device, label="🧠 Few-shot evaluation on Base Classes")
novel_accuracy = eval(model=model_ft_visual_layer, dataset=test_novel, categories=novel_classes, batch_size=128, device=device, label="🧠 Zero-shot evaluation on Novel Classes")
print()
print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")

🧠 Few-shot evaluation on Base Classes: 100%|██████████| 20/20 [00:30<00:00,  1.51s/it]
🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 29/29 [00:43<00:00,  1.51s/it]


🔍 Base classes accuracy: 76.26%
🔍 Novel classes accuracy: 46.74%



