# CLIP zero-shot Evaluation
This short notebook implements the dataset split into base and novel categories (see project assignment) and runs the zero-shot evaluation with CLIP.
Feel free to copy the code contained in this notebook or to directly use this notebook as starting point for you project.

In [1]:
# we need to install clip as it is not pre-installed
# you are also free to use open_clip which provide more models
# https://github.com/mlfoundations/open_clip
%pip install openai_clip

Collecting openai_clip
  Downloading openai-clip-1.0.1.tar.gz (1.4 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.4 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m46.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from openai_clip)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: openai_clip
  Building wheel for openai_clip (setup.py) ... [?25l[?25hdone
  Created wheel for openai_clip: filename=openai_clip-1.0.1-py3-none-any.whl size=1368605 sha256=cfa5d8aab53f6e6f364f434110dd4004e474064f8a1c050bcbd959aa75991b04
  Stored in directory: /root/.cache/pip/wheels/0d/17/90/042948fd2e2a87f1dcf6db6d438cad015c49

In [2]:
import torch
import torchvision
import clip
from tqdm import tqdm

## Dataset Loading
Let's get the data directly from torchvision as we have seen during labs.

In [3]:
def get_data(data_dir="./data", transform=None):
    """Load Flowers102 train, validation and test sets.
    Args:
        data_dir (str): Directory where the dataset will be stored.
        transform (torch.Compose)
    Returns:
        tuple: A tuple containing the train, validation, and test sets.
    """
    train = torchvision.datasets.Flowers102(root=data_dir, split="train", download=True, transform=transform)
    val = torchvision.datasets.Flowers102(root=data_dir, split="val", download=True, transform=transform)
    test = torchvision.datasets.Flowers102(root=data_dir, split="test", download=True, transform=transform)
    return train, val, test

## Base and Novel categories
To split in base and novel categories we list all dataset classes, and count their number (we already know it's 102 but let's do it properly).
Then, we just allocate the first half to base categories and the remaining half to novel ones.
We can do this because we are simulating a real world application, but keep in mind this will not happen out there!

In [4]:
def base_novel_categories(dataset):
    # set returns the unique set of all dataset classes
    all_classes = set(dataset._labels)
    # and let's count them
    num_classes = len(all_classes)

    # here list(range(num_classes)) returns a list from 0 to num_classes - 1
    # then we slice the list in half and generate base and novel category lists
    base_classes = list(range(num_classes))[:num_classes//2]
    novel_classes = list(range(num_classes))[num_classes//2:]
    return base_classes, novel_classes

## Inspect Classes
Let's now visualize which are the base and novel classes.
To do so, we first get a dummy test set (without augmentations) as we are just interested in the dataset labels. Then, we split it useing `base_novel_categories`.
Finally, we use the hard-coded CLASS_NAMES to print the class in natural language.

> Note: the list of class names was only recently added to `torchvision.datasets.Flowers102`. To avoid useless errors that can occour to you, we decided to also provide such a list.

In [5]:
from os import name
_, _, tmp_test = get_data()
base_classes, novel_classes = base_novel_categories(tmp_test)
CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]
print("Base Class Names:", [(i, CLASS_NAMES[i]) for i in base_classes])
name_base = [CLASS_NAMES[i] for i in base_classes]
name_novel = [CLASS_NAMES[i] for i in novel_classes]
print("Novel Class Names:", [(i, CLASS_NAMES[i]) for i in novel_classes])

100%|██████████| 345M/345M [00:20<00:00, 16.7MB/s]
100%|██████████| 502/502 [00:00<00:00, 869kB/s]
100%|██████████| 15.0k/15.0k [00:00<00:00, 30.3MB/s]


Base Class Names: [(0, 'pink primrose'), (1, 'hard-leaved pocket orchid'), (2, 'canterbury bells'), (3, 'sweet pea'), (4, 'english marigold'), (5, 'tiger lily'), (6, 'moon orchid'), (7, 'bird of paradise'), (8, 'monkshood'), (9, 'globe thistle'), (10, 'snapdragon'), (11, "colt's foot"), (12, 'king protea'), (13, 'spear thistle'), (14, 'yellow iris'), (15, 'globe-flower'), (16, 'purple coneflower'), (17, 'peruvian lily'), (18, 'balloon flower'), (19, 'giant white arum lily'), (20, 'fire lily'), (21, 'pincushion flower'), (22, 'fritillary'), (23, 'red ginger'), (24, 'grape hyacinth'), (25, 'corn poppy'), (26, 'prince of wales feathers'), (27, 'stemless gentian'), (28, 'artichoke'), (29, 'sweet william'), (30, 'carnation'), (31, 'garden phlox'), (32, 'love in the mist'), (33, 'mexican aster'), (34, 'alpine sea holly'), (35, 'ruby-lipped cattleya'), (36, 'cape flower'), (37, 'great masterwort'), (38, 'siam tulip'), (39, 'lenten rose'), (40, 'barbeton daisy'), (41, 'daffodil'), (42, 'sword 

## Split Dataset
The next step is to actually split the dataset into the base and novel categories we extract from `base_novel_categories`.
To split the data we need the dataset (obviously) and the list of base classes. If the sample label is not part of the base categories, then it must be part of the novel ones.

In [6]:
def split_data(dataset, base_classes):
    # these two lists will store the sample indexes
    base_categories_samples = []
    novel_categories_samples = []

    # we create a set of base classes to compute the test below in O(1)
    # this is optional and can be removed
    base_set = set(base_classes)

    # here we iterate over sample labels and also get the correspondent sample index
    for sample_id, label in enumerate(dataset._labels):
        if label in base_set:
            base_categories_samples.append(sample_id)
        else:
            novel_categories_samples.append(sample_id)

    # here we create the dataset subsets
    # the torch Subset is just a wrapper around the dataset
    # it simply stores the subset indexes and the original dataset (your_subset.dataset)
    # when asking for sample i in the subset, torch will look for its original position in the dataset and retrieve it
    # https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset
    base_dataset = torch.utils.data.Subset(dataset, base_categories_samples)
    novel_dataset = torch.utils.data.Subset(dataset, novel_categories_samples)
    return base_dataset, novel_dataset

## Extract k shots
As the dataset already provides 10 train and validation shots, we do not need to extract them.
Beaware that Few-Shot Adaptation papers must do this operation as most datasets count significantly more samples in both the training and validation sets.

## Load CLIP

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# available models = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
model, preprocess = clip.load("ViT-B/16", device=device)

# preprocess contains CLIP's pre-defined augmentations, let's inspect them!
preprocess

100%|███████████████████████████████████████| 335M/335M [00:07<00:00, 46.5MiB/s]


Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    <function _convert_image_to_rgb at 0x7ea154856520>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

## Load and Prepare Data
Here we get the three dataset split and pass clip pre-defined augmentations.
Then, we compute base and novel categories (in this case is redundand as we already did it before).
Finally, se split the three datasets into base and novel categories.
As we want to use the novel categories only for the test set, we drop `train_novel` and `val_novel`.

In [8]:
# get the three datasets
train_set, val_set, test_set = get_data(transform=preprocess)

# split classes into base and novel
base_classes, novel_classes = base_novel_categories(train_set)

# split the three datasets
train_base, _ = split_data(train_set, base_classes)
val_base, _ = split_data(val_set, base_classes)
test_base, test_novel = split_data(test_set, base_classes)

## Compute Zero-Shot Predictions

In [None]:
@torch.no_grad() # we don't want gradients
def eval(model, dataset, categories, batch_size, device, label=""):
    # let's set the model in evaluation mode
    model.eval()

    # Remap labels into a contiguous set starting from zero
    contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}

    # here we apply the standard CLIP template used for oxford flowers to all categories
    # and immediately tokenize each sentence (convert natural language into numbers - feel free to print the text input to inspect them)
    text_inputs = clip.tokenize(
        [f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in categories]
    ).to(device)

    # we can encode the text features once as they are shared for all images
    # therefore we do it outside the evaluation loop
    text_features = model.encode_text(text_inputs)
    # and here we normalize them (standard pratice with CLIP)
    text_features /= text_features.norm(dim=-1, keepdim=True) # per avere norma 1 per calcolare cosine similarity

    # simple dataloader creation
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # here we store the number of correct predictions we will make
    correct_predictions = 0
    for image, target in tqdm(dataloader, desc=label):
        # base categories range from 0 to 50, whil novel ones from 51 to 101
        # therefore we must map categories to the [0, 50], otherwise we will have wrong predictions
        # Map targets in contiguous set starting from zero
        # Labels needs to be .long() in pytorch
        target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

        image = image.to(device)
        target = target.to(device)

        # forward image through CLIP image encoder
        image_features = model.encode_image(image)
        # and normalize
        image_features /= image_features.norm(dim=-1, keepdim=True)

        # here cosine similarity between image and text features and keep the argmax for every row (every image)
        predicted_class = (image_features @ text_features.T).argmax(dim=-1)
        # now we check which are correct, and sum them (False == 0, True == 1)
        correct_predictions += (predicted_class == target).sum().item()

    # and now we compute the accuracy
    accuracy = correct_predictions / len(dataset)
    return accuracy

base_accuracy = eval(model=model, dataset=test_base, categories=base_classes, batch_size=128, device=device, label="🧠 Zero-shot evaluation on Base Classes")
novel_accuracy = eval(model=model, dataset=test_novel, categories=novel_classes, batch_size=128, device=device, label="🧠 Zero-shot evaluation on Novel Classes")

print()
print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")


🧠 Zero-shot evaluation on Base Classes: 100%|██████████| 20/20 [00:18<00:00,  1.07it/s]
🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 29/29 [00:35<00:00,  1.21s/it]


🔍 Base classes accuracy: 71.33%
🔍 Novel classes accuracy: 78.24%





## Harmonic Mean
Few-Shot Adaptations papers usually report the Harmonic Mean.
The harmonic mean tends to mitigate the impact of large outliers (base accuracy) and aggravate the impact of small ones (novel accuracy).
Thus, achieving very high base accuracies at the expense of the novel accuracy will be penalized by the HM.

In [None]:
def harmonic_mean(base_accuracy, novel_accuracy):
    numerator = 2
    denominator = 1 / base_accuracy + 1 / novel_accuracy
    hm = numerator / denominator
    return hm

print(f"🔍 Harmonic Mean: {harmonic_mean(base_accuracy, novel_accuracy)*100:.2f}%")

🔍 Harmonic Mean: 74.62%


# Data Augmentation

In [None]:
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image

# Define data augmentation transformations
augmentation_transforms = transforms.Compose([
    transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0)),  # Random crop and resize
    transforms.RandomHorizontalFlip(p=0.5),  # 50% chance to flip horizontally
    #transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Random color adjustments
    transforms.RandomSolarize(0.5, p=1),
    transforms.RandomRotation(degrees=15),  # Rotate image within ±15 degrees
    #transforms.ToTensor(),  # Convert PIL image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize for CLIP
])

# Validation transformations (no augmentation, just normalization)
validation_transforms = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to a fixed size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
from torch.utils.data import Dataset
from PIL import Image

# Example dataset class
class AugmentedImageTextDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data  # Assume data is a list of (image_path, label) tuples
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image, label = self.data[idx]

        if self.transform:
          image = self.transform(image)
        return image, label
augmented_train_base = AugmentedImageTextDataset(data = train_base, transform = augmentation_transforms)

## Fine-Tuning of the textual linear layer

We fine-tune the last linear layer of the textual encoder for the classification of the base train data.

In [None]:
from torch.utils.data import DataLoader
from torch.optim import Adam
import torchvision.transforms as transforms

def fine_tuning_linear_text(model, train_dataset, val_dataset, categories, lr, batch_size, num_epochs, device):

  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
  val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

  model = model.float()

  text_projection = model.text_projection

  # Freeze all parameters in the model
  for param in model.parameters():
      param.requires_grad = False

  # Unfreeze the projection layer
  text_projection.requires_grad = True

  criterion = torch.nn.CrossEntropyLoss()
  optimizer = Adam([text_projection], lr=lr)

  contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}

  print("🧠 Fine-tuning training+validation on Base Classes")

  for epoch in range(num_epochs):

    # Training of one epoch

    model.train()
    # here we store the sum of all the computed losses through the all batches
    total_loss = 0
    # here we store the number of correct predictions we will make
    correct_predictions = 0

    for image, target in tqdm(train_dataloader):

          target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

          image = image.to(device).float()
          target = target.to(device)

          text_inputs = clip.tokenize(
          [f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in categories]).to(device)

          text_features = model.encode_text(text_inputs).float()
          # and here we normalize them (standard pratice with CLIP)
          text_feature_norm = text_features.norm(dim=-1, keepdim=True)
          text_features = text_features/text_feature_norm # per avere norma 1 per calcolare cosine similarity

          # forward image through CLIP image encoder
          image_features = model.encode_image(image).float()
          # and normalize
          image_features_norm = image_features.norm(dim=-1, keepdim=True)
          image_features = image_features / image_features_norm # per avere norma 1 per calcolare cosine similarity

          # here cosine similarity between image and text features and keep the argmax for every row (every image)
          logits = image_features @ text_features.T

          loss = criterion(logits, target)
          total_loss += loss.item()

          # Backpropagation
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
          optimizer.step()
          optimizer.zero_grad()

          # Training accuracy computation
          predicted_class = logits.argmax(dim=-1)
          correct_predictions += (predicted_class == target).sum().item()

    print(f"Epoch {epoch + 1}, Training loss: {total_loss/ len(train_dataloader)}; Training accuracy: {correct_predictions / len(train_dataset)*100:.2f}%")
    print()

    # Validation of one epoch

    model.eval()
    total_loss = 0
    correct_predictions = 0

    for image, target in tqdm(val_dataloader):

          target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

          image = image.to(device).float()
          target = target.to(device)

          text_inputs = clip.tokenize(
          [f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in categories]).to(device)

          text_features = model.encode_text(text_inputs).float()
          # and here we normalize them (standard pratice with CLIP)
          text_feature_norm = text_features.norm(dim=-1, keepdim=True)
          text_features = text_features/text_feature_norm # per avere norma 1 per calcolare cosine similarity

          # forward image through CLIP image encoder
          image_features = model.encode_image(image).float()
          # and normalize
          image_features_norm = image_features.norm(dim=-1, keepdim=True)
          image_features = image_features / image_features_norm # per avere norma 1 per calcolare cosine similarity

          # here cosine similarity between image and text features and keep the argmax for every row (every image)
          logits = image_features @ text_features.T

          loss = criterion(logits, target)
          total_loss += loss.item()

          # Validation accuracy computation
          predicted_class = logits.argmax(dim=-1)
          correct_predictions += (predicted_class == target).sum().item()

    print(f"Epoch {epoch + 1}, Validation loss: {total_loss/ len(val_dataloader)}; Validation accuracy: {correct_predictions / len(val_dataset)*100:.2f}%")
    print()
    print("-----------------------------------------------------------------------------------------------")

  return(model)

model, _ = clip.load("ViT-B/16", device=device)
model_ft_text_layer = fine_tuning_linear_text(model=model, train_dataset=train_base, val_dataset=val_base, categories=base_classes, lr = 0.0001, batch_size=16, num_epochs=10, device=device)

🧠 Fine-tuning training+validation on Base Classes


100%|██████████| 32/32 [00:11<00:00,  2.84it/s]


Epoch 1, Training loss: 3.7992719411849976; Training accuracy: 60.39%



100%|██████████| 32/32 [00:11<00:00,  2.83it/s]


Epoch 1, Validation loss: 3.751607060432434; Validation accuracy: 76.27%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:11<00:00,  2.77it/s]


