In [1]:
import clip
import torch
from torch import nn
from clip.model import CLIP
from clip.simple_tokenizer import SimpleTokenizer

In [28]:
class TextEncoder(nn.Module):
    def __init__(self, clip_model: CLIP):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, soft_prompt_emb, prompt_token_ids):
        x = soft_prompt_emb + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # shape: [num_classes, num_tokens, dim] -> [num_tokens, num_classes, dim]
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # shape: [num_tokens, num_classes, dim] -> [num_classes, num_tokens, dim]
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), prompt_token_ids.argmax(dim=-1)] @ self.text_projection

        return x

In [37]:
class SoftPrompt(nn.Module):
    def __init__(self, class_names, clip_model: CLIP, tokenizer_fn, nctx=16) -> None:
        super().__init__()

        self.nctx = nctx
        dtype = clip_model.dtype
        dim, = clip_model.ln_final.weight.shape
        self.ctx = nn.Parameter(torch.normal(mean=0, std=0.02, size=(nctx, dim), dtype=dtype))

        self.num_classes = len(class_names)
        dummy_texts = " ".join(['X'] * nctx)

        class_names = [name.replace("_", " ") for name in class_names]
        prompt_texts = [f'{dummy_texts} {name}.' for name in class_names]

        self.prompt_token_ids = torch.cat([tokenizer_fn(p) for p in prompt_texts])
        with torch.no_grad():
            embeddings = clip_model.token_embedding(self.prompt_token_ids).to(dtype)

        self.register_buffer("prefix_emb", embeddings[:, :1, :])  # SOT, shape: [num_classes, 1, dim]
        self.register_buffer("suffix_emb", embeddings[:, 1+nctx:, :])  # Class tokens and EOT, shape: [num_classes, *, dim]
    
    def forward(self):
        ctx = self.ctx
        ctx = ctx.unsqueeze(0).expand(self.num_classes, -1, -1)  # shape: [num_classes, nctx, dim]
        soft_prompts = torch.cat([self.prefix_emb, ctx, self.suffix_emb], dim=1)
        return soft_prompts, self.prompt_token_ids

In [38]:
class CLIPWithSoftPrompt(nn.Module):
    def __init__(self, class_names, clip_model: CLIP, clip_tokenize_fn):
        super().__init__()
        self.create_soft_prompt = SoftPrompt(class_names, clip_model, clip_tokenize_fn)
        self.visual_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype

    def forward(self, image):
        image_features = self.visual_encoder(image.to(self.dtype))

        soft_prompt_emb, prompt_token_ids = self.create_soft_prompt()
        text_features = self.text_encoder(soft_prompt_emb, prompt_token_ids)

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        return logits

In [39]:
clip_model, clip_preprocess = clip.load('RN50')
class_names = ['elephant', 'person', 'fish', 'bird']
model = CLIPWithSoftPrompt(class_names, clip_model, clip.tokenize)

In [42]:
x = torch.rand((16, 3, 224, 224))
model(x)

tensor([[18.9530, 19.4049, 19.7601, 19.3778],
        [18.9777, 19.3840, 19.6736, 19.3275],
        [18.4510, 18.7984, 19.1546, 18.8448],
        [19.0224, 19.4500, 19.8114, 19.4139],
        [18.5623, 18.9927, 19.2871, 18.9639],
        [18.9982, 19.4180, 19.7362, 19.3876],
        [18.9137, 19.3398, 19.5940, 19.2800],
        [18.8552, 19.2172, 19.5546, 19.1923],
        [18.8760, 19.3108, 19.6193, 19.2605],
        [18.7870, 19.1620, 19.5427, 19.1951],
        [18.8915, 19.2890, 19.5915, 19.2640],
        [18.6203, 19.0552, 19.3533, 19.0138],
        [18.7040, 19.1012, 19.4249, 19.0835],
        [19.2366, 19.6312, 19.9429, 19.5913],
        [18.7571, 19.1356, 19.4448, 19.1265],
        [18.9356, 19.3609, 19.6696, 19.3360]], grad_fn=<MmBackward0>)

In [20]:
import os
import pandas as pd

class_names_df = pd.read_csv(os.path.join('datasets', 'CUB', 'CUB_200_2011', 'classes.txt'),
                             sep=' ', header=None, names=['class_id', 'class_name'])
class_names = class_names_df['class_name'].str.split('.').str[-1].replace('_', ' ', regex=True).to_list()

In [21]:
class_names

['Black footed Albatross',
 'Laysan Albatross',
 'Sooty Albatross',
 'Groove billed Ani',
 'Crested Auklet',
 'Least Auklet',
 'Parakeet Auklet',
 'Rhinoceros Auklet',
 'Brewer Blackbird',
 'Red winged Blackbird',
 'Rusty Blackbird',
 'Yellow headed Blackbird',
 'Bobolink',
 'Indigo Bunting',
 'Lazuli Bunting',
 'Painted Bunting',
 'Cardinal',
 'Spotted Catbird',
 'Gray Catbird',
 'Yellow breasted Chat',
 'Eastern Towhee',
 'Chuck will Widow',
 'Brandt Cormorant',
 'Red faced Cormorant',
 'Pelagic Cormorant',
 'Bronzed Cowbird',
 'Shiny Cowbird',
 'Brown Creeper',
 'American Crow',
 'Fish Crow',
 'Black billed Cuckoo',
 'Mangrove Cuckoo',
 'Yellow billed Cuckoo',
 'Gray crowned Rosy Finch',
 'Purple Finch',
 'Northern Flicker',
 'Acadian Flycatcher',
 'Great Crested Flycatcher',
 'Least Flycatcher',
 'Olive sided Flycatcher',
 'Scissor tailed Flycatcher',
 'Vermilion Flycatcher',
 'Yellow bellied Flycatcher',
 'Frigatebird',
 'Northern Fulmar',
 'Gadwall',
 'American Goldfinch',
 'Euro