#Dataset

In [None]:
import os
import torchvision
import gdown
import json
import PIL

def download_ucf101(root, download):
    torchvision.datasets.utils.download_and_extract_archive('https://drive.google.com/file/d/10Jqome3vtUA2keJkNanAiFpgbyC9Hc2O', os.path.join(root, 'ucf101'), filename='UCF-101-midframes.zip')

# Split datasets (train, val, test) according to https://github.com/KaiyangZhou/CoOp/blob/main/DATASETS.md
# folder, img folder, json file in Google drive
datasets_list = {
    'caltech101': (torchvision.datasets.Caltech101,     'caltech101',       '101_ObjectCategories', '1hyarUivQE36mY6jSomru6Fjd-JzwcCzN'),
    'oxfordpets': (torchvision.datasets.OxfordIIITPet,  'oxford-iiit-pet',  'images',               '1501r8Ber4nNKvmlFVQZ8SeUHTcdTTEqs'),
    'flowers102': (torchvision.datasets.Flowers102,     'flowers-102',      'jpg',                  '1Pp0sRXzZFZq15zVOzKjKBu4A9i01nozT'),
    'food101'   : (torchvision.datasets.Food101,        'food-101',         'images',               '1QK0tGi096I0Ba6kggatX1ee6dJFIcEJl'),
    'dtd'       : (torchvision.datasets.DTD,            'dtd', os.path.join('dtd', 'images'),       '1u3_QfB467jqHgNXC00UIzbLZRQCg2S7x'),
    'eurosat'   : (torchvision.datasets.EuroSAT,        'eurosat',          '2750',                 '1Ip7yaCWFi0eaOFUGga0lUdVi_DDQth1o'),
    'ucf101'    : (download_ucf101,                     'ucf101',           'UCF-101-midframes',    '1I0S0q91hJfsV9Gf4xDIjgDq4AqBNJb1y'),
}


class AMLDataset(torchvision.datasets.VisionDataset):
    def __init__(self, dataset_name, root, split: str='train', transforms=None, transform=None, target_transform=None):
        dataset_info = datasets_list[dataset_name]

        # Download the dataset with the help of torchvision.datasets object
        dataset_info[0](root, download=True)

        # Since torchvision.datasets put data inside a subfolder, we change path into this new root folder, and store everything inside it
        root = os.path.join(root, dataset_info[1])
        super().__init__(root, transforms, transform, target_transform)

        # Images are further inside the new root folder
        self.img_folder = os.path.join(root, dataset_info[2])

        # Download json inside the new root folder
        split_file_path = os.path.join(root, 'split.json')
        if not os.path.exists(split_file_path):
            gdown.download(id=dataset_info[3], output=split_file_path)

        # Read json file, resulting in a dict[str('train', 'val', 'test'), list[str(impath), int(label), str(classname)]]
        with open(split_file_path, 'r') as f:
            data_source = json.load(f)

        self._items = data_source[split]
        self._num_classes = self.get_num_classes(data_source['train'])
        self._lab2cname, self._classnames = self.get_lab2cname(data_source['train'])

    def __len__(self):
        return len(self._items)

    def __getitem__(self, index):
        impath, label, classname = self._items[index]

        img = PIL.Image.open(os.path.join(self.img_folder, impath))

        # if self.transforms is not None:
        #     img, label = self.transforms(img, label)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        return img, label

    @property
    def lab2cname(self):
        return self._lab2cname

    @property
    def classnames(self):
        return self._classnames

    @property
    def num_classes(self):
        return self._num_classes

    @staticmethod
    def get_num_classes(data_source):
        """Count number of classes.

        Args:
            data_source (list): a list of Datum objects.
        """
        label_set = set()
        for impath, label, classname in data_source:
            label_set.add(label)
        return max(label_set) + 1

    @staticmethod
    def get_lab2cname(data_source):
        """Get a label-to-classname mapping (dict).

        Args:
            data_source (list): a list of Datum objects.
        """
        container = set()
        for impath, label, classname in data_source:
            container.add((label, classname))
        mapping = {label: classname for label, classname in container}
        labels = list(mapping.keys())
        labels.sort()
        classnames = [mapping[label] for label in labels]
        return mapping, classnames