Epoch 2, Training loss: 3.7287665978074074; Training accuracy: 79.61%



100%|██████████| 32/32 [00:11<00:00,  2.84it/s]


Epoch 2, Validation loss: 3.7079309299588203; Validation accuracy: 84.31%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:11<00:00,  2.83it/s]


Epoch 3, Training loss: 3.689993605017662; Training accuracy: 87.06%



100%|██████████| 32/32 [00:11<00:00,  2.86it/s]


Epoch 3, Validation loss: 3.680222935974598; Validation accuracy: 88.43%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:11<00:00,  2.84it/s]


Epoch 4, Training loss: 3.6637882068753242; Training accuracy: 92.55%



100%|██████████| 32/32 [00:11<00:00,  2.89it/s]


Epoch 4, Validation loss: 3.6623907387256622; Validation accuracy: 90.98%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:11<00:00,  2.85it/s]


Epoch 5, Training loss: 3.6464011296629906; Training accuracy: 94.12%



100%|██████████| 32/32 [00:11<00:00,  2.87it/s]


Epoch 5, Validation loss: 3.6505211740732193; Validation accuracy: 92.35%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:11<00:00,  2.85it/s]


Epoch 6, Training loss: 3.6343393474817276; Training accuracy: 94.51%



100%|██████████| 32/32 [00:11<00:00,  2.90it/s]


