In [1]:
import clip
import torch
from torch import nn
from clip.model import CLIP
from clip.simple_tokenizer import SimpleTokenizer

In [2]:
class TextEncoder(nn.Module):
    def __init__(self, clip_model: CLIP):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, soft_prompt_emb, prompt_token_ids):
        x = soft_prompt_emb + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # shape: [num_classes, num_tokens, dim] -> [num_tokens, num_classes, dim]
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # shape: [num_tokens, num_classes, dim] -> [num_classes, num_tokens, dim]
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), prompt_token_ids.argmax(dim=-1)] @ self.text_projection

        return x

In [3]:
class SoftPrompt(nn.Module):
    def __init__(self, class_names, clip_model: CLIP, tokenizer_fn, nctx=16) -> None:
        super().__init__()

        self.nctx = nctx
        dtype = clip_model.dtype
        dim, = clip_model.ln_final.weight.shape
        self.ctx = nn.Parameter(torch.normal(mean=0, std=0.02, size=(nctx, dim), dtype=dtype))

        self.num_classes = len(class_names)
        dummy_texts = " ".join(['X'] * nctx)

        class_names = [name.replace("_", " ") for name in class_names]
        prompt_texts = [f'{dummy_texts} {name}.' for name in class_names]

        self.prompt_token_ids = torch.cat([tokenizer_fn(p) for p in prompt_texts])
        with torch.no_grad():
            embeddings = clip_model.token_embedding(self.prompt_token_ids).to(dtype)

        self.register_buffer("prefix_emb", embeddings[:, :1, :])  # SOT, shape: [num_classes, 1, dim]
        self.register_buffer("suffix_emb", embeddings[:, 1+nctx:, :])  # Class tokens and EOT, shape: [num_classes, *, dim]
    
    def forward(self):
        ctx = self.ctx
        ctx = ctx.unsqueeze(0).expand(self.num_classes, -1, -1)  # shape: [num_classes, nctx, dim]
        soft_prompts = torch.cat([self.prefix_emb, ctx, self.suffix_emb], dim=1)
        return soft_prompts, self.prompt_token_ids

In [4]:
class CLIPWithSoftPrompt(nn.Module):
    def __init__(self, class_names, clip_model: CLIP, clip_tokenize_fn):
        super().__init__()
        self.create_soft_prompt = SoftPrompt(class_names, clip_model, clip_tokenize_fn)
        self.visual_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype

    def forward(self, image):
        image_features = self.visual_encoder(image.to(self.dtype))

        soft_prompt_emb, prompt_token_ids = self.create_soft_prompt()
        text_features = self.text_encoder(soft_prompt_emb, prompt_token_ids)

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        return logits, image_features, text_features

In [5]:
clip_model, clip_preprocess = clip.load('RN50')
class_names = ['elephant', 'person', 'fish', 'bird']
model = CLIPWithSoftPrompt(class_names, clip_model, clip.tokenize)

In [6]:
x = torch.rand((16, 3, 224, 224))
logts, im_feats, text_feats = model(x)

In [7]:
im_feats.shape

torch.Size([16, 1024])