# Installs, Imports, Preps

In [None]:
# !pip install git+https://github.com/openai/CLIP.git
# !pip install transformers
# !pip install yacs

In [None]:
import clip
import torch
import transformers
import torchvision
import gc
import os
import torch.nn as nn
import torch.optim as optim
# import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
from config import get_cfg_defaults
from utils import build_lr_scheduler

device = "cuda" if torch.cuda.is_available() else "cpu"
cfg = get_cfg_defaults()

if not os.path.exists('checkpoints'):
    os.mkdir('checkpoints')

# Image Encoding

## Load data

In [None]:
model, transform = clip.load(cfg.MODEL.NAME)
model.to(device)

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [None]:
# If download is necessary, use this
# dataset = torchvision.datasets.Flowers102(root='',transform=transform,download=True,split='train')

# If already downloaded, use this
dataset = torchvision.datasets.Flowers102(root='',transform=transform,split='train')

I keep running into OOM error, so I have to subset the dataset.

In [None]:
subsets = [Subset(dataset, range(i,i+204)) for i in range(0, len(dataset), len(dataset)//5)]

In [None]:
dataloaders = [DataLoader(subset, batch_size=cfg.DATASET.IMG_BATCH_SIZE, shuffle=False) for subset in subsets]

## Get image embs

Confirm image shape

In [None]:
print(f"Image shape: {dataset[0][0].shape}")
print(f"Label type:  {type(dataset[0][1])}")

Image shape: torch.Size([3, 224, 224])
Label type:  <class 'int'>


Encode

In [None]:
def free_gpu_cache():
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
for i,dataloader in enumerate(dataloaders):
    img_embs, lbls = [], []
    for images, labels in tqdm(dataloader):
        img_embs.append(model.encode_image(images.to(device)))
        lbls.append(labels)
    free_gpu_cache()
    img_embs, lbls = torch.vstack(img_embs), torch.hstack(lbls)
    print(f"Image embedding shape: {img_embs.shape}")
    print(f"Labels shape:          {lbls.shape}")
    torch.save(img_embs, f'checkpoints/{cfg.DATASET.NAME}_image_embs_{i}.pt')
    torch.save(lbls,     f'checkpoints/{cfg.DATASET.NAME}_labels_{i}.pt')

100%|██████████| 13/13 [00:01<00:00,  9.95it/s]


Image embedding shape: torch.Size([204, 512])
Labels shape:          torch.Size([204])


100%|██████████| 13/13 [00:01<00:00,  9.69it/s]


Image embedding shape: torch.Size([204, 512])
Labels shape:          torch.Size([204])


100%|██████████| 13/13 [00:01<00:00,  9.80it/s]


Image embedding shape: torch.Size([204, 512])
Labels shape:          torch.Size([204])


100%|██████████| 13/13 [00:01<00:00, 10.15it/s]


Image embedding shape: torch.Size([204, 512])
Labels shape:          torch.Size([204])


100%|██████████| 13/13 [00:01<00:00, 10.26it/s]


Image embedding shape: torch.Size([204, 512])
Labels shape:          torch.Size([204])


Aggregate temp files

In [None]:
img_embs = torch.vstack([torch.load(f'checkpoints/{cfg.DATASET.NAME}_image_embs_{i}.pt') for i in range(len(dataloaders))])
lbls = torch.hstack([torch.load(f'checkpoints/{cfg.DATASET.NAME}_labels_{i}.pt') for i in range(len(dataloaders))])

Save for future use

In [None]:
torch.save(img_embs, f'checkpoints/{cfg.DATASET.NAME}_image_embs.pt')
torch.save(lbls,     f'checkpoints/{cfg.DATASET.NAME}_labels.pt')

# Prompt Encoding

## Prompt Learner

### CoOp

In [None]:
class CoOp(nn.Module):
    def __init__(self, cfg, model):
        super().__init__()
        self.n_cls = len(cfg.DATASET.CLASSNAMES)
        self.n_ctx = cfg.TRAIN.N_CTX
        ctx_dim = model.ln_final.weight.shape[0]
        
        # A prompt in CoOp with classname at the back looks like: [SOS][V1]...[Vn][CLS][EOS]
        # I will assume classname is always at the back for now.
        
        # context init
        if cfg.TRAIN.CTX_INIT:
            # fixed init (assume global ctx)
            ctx_init = cfg.TRAIN.CTX_INIT.replace("_", " ") # The "_" is for fill-in of class name
            with torch.no_grad():
                token_emb = model.token_embedding(clip.tokenize(ctx_init)).type(model.dtype)
            ctx_vectors = token_emb[0, 1:self.n_ctx+1, :]
            prefix = ctx_init
        else:
            # random init
            if cfg.TRAIN.CSC:
                # class-specific ctx
                ctx_vectors = torch.empty(self.n_cls, self.n_ctx, ctx_dim, dtype=model.dtype)
            else:
                # global ctx
                ctx_vectors = torch.empty(self.n_ctx, ctx_dim, dtype=model.dtype)
            nn.init.normal_(ctx_vectors, std=cfg.TRAIN.PARAM_STD)
            prefix = " ".join(["X"]*self.n_ctx)
        
        # context vectors (THE ONLY PART THAT NEEDS TO BE TRAINED)
        self.ctx = nn.Parameter(ctx_vectors)
        
        # prompt finalization
        classnames = [classname.replace("_", " ") for classname in cfg.DATASET.CLASSNAMES]
        raw_prompts = [prefix + " " + classname + "." for classname in classnames]
        self.tokenized_prompts = torch.cat([clip.tokenize(raw_prompt) for raw_prompt in raw_prompts])
        with torch.no_grad():
            token_emb = model.token_embedding(self.tokenized_prompts).type(model.dtype)
        
        # [SOS]
        self.register_buffer("prefix", token_emb[:, :1, :])
        # [CLS][EOS]
        self.register_buffer("suffix", token_emb[:, self.n_ctx+1:, :])
        
    def forward(self):
        # expand global ctx to match n_cls (i.e., a total of n_cls ctx vectors)
        ctx = self.ctx
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        prompt_embs = torch.cat(
            [
                self.prefix,    # (n_cls, 1, dim)
                ctx,            # (n_cls, n_ctx, dim)
                self.suffix     # (n_cls, sfx_len, dim)
            ], dim=1
        )
        
        return prompt_embs

### CoCoOp

In [None]:
class CoCoOp(nn.Module):
    def __init__(self, cfg, model):
        super().__init__()
        self.n_cls = len(cfg.DATASET.CLASSNAMES)
        self.n_ctx = cfg.TRAIN.N_CTX
        self.dtype = model.dtype
        ctx_dim = model.ln_final.weight.shape[0]
        vis_dim = model.visual.output_dim
        # img_dim = model.visual.input_resolution
        
        # A prompt in CoOp with classname at the back looks like: [SOS][V1]...[Vn][CLS][EOS]
        # I will assume classname is always at the back for now.
        
        # context init (always global in cocoop)
        if cfg.TRAIN.CTX_INIT:
            # fixed init
            ctx_init = cfg.TRAIN.CTX_INIT.replace("_", " ") # The "_" is for fill-in of class name
            with torch.no_grad():
                token_emb = model.token_embedding(clip.tokenize(ctx_init)).type(model.dtype)
            ctx_vectors = token_emb[0, 1:self.n_ctx+1, :]
            prefix = ctx_init
        else:
            # random init
            ctx_vectors = torch.empty(self.n_ctx, ctx_dim, dtype=model.dtype)
            nn.init.normal_(ctx_vectors, std=cfg.TRAIN.PARAM_STD)
            prefix = " ".join(["X"]*self.n_ctx)
        
        # context vectors
        self.ctx = nn.Parameter(ctx_vectors)      
          
        # FF (image -> ctx bias)
        self.net = nn.Sequential(
            nn.Linear(vis_dim, vis_dim//16),
            nn.ReLU(inplace=True),
            nn.Linear(vis_dim//16, ctx_dim)
        )
        
        # prompt finalization
        classnames = [classname.replace("_", " ") for classname in cfg.DATASET.CLASSNAMES]
        raw_prompts = [prefix + " " + classname + "." for classname in classnames]
        self.tokenized_prompts = torch.cat([clip.tokenize(raw_prompt) for raw_prompt in raw_prompts])
        with torch.no_grad():
            token_emb = model.token_embedding(self.tokenized_prompts).type(model.dtype)
        
        # [SOS]
        self.register_buffer("prefix", token_emb[:, :1, :])
        # [CLS][EOS]
        self.register_buffer("suffix", token_emb[:, self.n_ctx+1:, :])
        
    def forward(self, img_feats):
        bias = self.net(img_feats.type(self.dtype)).unsqueeze(1) # (batch, 1, dim)
        ctx = self.ctx
        ctx = ctx.unsqueeze(0)                  # (1, n_ctx, dim)
        ctx_shifted = ctx + bias                # (batch, n_ctx, dim)
        
        prompt_embs = []
        for ctx_shifted_i in ctx_shifted:
            ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1)

            prompt_emb = torch.cat(
                [
                    self.prefix,    # (n_cls, 1, dim)
                    ctx_i,          # (n_cls, n_ctx, dim)
                    self.suffix     # (n_cls, sfx_len, dim)
                ], dim=1
            )
            
            prompt_embs.append(prompt_emb)
        
        return torch.stack(prompt_embs)

## Custom Model

In [None]:
class CustomCLIPCoOp(nn.Module):
    def __init__(self, cfg, model):
        super().__init__()
        self.prompt_learner = CoOp(cfg, model)
        # self.classnames = cfg.DATASET.CLASSNAMES
        self.model = model
        
        # The freezing part was originally done in the training part, but why not just here since we are not modifying anything of CLIP anyway?
        for _,param in self.model.named_parameters():
            param.requires_grad = False
    
    # note that this is nearly identical to the model.encode_text() function from CLIP
    # the only difference is that we already have prompt_embs rather than having to recompute it
    def encode_text(self, prompt_embs, tokenized_prompts):
        x = prompt_embs + self.model.positional_embedding.type(self.model.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.model.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.model.ln_final(x).type(self.model.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.model.text_projection
        return x
    
    def forward(self, img_feats):
        prompt_embs = self.prompt_learner()

        # encode prompts
        txt_feats = self.encode_text(prompt_embs, self.prompt_learner.tokenized_prompts)
        
        # normalize
        img_feats = img_feats/img_feats.norm(dim=-1, keepdim=True).type(self.model.dtype)   # It was float16 but model.dtype = float32.
        txt_feats = txt_feats/txt_feats.norm(dim=-1, keepdim=True)
        
        logits = self.model.logit_scale.exp() * img_feats @ txt_feats.t()
        return logits

In [None]:
class CustomCLIPCoCoOp(nn.Module):
    def __init__(self, cfg, model):
        super().__init__()
        self.prompt_learner = CoCoOp(cfg, model)
        self.model = model
        
        # The freezing part was originally done in the training part, but why not just here since we are not modifying anything of CLIP anyway?
        for _,param in self.model.named_parameters():
            param.requires_grad = False
    
    # note that this is nearly identical to the model.encode_text() function from CLIP
    # the only difference is that we already have prompt_embs rather than having to recompute it
    def encode_text(self, prompt_embs, tokenized_prompts):
        x = prompt_embs + self.model.positional_embedding.type(self.model.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.model.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.model.ln_final(x).type(self.model.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.model.text_projection
        return x
    
    def forward(self, img_feats):
        logit_scale = self.model.logit_scale.exp()
        tokenized_prompts = self.prompt_learner.tokenized_prompts
        prompt_embs = self.prompt_learner(img_feats)
        
        # normalize
        img_feats = img_feats/img_feats.norm(dim=-1, keepdim=True).type(self.model.dtype)   # It was float16 but model.dtype = float32.
        
        logits = []
        for pts_i, img_i in zip(prompt_embs, img_feats):
            txt_feats = self.encode_text(pts_i, tokenized_prompts)
            txt_feats = txt_feats / txt_feats.norm(dim=-1, keepdim=True)
            l_i = logit_scale * img_i @ txt_feats.t()
            logits.append(l_i)

        return torch.stack(logits)

# Train

In [None]:
def train(model, data, criterion, optimizer, scheduler, n_epoch, n_shots):
    for epoch in range(n_epoch):
        print(f"Epoch {epoch+1}/{n_epoch}:")
        print('-' * 20)
        running_loss = 0.0
        running_corr = 0
        
        for batch in data:
            img_embs, lbls = batch["image"].to(device), batch["label"].to(device)
            outputs = model(img_embs)
            loss = criterion(outputs, lbls)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item()*img_embs.size(0)
            running_corr += torch.sum(preds==lbls.data)
            
        epoch_loss = running_loss / n_shots
        epoch_acc = running_corr.double() / n_shots
        print('loss: {:.4f}; acc: {:.4f}'.format(epoch_loss, epoch_acc))
        scheduler.step()

Init model

In [None]:
model = CustomCLIPCoCoOp(cfg, clip.load(cfg.MODEL.NAME, device='cpu')[0]) if cfg.USE_COCOOP else CustomCLIPCoOp(cfg, clip.load(cfg.MODEL.NAME, device='cpu')[0])
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.prompt_learner.parameters(), lr=cfg.OPTIM.LR)
scheduler = build_lr_scheduler(optimizer, cfg.OPTIM)

load img embs

In [None]:
img_embs, lbls = torch.load(f'checkpoints/{cfg.DATASET.NAME}_image_embs.pt'), torch.load(f'checkpoints/{cfg.DATASET.NAME}_labels.pt')
print(img_embs.shape, lbls.shape)
torch.unique(lbls, return_counts=True)

torch.Size([1020, 512]) torch.Size([1020])


(tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
          14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
          28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
          42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
          56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
          70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
          84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
          98,  99, 100, 101]),
 tensor([10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
         10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
         10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
         10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
         10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
         10

select few-shot examples

In [None]:
# THIS PART SHOULD BE EDITTED TO MATCH THEIR FEW-SHOT EXAMPLES IF PROVIDED
img_embs, lbls = img_embs[0:31:10], lbls[0:31:10]

# manually build dataloader (since we only have very few shots this will just be one batch)
data = [{'image': img_embs, 'label': lbls}]

In [None]:
train(model, data, criterion, optimizer, scheduler, n_epoch=cfg.OPTIM.MAX_EPOCH, n_shots=4)

Epoch 1/50:
--------------------
torch.Size([102, 77, 512])
torch.Size([102, 77])
loss: 3.5033; acc: 0.2500
Epoch 2/50:
--------------------
torch.Size([102, 77, 512])
torch.Size([102, 77])
loss: 3.4098; acc: 0.2500
Epoch 3/50:
--------------------
torch.Size([102, 77, 512])
torch.Size([102, 77])
loss: 1.9729; acc: 0.5000
Epoch 4/50:
--------------------
torch.Size([102, 77, 512])
torch.Size([102, 77])
loss: 1.5523; acc: 0.5000
Epoch 5/50:
--------------------
torch.Size([102, 77, 512])
torch.Size([102, 77])
loss: 1.1077; acc: 0.7500
Epoch 6/50:
--------------------
torch.Size([102, 77, 512])
torch.Size([102, 77])
loss: 1.2063; acc: 0.7500
Epoch 7/50:
--------------------
torch.Size([102, 77, 512])
torch.Size([102, 77])
loss: 0.6031; acc: 0.7500
Epoch 8/50:
--------------------
torch.Size([102, 77, 512])
torch.Size([102, 77])
loss: 0.4643; acc: 0.7500
Epoch 9/50:
--------------------
torch.Size([102, 77, 512])
torch.Size([102, 77])
loss: 0.4591; acc: 0.7500
Epoch 10/50:
---------------