Epoch 6, Validation loss: 3.642144314944744; Validation accuracy: 91.96%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:11<00:00,  2.83it/s]


Epoch 7, Training loss: 3.625513881444931; Training accuracy: 93.92%



100%|██████████| 32/32 [00:11<00:00,  2.88it/s]


Epoch 7, Validation loss: 3.636342190206051; Validation accuracy: 94.12%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:11<00:00,  2.88it/s]


Epoch 8, Training loss: 3.6201296970248222; Training accuracy: 96.67%



100%|██████████| 32/32 [00:11<00:00,  2.90it/s]


Epoch 8, Validation loss: 3.631933718919754; Validation accuracy: 94.71%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:11<00:00,  2.87it/s]


Epoch 9, Training loss: 3.6144131645560265; Training accuracy: 96.27%



100%|██████████| 32/32 [00:11<00:00,  2.87it/s]


Epoch 9, Validation loss: 3.6280209496617317; Validation accuracy: 93.73%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:11<00:00,  2.86it/s]


Epoch 10, Training loss: 3.6095264106988907; Training accuracy: 96.86%



100%|██████████| 32/32 [00:11<00:00,  2.90it/s]

Epoch 10, Validation loss: 3.624827913939953; Validation accuracy: 94.31%

-----------------------------------------------------------------------------------------------





Then we evaluate the fine-tuned CLIP in both base test set (few-shot evaluation) and novel test set (zero-shot evaluation).

In [None]:
base_accuracy = eval(model=model_ft_text_layer, dataset=test_base, categories=base_classes, batch_size=128, device=device, label="🧠 Few-shot evaluation on Base Classes")
novel_accuracy = eval(model=model_ft_text_layer, dataset=test_novel, categories=novel_classes, batch_size=128, device=device, label="🧠 Zero-shot evaluation on Novel Classes")
print()
print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")

🧠 Few-shot evaluation on Base Classes: 100%|██████████| 20/20 [00:31<00:00,  1.58s/it]
🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 29/29 [00:45<00:00,  1.58s/it]


🔍 Base classes accuracy: 92.72%
🔍 Novel classes accuracy: 63.85%





## Fine-Tuning of the visual linear layer

We fine-tune the last linear layer of the visual encoder for the classification of the base train data.

In [None]:
from torch.utils.data import DataLoader
from torch.optim import Adam
import torchvision.transforms as transforms

def fine_tuning_linear_visual(model, train_dataset, val_dataset, categories, lr, batch_size, num_epochs, device):

  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
  val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

  model = model.float()

  visual_projection = model.visual.proj
  text_projection = model.text_projection

  # Freeze all parameters in the model
  for param in model.parameters():
      param.requires_grad = False

  # Unfreeze the projection layer
  visual_projection.requires_grad = True
  #text_projection.requires_grad = True

  criterion = torch.nn.CrossEntropyLoss()
  optimizer = Adam([visual_projection], lr=lr)

  contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}



  print("🧠 Fine-tuning training+validation on Base Classes")

  for epoch in range(num_epochs):

    # Training of one epoch

    model.train()
    # here we store the sum of all the computed losses through the all batches
    total_loss = 0
    # here we store the number of correct predictions we will make
    correct_predictions = 0

    for image, target in tqdm(train_dataloader):
          text_inputs = clip.tokenize(
          [f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in categories]).to(device)

          text_features = model.encode_text(text_inputs).float()
          # and here we normalize them (standard pratice with CLIP)
          text_feature_norm = text_features.norm(dim=-1, keepdim=True)
          text_features = text_features/text_feature_norm # per avere norma 1 per calcolare cosine similarity

          target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

          image = image.to(device).float()
          target = target.to(device)

          # forward image through CLIP image encoder
          image_features = model.encode_image(image).float()
          # and normalize
          image_features_norm = image_features.norm(dim=-1, keepdim=True)
          image_features = image_features / image_features_norm # per avere norma 1 per calcolare cosine similarity

          # here cosine similarity between image and text features and keep the argmax for every row (every image)
          logits = image_features @ text_features.T

          loss = criterion(logits, target)
          total_loss += loss.item()

          # Backpropagation
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
          optimizer.step()
          optimizer.zero_grad()

          # Training accuracy computation
          predicted_class = logits.argmax(dim=-1)
          correct_predictions += (predicted_class == target).sum().item()

    print(f"Epoch {epoch + 1}, Training loss: {total_loss/ len(train_dataloader)}; Training accuracy: {correct_predictions / len(train_dataset)*100:.2f}%")
    print()

    # Validation of one epoch

    model.eval()
    total_loss = 0
    correct_predictions = 0

    for image, target in tqdm(val_dataloader):

          target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

          image = image.to(device).float()
          target = target.to(device)

          # forward image through CLIP image encoder
          image_features = model.encode_image(image).float()
          # and normalize
          image_features_norm = image_features.norm(dim=-1, keepdim=True)
          image_features = image_features / image_features_norm # per avere norma 1 per calcolare cosine similarity

          # here cosine similarity between image and text features and keep the argmax for every row (every image)
          logits = image_features @ text_features.T

          loss = criterion(logits, target)
          total_loss += loss.item()

          # Validation accuracy computation
          predicted_class = logits.argmax(dim=-1)
          correct_predictions += (predicted_class == target).sum().item()

    print(f"Epoch {epoch + 1}, Validation loss: {total_loss/ len(val_dataloader)}; Validation accuracy: {correct_predictions / len(val_dataset)*100:.2f}%")
    print()
    print("-----------------------------------------------------------------------------------------------")

  return(model)

model, _ = clip.load("ViT-B/16", device=device)
model_ft_visual_layer = fine_tuning_linear_visual(model=model, train_dataset=augmented_train_base, val_dataset=val_base, categories=base_classes, lr = 0.0001, batch_size=16, num_epochs=30, device=device)


TypeError: fine_tuning_linear_visual() got an unexpected keyword argument 'augmented_train_dataset'

Then we evaluate the fine-tuned CLIP in both base test set (few-shot evaluation) and novel test set (zero-shot evaluation).

In [None]:
base_accuracy = eval(model=model_ft_visual_layer, dataset=test_base, categories=base_classes, batch_size=128, device=device, label="🧠 Few-shot evaluation on Base Classes")
novel_accuracy = eval(model=model_ft_visual_layer, dataset=test_novel, categories=novel_classes, batch_size=128, device=device, label="🧠 Zero-shot evaluation on Novel Classes")
print()
print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")