class Caltech101(AMLDataset):
    def __init__(self, root, *args, **kwargs):
        from urllib.error import HTTPError
        try:
            super().__init__('caltech101', root, *args, **kwargs)
        except HTTPError:
            # Use the copy hosted by Terry
            from torchvision.datasets.utils import download_and_extract_archive
            download_and_extract_archive(
                'https://drive.google.com/file/d/1IFqrvpdbrpmI6DPntopcPY6svPu04uYD',
                os.path.join(root, 'caltech101'),
                filename='101_ObjectCategories.tar.gz',
                md5='b224c7392d521a49829488ab0f1120d9',
            )
            download_and_extract_archive(
                'https://drive.google.com/file/d/1sW96Lj6yLIujKpopd8tBrIO_NCaKBy5d',
                os.path.join(root, 'caltech101'),
                filename='Annotations.tar',
                md5='6f83eeb1f24d99cab4eb377263132c91',
            )
            super().__init__('caltech101', root, *args, **kwargs)


class OxfordIIITPet(AMLDataset):
    def __init__(self, *args, **kwargs):
        super().__init__('oxfordpets', *args, **kwargs)


class Flowers102(AMLDataset):
    def __init__(self, *args, **kwargs):
        super().__init__('flowers102', *args, **kwargs)


class Food101(AMLDataset):
    def __init__(self, *args, **kwargs):
        super().__init__('food101', *args, **kwargs)


class DTD(AMLDataset):
    def __init__(self, *args, **kwargs):
        super().__init__('dtd', *args, **kwargs)


class EuroSAT(AMLDataset):
    def __init__(self, *args, **kwargs):
        super().__init__('eurosat', *args, **kwargs)


class UCF101(AMLDataset):
    def __init__(self, *args, **kwargs):
        super().__init__('ucf101', *args, **kwargs)


class FGVCAircraft(torchvision.datasets.FGVCAircraft):
    def __init__(self, root, *args, **kwargs):
        super().__init__(root, *args, download=True, **kwargs)
        self._lab2cname = {i: self.classes[i] for i in range(len(self.classes))}

    @property
    def lab2cname(self):
        return self._lab2cname

    @property
    def classnames(self):
        return self.classes

    @property
    def num_classes(self):
        return len(self.classes)

In [None]:
!mkdir datasets

In [None]:
dataset_path = "/content/datasets"

In [None]:
!pip install torchvision ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git


Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.6.0->torchvision)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch==2.6.0->torchvision)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch==2.6.0->torc

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import clip
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import random_split
import clip
import os.path as osp

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast

from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()
# Load CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/16", device=device)

# Load Caltech101 dataset
transform = preprocess
train_set = Caltech101(root=dataset_path,split='test', transform=preprocess)


100%|███████████████████████████████████████| 335M/335M [00:04<00:00, 81.3MiB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1IFqrvpdbrpmI6DPntopcPY6svPu04uYD
From (redirected): https://drive.usercontent.google.com/download?id=1IFqrvpdbrpmI6DPntopcPY6svPu04uYD&confirm=t&uuid=eae6d48f-e277-474d-b8eb-26bfd261b7ae
To: /content/datasets/caltech101/101_ObjectCategories.tar.gz
100%|██████████| 132M/132M [00:03<00:00, 42.0MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1sW96Lj6yLIujKpopd8tBrIO_NCaKBy5d
From (redirected): https://drive.usercontent.google.com/download?id=1sW96Lj6yLIujKpopd8tBrIO_NCaKBy5d&confirm=t&uuid=a83a6862-2a95-47c4-b861-31f1d24037c9
To: /content/datasets/caltech101/Annotations.tar
100%|██████████| 14.0M/14.0M [00:00<00:00, 27.3MB/s]
Downloading...
From: https://drive.google.com/uc?id=1hyarUivQE36mY6jSomru6Fjd-JzwcCzN
To: /content/datasets/caltech101/split.json
100%|██████████| 809k/809k [00:00<00:00, 8.64MB/s]


In [None]:
# train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=8)


In [None]:
num_train = int(0.8 * len(train_set))
num_val = len(train_set) - num_train
train_dataset, val_dataset = random_split(train_set, [num_train, num_val])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=8)



In [None]:
classnames = train_set.classnames

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)
        x = self.transformer(x)
        x = x.permute(1, 0, 2)
        x = self.ln_final(x).type(self.dtype)

        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
        return x