🧠 Few-shot evaluation on Base Classes: 100%|██████████| 20/20 [00:31<00:00,  1.59s/it]
🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 29/29 [00:46<00:00,  1.60s/it]


🔍 Base classes accuracy: 82.81%
🔍 Novel classes accuracy: 52.29%





# Using CLIP original loss

In [None]:
from torch.utils.data import DataLoader
from torch.optim import Adam
import torchvision.transforms as transforms

temperature = 0.07

def fine_tuning_linear_visual(model, train_dataset, val_dataset, categories, lr, batch_size, num_epochs, device):

  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
  val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

  model = model.float()

  visual_projection = model.visual.proj
  text_projection = model.text_projection

  # Freeze all parameters in the model
  for param in model.parameters():
      param.requires_grad = False

  # Unfreeze the projection layer
  visual_projection.requires_grad = True
  #text_projection.requires_grad = True

  criterion = torch.nn.CrossEntropyLoss()
  optimizer = Adam([visual_projection], lr=lr)

  contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}



  print("🧠 Fine-tuning training+validation on Base Classes")

  for epoch in range(num_epochs):

    # Training of one epoch

    model.train()
    # here we store the sum of all the computed losses through the all batches
    total_loss = 0
    # here we store the number of correct predictions we will make
    correct_predictions = 0

    for image, target in tqdm(train_dataloader):
          text_inputs = clip.tokenize(
          [f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in target]).to(device)

          text_features = model.encode_text(text_inputs).float()
          # and here we normalize them (standard pratice with CLIP)
          text_feature_norm = text_features.norm(dim=-1, keepdim=True)
          text_features = text_features/text_feature_norm # per avere norma 1 per calcolare cosine similarity


          target = torch.arange(len(target))

          image = image.to(device).float()
          target = target.to(device)

          # forward image through CLIP image encoder
          image_features = model.encode_image(image).float()
          # and normalize
          image_features_norm = image_features.norm(dim=-1, keepdim=True)
          image_features = image_features / image_features_norm # per avere norma 1 per calcolare cosine similarity

          # here cosine similarity between image and text features and keep the argmax for every row (every image)
          logits = (image_features @ text_features.T)* torch.exp(torch.tensor([temperature]).to(device))

          loss = (criterion(logits, target) +criterion(logits.T, target))/2
          total_loss += loss.item()

          # Backpropagation
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
          optimizer.step()
          optimizer.zero_grad()

          # Training accuracy computation
          predicted_class = logits.argmax(dim=-1)
          correct_predictions += (predicted_class == target).sum().item()

    print(f"Epoch {epoch + 1}, Training loss: {total_loss/ len(train_dataloader)}; Training accuracy: {correct_predictions / len(train_dataset)*100:.2f}%")
    print()

    # Validation of one epoch

    model.eval()
    total_loss = 0
    correct_predictions = 0

    for image, target in tqdm(val_dataloader):
          text_inputs = clip.tokenize(
          [f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in categories]).to(device)

          text_features = model.encode_text(text_inputs).float()
          # and here we normalize them (standard pratice with CLIP)
          text_feature_norm = text_features.norm(dim=-1, keepdim=True)
          text_features = text_features/text_feature_norm # per avere norma 1 per calcolare cosine similarity

          target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

          image = image.to(device).float()
          target = target.to(device)

          # forward image through CLIP image encoder
          image_features = model.encode_image(image).float()
          # and normalize
          image_features_norm = image_features.norm(dim=-1, keepdim=True)
          image_features = image_features / image_features_norm # per avere norma 1 per calcolare cosine similarity

          # here cosine similarity between image and text features and keep the argmax for every row (every image)
          logits = image_features @ text_features.T

          loss = criterion(logits, target)
          total_loss += loss.item()

          # Validation accuracy computation
          predicted_class = logits.argmax(dim=-1)
          correct_predictions += (predicted_class == target).sum().item()

    print(f"Epoch {epoch + 1}, Validation loss: {total_loss/ len(val_dataloader)}; Validation accuracy: {correct_predictions / len(val_dataset)*100:.2f}%")
    print()
    print("-----------------------------------------------------------------------------------------------")

  return(model)

model, _ = clip.load("ViT-B/16", device=device)
model_ft_visual_layer = fine_tuning_linear_visual(model=model, train_dataset=train_base, val_dataset=val_base, categories=base_classes, lr = 0.00001, batch_size=16, num_epochs=20, device=device)


🧠 Fine-tuning training+validation on Base Classes


100%|██████████| 32/32 [00:07<00:00,  4.02it/s]


Epoch 1, Training loss: 2.669394940137863; Training accuracy: 71.18%



100%|██████████| 32/32 [00:10<00:00,  2.92it/s]


Epoch 1, Validation loss: 3.8241214007139206; Validation accuracy: 73.14%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:08<00:00,  3.75it/s]


Epoch 2, Training loss: 2.6547181755304337; Training accuracy: 70.20%



100%|██████████| 32/32 [00:10<00:00,  3.01it/s]


Epoch 2, Validation loss: 3.809515818953514; Validation accuracy: 73.33%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:08<00:00,  3.91it/s]


Epoch 3, Training loss: 2.6390961185097694; Training accuracy: 70.39%



100%|██████████| 32/32 [00:10<00:00,  3.13it/s]


Epoch 3, Validation loss: 3.7945312410593033; Validation accuracy: 71.96%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:07<00:00,  4.19it/s]


Epoch 4, Training loss: 2.6234960108995438; Training accuracy: 70.59%



100%|██████████| 32/32 [00:10<00:00,  3.16it/s]


Epoch 4, Validation loss: 3.780126266181469; Validation accuracy: 70.20%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:08<00:00,  3.91it/s]


Epoch 5, Training loss: 2.6073362827301025; Training accuracy: 70.59%



100%|██████████| 32/32 [00:10<00:00,  3.20it/s]


Epoch 5, Validation loss: 3.766261711716652; Validation accuracy: 68.63%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:07<00:00,  4.07it/s]


Epoch 6, Training loss: 2.595405198633671; Training accuracy: 69.41%



100%|██████████| 32/32 [00:10<00:00,  3.13it/s]


Epoch 6, Validation loss: 3.7536500096321106; Validation accuracy: 68.24%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:07<00:00,  4.05it/s]


Epoch 7, Training loss: 2.5816031396389008; Training accuracy: 69.80%



100%|██████████| 32/32 [00:10<00:00,  3.14it/s]


Epoch 7, Validation loss: 3.741886779665947; Validation accuracy: 67.25%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:08<00:00,  3.94it/s]


Epoch 8, Training loss: 2.5735903084278107; Training accuracy: 70.20%



100%|██████████| 32/32 [00:10<00:00,  3.12it/s]


Epoch 8, Validation loss: 3.731088675558567; Validation accuracy: 67.45%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:07<00:00,  4.15it/s]


Epoch 9, Training loss: 2.5596833154559135; Training accuracy: 72.35%



100%|██████████| 32/32 [00:10<00:00,  3.09it/s]


Epoch 9, Validation loss: 3.721024088561535; Validation accuracy: 67.06%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:07<00:00,  4.01it/s]


Epoch 10, Training loss: 2.5493284538388252; Training accuracy: 70.78%



100%|██████████| 32/32 [00:10<00:00,  3.14it/s]


Epoch 10, Validation loss: 3.711524799466133; Validation accuracy: 66.27%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:07<00:00,  4.16it/s]


Epoch 11, Training loss: 2.5419431999325752; Training accuracy: 71.18%



100%|██████████| 32/32 [00:10<00:00,  3.13it/s]


Epoch 11, Validation loss: 3.7026090547442436; Validation accuracy: 67.06%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:07<00:00,  4.02it/s]


Epoch 12, Training loss: 2.5296263098716736; Training accuracy: 71.96%



100%|██████████| 32/32 [00:10<00:00,  3.15it/s]


Epoch 12, Validation loss: 3.6941885501146317; Validation accuracy: 67.06%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:07<00:00,  4.16it/s]


Epoch 13, Training loss: 2.522714301943779; Training accuracy: 72.16%



100%|██████████| 32/32 [00:10<00:00,  3.11it/s]


Epoch 13, Validation loss: 3.686231814324856; Validation accuracy: 67.06%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:07<00:00,  4.04it/s]


Epoch 14, Training loss: 2.5152296647429466; Training accuracy: 73.33%



100%|██████████| 32/32 [00:10<00:00,  3.15it/s]


Epoch 14, Validation loss: 3.6786944568157196; Validation accuracy: 66.67%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:08<00:00,  3.98it/s]


Epoch 15, Training loss: 2.5057100653648376; Training accuracy: 76.27%



100%|██████████| 32/32 [00:10<00:00,  3.13it/s]


Epoch 15, Validation loss: 3.671561114490032; Validation accuracy: 67.06%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:07<00:00,  4.08it/s]


Epoch 16, Training loss: 2.4966844469308853; Training accuracy: 73.73%



100%|██████████| 32/32 [00:10<00:00,  3.12it/s]


Epoch 16, Validation loss: 3.6647297367453575; Validation accuracy: 67.65%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:07<00:00,  4.02it/s]


Epoch 17, Training loss: 2.4932354390621185; Training accuracy: 71.76%



100%|██████████| 32/32 [00:10<00:00,  3.13it/s]


Epoch 17, Validation loss: 3.658332198858261; Validation accuracy: 67.65%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:07<00:00,  4.14it/s]


Epoch 18, Training loss: 2.483542248606682; Training accuracy: 75.49%



100%|██████████| 32/32 [00:10<00:00,  3.12it/s]


Epoch 18, Validation loss: 3.652178570628166; Validation accuracy: 68.04%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:07<00:00,  4.02it/s]


Epoch 19, Training loss: 2.4799063205718994; Training accuracy: 73.92%



100%|██████████| 32/32 [00:10<00:00,  3.14it/s]


Epoch 19, Validation loss: 3.6463252305984497; Validation accuracy: 68.63%

-----------------------------------------------------------------------------------------------


100%|██████████| 32/32 [00:07<00:00,  4.15it/s]


Epoch 20, Training loss: 2.4733694940805435; Training accuracy: 75.69%



100%|██████████| 32/32 [00:10<00:00,  3.13it/s]

Epoch 20, Validation loss: 3.640775181353092; Validation accuracy: 69.41%

-----------------------------------------------------------------------------------------------





In [None]:
base_accuracy = eval(model=model_ft_visual_layer, dataset=test_base, categories=base_classes, batch_size=128, device=device, label="🧠 Few-shot evaluation on Base Classes")
novel_accuracy = eval(model=model_ft_visual_layer, dataset=test_novel, categories=novel_classes, batch_size=128, device=device, label="🧠 Zero-shot evaluation on Novel Classes")
print()
print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")

🧠 Few-shot evaluation on Base Classes: 100%|██████████| 20/20 [00:29<00:00,  1.47s/it]
🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 29/29 [00:43<00:00,  1.50s/it]


🔍 Base classes accuracy: 69.07%
🔍 Novel classes accuracy: 63.47%





# Layer Norm

In [None]:
from torch.utils.data import DataLoader
from torch.optim import Adam
import torchvision.transforms as transforms

def fine_tuning_linear_visual(model, train_dataset, val_dataset, categories, lr, batch_size, num_epochs, device):

  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
  val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

  model = model.float()

  visual_projection = model.visual.proj
  text_projection = model.text_projection

  # Freeze all parameters in the model
  for name, param in model.named_parameters():
    if "ln" not in name:  # Check for LayerNorm parameters
        param.requires_grad = False

# Verify which parameters are trainable
  trainable_params = [param for name, param in model.named_parameters() if param.requires_grad]

  criterion = torch.nn.CrossEntropyLoss()
  optimizer = Adam(trainable_params, lr=lr)

  contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}



  print("🧠 Fine-tuning training+validation on Base Classes")

  for epoch in range(num_epochs):

    # Training of one epoch

    model.train()
    # here we store the sum of all the computed losses through the all batches
    total_loss = 0
    # here we store the number of correct predictions we will make
    correct_predictions = 0

    for image, target in tqdm(train_dataloader):
          text_inputs = clip.tokenize(
          [f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in categories]).to(device)

          text_features = model.encode_text(text_inputs).float()
          # and here we normalize them (standard pratice with CLIP)
          text_feature_norm = text_features.norm(dim=-1, keepdim=True)
          text_features = text_features/text_feature_norm # per avere norma 1 per calcolare cosine similarity

          target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

          image = image.to(device).float()
          target = target.to(device)

          # forward image through CLIP image encoder
          image_features = model.encode_image(image).float()
          # and normalize
          image_features_norm = image_features.norm(dim=-1, keepdim=True)
          image_features = image_features / image_features_norm # per avere norma 1 per calcolare cosine similarity

          # here cosine similarity between image and text features and keep the argmax for every row (every image)
          logits = image_features @ text_features.T

          loss = criterion(logits, target)
          total_loss += loss.item()

          # Backpropagation
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
          optimizer.step()
          optimizer.zero_grad()

          # Training accuracy computation
          predicted_class = logits.argmax(dim=-1)
          correct_predictions += (predicted_class == target).sum().item()

    print(f"Epoch {epoch + 1}, Training loss: {total_loss/ len(train_dataloader)}; Training accuracy: {correct_predictions / len(train_dataset)*100:.2f}%")
    print()

    # Validation of one epoch

    model.eval()
    total_loss = 0
    correct_predictions = 0

    for image, target in tqdm(val_dataloader):

          target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

          image = image.to(device).float()
          target = target.to(device)

          # forward image through CLIP image encoder
          image_features = model.encode_image(image).float()
          # and normalize
          image_features_norm = image_features.norm(dim=-1, keepdim=True)
          image_features = image_features / image_features_norm # per avere norma 1 per calcolare cosine similarity

          # here cosine similarity between image and text features and keep the argmax for every row (every image)
          logits = image_features @ text_features.T

          loss = criterion(logits, target)
          total_loss += loss.item()

          # Validation accuracy computation
          predicted_class = logits.argmax(dim=-1)
          correct_predictions += (predicted_class == target).sum().item()

    print(f"Epoch {epoch + 1}, Validation loss: {total_loss/ len(val_dataloader)}; Validation accuracy: {correct_predictions / len(val_dataset)*100:.2f}%")
    novel_accuracy = eval(model=model, dataset=test_novel, categories=novel_classes, batch_size=32, device=device, label="🧠 Zero-shot evaluation on Novel Classes")
    print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")
    print()
    print("-----------------------------------------------------------------------------------------------")

  return(model)

model, _ = clip.load("ViT-B/16", device=device)
model_ft_visual_layer = fine_tuning_linear_visual(model=model, train_dataset=train_base, val_dataset=val_base, categories=base_classes, lr = 0.0001, batch_size=32, num_epochs=30, device=device)

🧠 Fine-tuning training+validation on Base Classes


100%|██████████| 16/16 [00:18<00:00,  1.17s/it]


Epoch 1, Training loss: 3.836524084210396; Training accuracy: 69.02%



100%|██████████| 16/16 [00:06<00:00,  2.38it/s]


Epoch 1, Validation loss: 3.8312969505786896; Validation accuracy: 71.57%


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 115/115 [00:42<00:00,  2.73it/s]


🔍 Novel classes accuracy: 78.35%

-----------------------------------------------------------------------------------------------


100%|██████████| 16/16 [00:18<00:00,  1.13s/it]