In [None]:
class PromptLearner(nn.Module):
    def __init__(self, classnames, clip_model, n_ctx=16, ctx_init="", class_token_position="end", csc=False, input_size=224):
        super().__init__()
        self.n_cls = len(classnames)
        self.n_ctx = n_ctx
        self.ctx_init = ctx_init
        self.class_token_position = class_token_position
        self.csc = csc
        self.input_size = input_size

        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution
        device = clip_model.token_embedding.weight.device

        assert self.input_size == clip_imsize, f"cfg_imsize ({self.input_size}) must equal to clip_imsize ({clip_imsize})"

        if self.ctx_init:
            ctx_init = self.ctx_init.replace("_", " ")
            self.n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init).to(device)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1 : 1 + self.n_ctx, :].to(device)
            prompt_prefix = ctx_init
        else:
            if self.csc:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(self.n_cls, self.n_ctx, ctx_dim, dtype=dtype, device=device)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(self.n_ctx, ctx_dim, dtype=dtype, device=device)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * self.n_ctx)

        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {self.n_ctx}")

        self.ctx = nn.Parameter(ctx_vectors)

        classnames = [name.replace("_", " ") for name in classnames]
        self.name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(device)
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)

        self.register_buffer("token_prefix", embedding[:, :1, :])
        self.register_buffer("token_suffix", embedding[:, 1 + self.n_ctx :, :])

        self.tokenized_prompts = tokenized_prompts

    def forward(self):
        ctx = self.ctx
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        prefix = self.token_prefix
        suffix = self.token_suffix

        if self.class_token_position == "end":
            prompts = torch.cat([prefix, ctx, suffix], dim=1)

        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
                prompt = torch.cat([prefix_i, ctx_i_half1, class_i, ctx_i_half2, suffix_i], dim=1)
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i = ctx[i : i + 1, :, :]
                prompt = torch.cat([prefix_i, class_i, ctx_i, suffix_i], dim=1)
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError(f"Invalid class_token_position: {self.class_token_position}")

        return prompts, self.tokenized_prompts

In [None]:
class CustomCLIP(nn.Module):
    def __init__(self, classnames, clip_model):
        super().__init__()
        self.prompt_learner = PromptLearner(
            classnames=classnames,
            clip_model=clip_model,
            n_ctx=16,
            ctx_init="",
            csc=False,
            class_token_position="end",
            input_size=224
        )
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype

    def forward(self, image):
        image_features = self.image_encoder(image.type(self.dtype))
        prompts, tokenized_prompts = self.prompt_learner()
        text_features = self.text_encoder(prompts, tokenized_prompts)

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        return logits

In [None]:
import clip

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, _ = clip.load("ViT-B/16", device=device)

model = CustomCLIP(classnames=classnames, clip_model=clip_model).to(device)


Initializing a generic context
Initial context: "X X X X X X X X X X X X X X X X"
Number of context words (tokens): 16


In [None]:

MAX_EPOCH = 10
LR = 0.002

optimizer = optim.SGD(model.prompt_learner.parameters(), lr=LR, momentum=0.9)
scheduler = CosineAnnealingLR(optimizer, T_max=MAX_EPOCH)
criterion = torch.nn.CrossEntropyLoss()


In [None]:
def save_prompt_learner(model, path="output/coop_prompt.pth"):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save({"state_dict": model.prompt_learner.state_dict()}, path)
    print(f"Prompt learner saved to {path}")