Epoch 2, Training loss: 3.8295159488916397; Training accuracy: 69.41%



100%|██████████| 16/16 [00:06<00:00,  2.53it/s]


Epoch 2, Validation loss: 3.8246214538812637; Validation accuracy: 72.16%


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 115/115 [00:42<00:00,  2.70it/s]


🔍 Novel classes accuracy: 78.37%

-----------------------------------------------------------------------------------------------


100%|██████████| 16/16 [00:17<00:00,  1.11s/it]


Epoch 3, Training loss: 3.82290743291378; Training accuracy: 69.22%



100%|██████████| 16/16 [00:06<00:00,  2.44it/s]


Epoch 3, Validation loss: 3.818163588643074; Validation accuracy: 73.14%


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 115/115 [00:42<00:00,  2.74it/s]


🔍 Novel classes accuracy: 78.16%

-----------------------------------------------------------------------------------------------


100%|██████████| 16/16 [00:17<00:00,  1.11s/it]


Epoch 4, Training loss: 3.816444754600525; Training accuracy: 70.20%



100%|██████████| 16/16 [00:06<00:00,  2.54it/s]


Epoch 4, Validation loss: 3.811863273382187; Validation accuracy: 72.94%


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 115/115 [00:42<00:00,  2.73it/s]


🔍 Novel classes accuracy: 77.97%

-----------------------------------------------------------------------------------------------


100%|██████████| 16/16 [00:17<00:00,  1.12s/it]


Epoch 5, Training loss: 3.810079589486122; Training accuracy: 70.59%



100%|██████████| 16/16 [00:06<00:00,  2.55it/s]


Epoch 5, Validation loss: 3.8056256771087646; Validation accuracy: 72.94%


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 115/115 [00:42<00:00,  2.70it/s]


🔍 Novel classes accuracy: 77.88%

-----------------------------------------------------------------------------------------------


100%|██████████| 16/16 [00:17<00:00,  1.11s/it]


Epoch 6, Training loss: 3.8037253618240356; Training accuracy: 71.18%



100%|██████████| 16/16 [00:06<00:00,  2.51it/s]


Epoch 6, Validation loss: 3.7994285374879837; Validation accuracy: 72.75%


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 115/115 [00:42<00:00,  2.73it/s]


🔍 Novel classes accuracy: 77.88%

-----------------------------------------------------------------------------------------------


100%|██████████| 16/16 [00:17<00:00,  1.10s/it]


Epoch 7, Training loss: 3.797354355454445; Training accuracy: 71.18%



100%|██████████| 16/16 [00:06<00:00,  2.48it/s]


Epoch 7, Validation loss: 3.793176233768463; Validation accuracy: 72.55%


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 115/115 [00:42<00:00,  2.73it/s]


🔍 Novel classes accuracy: 77.69%

-----------------------------------------------------------------------------------------------


100%|██████████| 16/16 [00:17<00:00,  1.11s/it]


Epoch 8, Training loss: 3.7909667789936066; Training accuracy: 71.18%



100%|██████████| 16/16 [00:06<00:00,  2.51it/s]


Epoch 8, Validation loss: 3.7869236767292023; Validation accuracy: 72.94%


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 115/115 [00:42<00:00,  2.72it/s]


🔍 Novel classes accuracy: 77.58%

-----------------------------------------------------------------------------------------------


100%|██████████| 16/16 [00:18<00:00,  1.13s/it]


Epoch 9, Training loss: 3.7844999879598618; Training accuracy: 70.98%



100%|██████████| 16/16 [00:06<00:00,  2.54it/s]


Epoch 9, Validation loss: 3.78067210316658; Validation accuracy: 73.33%


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 115/115 [00:42<00:00,  2.69it/s]


🔍 Novel classes accuracy: 77.58%

-----------------------------------------------------------------------------------------------


100%|██████████| 16/16 [00:17<00:00,  1.11s/it]


Epoch 10, Training loss: 3.778153195977211; Training accuracy: 71.18%



100%|██████████| 16/16 [00:06<00:00,  2.49it/s]


Epoch 10, Validation loss: 3.7743978649377823; Validation accuracy: 73.53%


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 115/115 [00:42<00:00,  2.73it/s]


🔍 Novel classes accuracy: 77.45%

-----------------------------------------------------------------------------------------------


100%|██████████| 16/16 [00:17<00:00,  1.11s/it]


Epoch 11, Training loss: 3.7716043144464493; Training accuracy: 72.55%



100%|██████████| 16/16 [00:06<00:00,  2.38it/s]


Epoch 11, Validation loss: 3.7680599689483643; Validation accuracy: 73.33%


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 115/115 [00:42<00:00,  2.72it/s]


🔍 Novel classes accuracy: 77.23%

-----------------------------------------------------------------------------------------------


100%|██████████| 16/16 [00:17<00:00,  1.11s/it]


Epoch 12, Training loss: 3.7650596499443054; Training accuracy: 72.94%



100%|██████████| 16/16 [00:06<00:00,  2.54it/s]


Epoch 12, Validation loss: 3.7617155462503433; Validation accuracy: 73.53%


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 115/115 [00:42<00:00,  2.73it/s]


🔍 Novel classes accuracy: 76.93%

-----------------------------------------------------------------------------------------------


100%|██████████| 16/16 [00:17<00:00,  1.12s/it]


Epoch 13, Training loss: 3.758386805653572; Training accuracy: 73.73%



100%|██████████| 16/16 [00:06<00:00,  2.53it/s]


Epoch 13, Validation loss: 3.75533726811409; Validation accuracy: 73.53%


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 115/115 [00:42<00:00,  2.71it/s]


🔍 Novel classes accuracy: 76.96%

-----------------------------------------------------------------------------------------------


100%|██████████| 16/16 [00:17<00:00,  1.11s/it]


Epoch 14, Training loss: 3.751645013689995; Training accuracy: 73.73%



100%|██████████| 16/16 [00:06<00:00,  2.49it/s]


Epoch 14, Validation loss: 3.7488829493522644; Validation accuracy: 73.14%


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 115/115 [00:42<00:00,  2.72it/s]


🔍 Novel classes accuracy: 76.99%

-----------------------------------------------------------------------------------------------


100%|██████████| 16/16 [00:17<00:00,  1.11s/it]


Epoch 15, Training loss: 3.7448973804712296; Training accuracy: 74.31%



100%|██████████| 16/16 [00:06<00:00,  2.38it/s]


Epoch 15, Validation loss: 3.7423578649759293; Validation accuracy: 73.92%


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 115/115 [00:42<00:00,  2.73it/s]


🔍 Novel classes accuracy: 76.96%

-----------------------------------------------------------------------------------------------


100%|██████████| 16/16 [00:17<00:00,  1.11s/it]


Epoch 16, Training loss: 3.7381471693515778; Training accuracy: 75.29%



100%|██████████| 16/16 [00:06<00:00,  2.54it/s]


Epoch 16, Validation loss: 3.735743820667267; Validation accuracy: 74.12%


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 115/115 [00:42<00:00,  2.72it/s]


🔍 Novel classes accuracy: 76.74%

-----------------------------------------------------------------------------------------------


100%|██████████| 16/16 [00:17<00:00,  1.11s/it]


Epoch 17, Training loss: 3.7310997396707535; Training accuracy: 77.25%



100%|██████████| 16/16 [00:06<00:00,  2.53it/s]


Epoch 17, Validation loss: 3.729115381836891; Validation accuracy: 75.69%


🧠 Zero-shot evaluation on Novel Classes: 100%|██████████| 115/115 [00:42<00:00,  2.71it/s]


🔍 Novel classes accuracy: 76.47%

-----------------------------------------------------------------------------------------------


 56%|█████▋    | 9/16 [00:10<00:08,  1.18s/it]


KeyboardInterrupt: 

# Simple prompt tuning

In [9]:
import torch
import torch.nn as nn

class ModulationMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.mlp(x)


In [16]:
import os.path as osp
from collections import OrderedDict
import math

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast



from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()




class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype


    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x


class PromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx, ctx_init, class_token_position, csc=False):
        super().__init__()
        n_cls = len(classnames)
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution
        vis_dim = clip_model.visual.output_dim
        self.meta_net = ModulationMLP(input_dim=vis_dim, hidden_dim=vis_dim//2, output_dim=ctx_dim)

        # Use given words to initialize context vectors
        if ctx_init:
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init).to(clip_model.token_embedding.weight.device)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init
            #we take prompt_prefix as the context that we give
        else:
            if csc:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim)

            torch.nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f"Initial context: '{prompt_prefix}'")
        print(f"Number of context words (tokens): {n_ctx}")

        # These are the `prompts` we want to optimize
        self.ctx = nn.Parameter(ctx_vectors) #we don't want this since we want to have just the parameters of the meta net

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        # print("+++")
        # print("Prompts:")
        # for p in prompts:
        #     print(p)
        # print("+++")

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(clip_model.token_embedding.weight.device)

        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts
        self.name_lens = name_lens
        self.class_token_position = class_token_position

    def construct_prompts(self, ctx, prefix, suffix, label=None):
        # dim0 is either batch_size (during training) or n_cls (during testing)
        # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
        # prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
        # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)

        if label is not None:
            prefix = prefix[label]
            suffix = suffix[label]

        prompts = torch.cat(
            [
                prefix,  # (dim0, 1, dim)
                ctx,     # (dim0, n_ctx, dim)
                suffix,  # (dim0, *, dim)
            ],
            dim=1,
        )
        return prompts

    def forward(self, im_features):
        prefix = self.token_prefix
        suffix = self.token_suffix
        ctx = self.ctx
        bias = self.meta_net(im_features)
        bias = bias.unsqueeze(1)           # (batch, 1, ctx_dim)
        ctx = ctx.unsqueeze(0)             # (1, n_ctx, ctx_dim)
        ctx_shifted = ctx + bias            # (batch, n_ctx, ctx_dim)

        prompts = []
        for ctx_shifted_i in ctx_shifted:
            ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1)
            pts_i = self.construct_prompts(ctx_i, prefix, suffix)  # (n_cls, n_tkn, ctx_dim)
            prompts.append(pts_i)
        prompts = torch.stack(prompts)
        print(prompts.shape)


        return prompts




In [17]:
class OurCLIP(nn.Module):
    def __init__(self, classnames, n_ctx, ctx_init, class_token_position, csc=False):
        super().__init__()
        clip_model, _ = clip.load("ViT-B/16", device=device)
        # clip_model = clip_model.cpu()
        clip_model = clip_model.float()

        self.prompt_learner = PromptLearner(clip_model, classnames, n_ctx, ctx_init, class_token_position, csc=csc)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale

    def forward(self, image, label = None):
        image_features = self.image_encoder(image)

        prompts = self.prompt_learner(image_features)
        tokenized_prompts = self.tokenized_prompts
        text_features = self.text_encoder(prompts, tokenized_prompts)

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        return logits

In [18]:
from torch.utils.data import DataLoader
from torch.optim import Adam
import torchvision.transforms as transforms
def cocoop( train_dataset, val_dataset, categories,  batch_size, num_epochs, device, n_ctx=4,
    ctx_init="a photo of a type of flower, the ",
    class_token_position="end",
    csc=False):

  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
  val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

  model = OurCLIP(
        classnames=name_base, n_ctx=n_ctx, ctx_init=ctx_init, class_token_position=class_token_position, csc=csc
    ).to(device)

  print("Turning off gradients in both the image and the text encoder")
  for name, param in model.named_parameters():
        if "prompt_learner" not in name:
            param.requires_grad_(False)

  print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
  print(f"Total trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")




  trainable_params = [param for name, param in model.named_parameters() if param.requires_grad]

  criterion = torch.nn.CrossEntropyLoss()
  optimizer = Adam(trainable_params, lr=0.002,  weight_decay=0.0005)

# Verify which parameters are trainable

  contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}



  print("🧠 Fine-tuning training+validation on Base Classes")

  for epoch in range(num_epochs):

    # Training of one epoch

    model.train()
    # here we store the sum of all the computed losses through the all batches
    total_loss = 0
    # here we store the number of correct predictions we will make
    correct_predictions = 0

    for image, target in tqdm(train_dataloader):


          target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

          image = image.to(device).float()
          target = target.to(device)
          output = model(image)

          # forward image through CLIP image encoder

          loss = criterion(output, target)
          total_loss += loss.item()

          # Backpropagation
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
          optimizer.step()
          optimizer.zero_grad()

          # Training accuracy computation
          predicted_class = output.argmax(dim=-1)
          correct_predictions += (predicted_class == target).sum().item()

    print(f"Epoch {epoch + 1}, Training loss: {total_loss/ len(train_dataloader)}; Training accuracy: {correct_predictions / len(train_dataset)*100:.2f}%")
    print()

    # Validation of one epoch

    model.eval()
    total_loss = 0
    correct_predictions = 0

    for image, target in tqdm(val_dataloader):

          target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

          image = image.to(device).float()
          target = target.to(device)

          # forward image through CLIP image encoder
          output= model(image)

          loss = criterion(output, target)
          total_loss += loss.item()

          # Validation accuracy computation
          predicted_class = output.argmax(dim=-1)
          correct_predictions += (predicted_class == target).sum().item()

    print(f"Epoch {epoch + 1}, Validation loss: {total_loss/ len(val_dataloader)}; Validation accuracy: {correct_predictions / len(val_dataset)*100:.2f}%")
    print()
    print("-----------------------------------------------------------------------------------------------")

  return(model)

model_cocoop = cocoop( train_dataset=train_base, val_dataset=val_base, categories=base_classes, batch_size=32, num_epochs=30, device=device)

Initial context: 'a photo of a type of flower, the '
Number of context words (tokens): 9
Turning off gradients in both the image and the text encoder
Total parameters: 124,591,361
Total trainable parameters: 267,520
🧠 Fine-tuning training+validation on Base Classes


  0%|          | 0/16 [00:00<?, ?it/s]


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 51 but got size 1 for tensor number 1 in the list.

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
import torchvision.transforms as transforms

class ModulationMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.mlp(x)

_tokenizer = _Tokenizer()




class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype


    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x


class PromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx, ctx_init, class_token_position, csc=False):
        super().__init__()
        n_cls = len(classnames)
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution
        vis_dim = clip_model.visual.output_dim
        self.meta_net = ModulationMLP(input_dim=vis_dim, hidden_dim=vis_dim//2, output_dim=ctx_dim)

        # Use given words to initialize context vectors
        if ctx_init:
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init).to(clip_model.token_embedding.weight.device)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init
            #we take prompt_prefix as the context that we give
        else:
            if csc:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim)

            torch.nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f"Initial context: '{prompt_prefix}'")
        print(f"Number of context words (tokens): {n_ctx}")

        # These are the `prompts` we want to optimize
        self.ctx = nn.Parameter(ctx_vectors) #we don't want this since we want to have just the parameters of the meta net

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        # print("+++")
        # print("Prompts:")
        # for p in prompts:
        #     print(p)
        # print("+++")

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(clip_model.token_embedding.weight.device)

        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts
        self.name_lens = name_lens
        self.class_token_position = class_token_position

    def construct_prompts(self, ctx, prefix, suffix, label=None):
        # dim0 is either batch_size (during training) or n_cls (during testing)
        # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
        # prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
        # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)

        if label is not None:
            prefix = prefix[label]
            suffix = suffix[label]

        prompts = torch.cat(
            [
                prefix,  # (dim0, 1, dim)
                ctx,     # (dim0, n_ctx, dim)
                suffix,  # (dim0, *, dim)
            ],
            dim=1,
        )
        return prompts

    def forward(self, im_features):
        prefix = self.token_prefix
        suffix = self.token_suffix
        ctx = self.ctx
        bias = self.meta_net(im_features)
        bias = bias.unsqueeze(1)           # (batch, 1, ctx_dim)
        ctx = ctx.unsqueeze(0)             # (1, n_ctx, ctx_dim)
        ctx_shifted = ctx + bias            # (batch, n_ctx, ctx_dim)
        prompts = []
        for ctx_shifted_i in ctx_shifted:
            ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1)
            pts_i = self.construct_prompts(ctx_i, prefix, suffix)  # (n_cls, n_tkn, ctx_dim)
            prompts.append(pts_i)
        prompts = torch.stack(prompts)

        return prompts