In [None]:
PRINT_FREQ = 5
for epoch in range(MAX_EPOCH):
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        if (i + 1) % PRINT_FREQ == 0:
            print(f"Epoch [{epoch+1}/{MAX_EPOCH}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}")

    acc = 100. * correct / total
    print(f"Epoch {epoch+1}: Train Loss = {running_loss:.4f}, Accuracy = {acc:.2f}%")

    scheduler.step()
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        logits = model(images)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Validation Accuracy: {100.0 * correct / total:.2f}%")
save_prompt_learner(model)



Epoch [1/10], Step [5/62], Loss: 0.0172
Epoch [1/10], Step [10/62], Loss: 0.0515
Epoch [1/10], Step [15/62], Loss: 0.1022
Epoch [1/10], Step [20/62], Loss: 0.0741
Epoch [1/10], Step [25/62], Loss: 0.0239
Epoch [1/10], Step [30/62], Loss: 0.0173
Epoch [1/10], Step [35/62], Loss: 0.1002
Epoch [1/10], Step [40/62], Loss: 0.1300
Epoch [1/10], Step [45/62], Loss: 0.1168
Epoch [1/10], Step [50/62], Loss: 0.2025
Epoch [1/10], Step [55/62], Loss: 0.0912
Epoch [1/10], Step [60/62], Loss: 0.1484
Epoch 1: Train Loss = 6.0818, Accuracy = 96.45%
Epoch [2/10], Step [5/62], Loss: 0.0341
Epoch [2/10], Step [10/62], Loss: 0.1082
Epoch [2/10], Step [15/62], Loss: 0.0670
Epoch [2/10], Step [20/62], Loss: 0.2025
Epoch [2/10], Step [25/62], Loss: 0.0565
Epoch [2/10], Step [30/62], Loss: 0.1329
Epoch [2/10], Step [35/62], Loss: 0.1022
Epoch [2/10], Step [40/62], Loss: 0.0895
Epoch [2/10], Step [45/62], Loss: 0.0719
Epoch [2/10], Step [50/62], Loss: 0.1646
Epoch [2/10], Step [55/62], Loss: 0.0844
Epoch [2/10

In [None]:
!git clone https://github.com/KaiyangZhou/CoOp.git

Cloning into 'CoOp'...
remote: Enumerating objects: 455, done.[K
remote: Counting objects: 100% (250/250), done.[K
remote: Compressing objects: 100% (51/51), done.[K
remote: Total 455 (delta 217), reused 199 (delta 199), pack-reused 205 (from 1)[K
Receiving objects: 100% (455/455), 1.40 MiB | 19.65 MiB/s, done.
Resolving deltas: 100% (266/266), done.


In [None]:
%cd CoOp/

/content/CoOp


In [None]:
!python interpret_prompt.py /content/output/coop_prompt.pth 5

Return the top-5 matched words
100%|███████████████████████████████████████| 256M/256M [01:04<00:00, 3.97MiB/s]
Size of token embedding: torch.Size([49408, 512])
Size of context: torch.Size([16, 512])
Size of distance matrix: torch.Size([16, 49408])
1: ['weaknesses</w>', 'losses</w>', 'alright</w>', 'and', 'aaaaa</w>'] ['0.6468', '0.6477', '0.6496', '0.6518', '0.6522']
2: ['troupe</w>', 'decatur</w>', 'aqu', 'arun</w>', 'katz</w>'] ['0.6009', '0.6046', '0.6048', '0.6059', '0.6063']
3: ['bcfc</w>', 'phill', 'fri</w>', 'strike', 'fiancÃ©</w>'] ['0.5692', '0.5703', '0.5709', '0.5725', '0.5737']
4: ['olives</w>', 'certain</w>', 'boi</w>', 'pelicans</w>', 'elles</w>'] ['0.5862', '0.5864', '0.5911', '0.5926', '0.5930']
5: ['ophthal', 'mia</w>', 'kra', 'ials</w>', 'volcanoes</w>'] ['0.6959', '0.6996', '0.7025', '0.7026', '0.7038']
6: ['worlds</w>', 'bas</w>', 'aff</w>', 'period</w>', 'litres</w>'] ['0.7474', '0.7508', '0.7516', '0.7530', '0.7535']
7: ['terrible</w>', 'for', 'loses</w>', 'rip<