class OurCLIP(nn.Module):
    def __init__(self, classnames, n_ctx, ctx_init, class_token_position, csc=False):
        super().__init__()
        clip_model, _ = clip.load("ViT-B/16", device=device)
        # clip_model = clip_model.cpu()
        clip_model = clip_model.float()

        self.prompt_learner = PromptLearner(clip_model, classnames, n_ctx, ctx_init, class_token_position, csc=csc)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale

    def forward(self, image, label = None):
        image_features = self.image_encoder(image)

        prompts = self.prompt_learner(image_features)
        tokenized_prompts = self.tokenized_prompts
        text_features = self.text_encoder(prompts, tokenized_prompts)

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        return logits



def cocoop( train_dataset, val_dataset, categories,  batch_size, num_epochs, device, n_ctx=4,
    ctx_init="a photo of a type of flower, the ",
    class_token_position="end",
    csc=False):

  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
  val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

  model = OurCLIP(
        classnames=name_base, n_ctx=n_ctx, ctx_init=ctx_init, class_token_position=class_token_position, csc=csc
    ).to(device)

  print("Turning off gradients in both the image and the text encoder")
  for name, param in model.named_parameters():
        if "prompt_learner" not in name:
            param.requires_grad_(False)

  print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
  print(f"Total trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")




  trainable_params = [param for name, param in model.named_parameters() if param.requires_grad]

  criterion = torch.nn.CrossEntropyLoss()
  optimizer = Adam(trainable_params, lr=0.002,  weight_decay=0.0005)

# Verify which parameters are trainable

  contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}



  print("🧠 Fine-tuning training+validation on Base Classes")

  for epoch in range(num_epochs):

    # Training of one epoch

    model.train()
    # here we store the sum of all the computed losses through the all batches
    total_loss = 0
    # here we store the number of correct predictions we will make
    correct_predictions = 0

    for image, target in tqdm(train_dataloader):


          target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

          image = image.to(device).float()
          target = target.to(device)
          output = model(image)

          # forward image through CLIP image encoder

          loss = criterion(output, target)
          total_loss += loss.item()

          # Backpropagation
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
          optimizer.step()
          optimizer.zero_grad()

          # Training accuracy computation
          predicted_class = output.argmax(dim=-1)
          correct_predictions += (predicted_class == target).sum().item()

    print(f"Epoch {epoch + 1}, Training loss: {total_loss/ len(train_dataloader)}; Training accuracy: {correct_predictions / len(train_dataset)*100:.2f}%")
    print()

    # Validation of one epoch

    model.eval()
    total_loss = 0
    correct_predictions = 0

    for image, target in tqdm(val_dataloader):

          target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

          image = image.to(device).float()
          target = target.to(device)

          # forward image through CLIP image encoder
          output= model(image)

          loss = criterion(output, target)
          total_loss += loss.item()

          # Validation accuracy computation
          predicted_class = output.argmax(dim=-1)
          correct_predictions += (predicted_class == target).sum().item()

    print(f"Epoch {epoch + 1}, Validation loss: {total_loss/ len(val_dataloader)}; Validation accuracy: {correct_predictions / len(val_dataset)*100:.2f}%")
    print()
    print("-----------------------------------------------------------------------------------------------")

  return(model)

model_cocoop = cocoop( train_dataset=train_base, val_dataset=val_base, categories=base_classes, batch_size=32, num_epochs=30, device=device)


In [None]:
def get_optimizer(model, lr, wd, momentum):
    optimizer = torch.optim.SGD([
        {"params": model.parameters()}
    ], lr=lr, weight_decay=wd, momentum=momentum)

    return optimizer

def main_coop(
    batch_size=16,
    num_classes=10,
    device="cuda:0",
    learning_rate=0.002,
    weight_decay=0.0005,
    momentum=0.9,
    epochs=2,
    run_name="exp1",
    n_ctx=4,
    ctx_init="",
    class_token_position="end",
    csc=False,
):
    # Create a logger for the experiment
    writer = SummaryWriter(log_dir=f"runs/{run_name}")

    # Get dataloaders
    train_loader, val_loader, test_loader = get_data(dataset_name, transform=preprocess, batch_size=batch_size)
    classnames, _ = embed_dataset_classnames(dataset_name)

    # Instantiate the network and move it to the chosen device (GPU)
    net = OurCLIP(
        classnames=classnames, n_ctx=n_ctx, ctx_init=ctx_init, class_token_position=class_token_position, csc=csc
    ).to(device)

    print("Turning off gradients in both the image and the text encoder")
    for name, param in net.named_parameters():
        if "prompt_learner" not in name:
            param.requires_grad_(False)

    print(f"Total parameters: {sum(p.numel() for p in net.parameters()):,}")
    print(f"Total trainable parameters: {sum(p.numel() for p in net.parameters() if p.requires_grad):,}")

    # Instantiate the optimizer
    optimizer = get_optimizer(net, learning_rate, weight_decay, momentum)

    # Define the cost function
    cost_function = get_cost_function()

    # Computes evaluation results before training
    print("Before training:")
    train_loss, train_accuracy = test_step(net, train_loader, cost_function)
    val_loss, val_accuracy = test_step(net, val_loader, cost_function)
    test_loss, test_accuracy = test_step(net, test_loader, cost_function)

    # Log to TensorBoard
    log_values(writer, -1, train_loss, train_accuracy, "train")
    log_values(writer, -1, val_loss, val_accuracy, "validation")
    log_values(writer, -1, test_loss, test_accuracy, "test")

    print(f"\tTraining loss {train_loss:.5f}, Training accuracy {train_accuracy:.2f}")
    print(f"\tValidation loss {val_loss:.5f}, Validation accuracy {val_accuracy:.2f}")
    print(f"\tTest loss {test_loss:.5f}, Test accuracy {test_accuracy:.2f}")

    # For each epoch, train the network and then compute evaluation results
    for e in range(epochs):
        train_loss, train_accuracy = training_step(net, train_loader, optimizer, cost_function)
        val_loss, val_accuracy = test_step(net, val_loader, cost_function)

        log_values(writer, e, train_loss, train_accuracy, "train")
        log_values(writer, e, val_loss, val_accuracy, "validation")

    # Compute final evaluation results
    print("After training:")
    train_loss, train_accuracy = test_step(net, train_loader, cost_function)
    val_loss, val_accuracy = test_step(net, val_loader, cost_function)
    test_loss, test_accuracy = test_step(net, test_loader, cost_function)

    log_values(writer, epochs, train_loss, train_accuracy, "train")
    log_values(writer, epochs, val_loss, val_accuracy, "validation")
    log_values(writer, epochs, test_loss, test_accuracy, "test")
    print(f"\tTraining loss {train_loss:.5f}, Training accuracy {train_accuracy:.2f}")
    print(f"\tValidation loss {val_loss:.5f}, Validation accuracy {val_accuracy:.2f}")
    print(f"\tTest loss {test_loss:.5f}, Test accuracy {test_accuracy:.2f}")

    # Closes the logger
    writer.close()