# CoOp

## CoOp

## 0. Info

### Paper
* title: Learning to Prompt for Vision-Language Models
* author: Kaiyang Zhou et al.
* url: https://arxiv.org/abs/2109.01134


### Features
* dataset: sports-100
* train projection layer

### Referecnes
* https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip

## 1. Setup

In [1]:
import einops
import easydict
import numpy as np
import pandas as pd
import scipy.io as sio
from PIL import Image
from glob import glob
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from transformers import CLIPProcessor, CLIPModel

In [2]:
cfg = easydict.EasyDict(
    device = 'cuda:0',
    pretrained = 'openai/clip-vit-base-patch32',
    prefix_length = 32,
    num_training_steps = 1000,
    k = 16
)

## 2. Data

In [3]:
def find_samples(split, k):
    samples = []
    classes = sorted(glob(f'sports-100/{split}/*'))
    for c in classes:
        class_samples = sorted(glob(f'{c}/*.jpg'))
        if split == 'train':
            samples += np.random.choice(class_samples, k, replace=False).tolist()
        else:
            samples += class_samples
    
    classes = [i.split('/')[-1] for i in classes]
    return samples, classes


class Dataset(torch.utils.data.Dataset):
    def __init__(self, split, k=4):
        self.samples, self.classes = find_samples(split, k)
        
    def __len__(self):
        return len(self.samples)
    
    
    def __getitem__(self, idx):
        item = self.samples[idx]
        class_name = item.split('/')[-2]
        
        img = Image.open(item).convert("RGB")
        class_id = self.classes.index(class_name)
        return img, class_id    

    
def collate_fn(batch, processor, prefix_length):
    images, labels = map(list, zip(*batch))
    inputs = processor(images=images, return_tensors='pt')
    return inputs.pixel_values, torch.tensor(labels)

In [4]:
processor = CLIPProcessor.from_pretrained(cfg.pretrained)

In [5]:
train_dataset = Dataset('train', cfg.k)
eval_dataset = Dataset('test')

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=lambda batch: collate_fn(batch, processor, cfg.prefix_length))
eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=16, shuffle=False, collate_fn=lambda batch: collate_fn(batch, processor, cfg.prefix_length))

In [6]:
class_input_ids = processor(text=train_dataset.classes, max_length=32, padding=True, return_tensors='pt').input_ids.to(cfg.device)

  "`max_length` is ignored when `padding`=`True` and there is no truncation strategy. "


In [7]:
inputs, labels = next(iter(train_loader))
inputs.size(), class_input_ids.size()

(torch.Size([16, 3, 224, 224]), torch.Size([100, 6]))

## 3. Model

In [8]:
class Prefix(nn.Module):
    def __init__(self, prefix_length, hidden_size):
        super().__init__()
        self.prefix_length = prefix_length
        self.input_tokens = torch.arange(prefix_length).long()
        self.embeddings = nn.Embedding(prefix_length, hidden_size)
        self.linear = nn.Linear(512, 512)
        
    def forward(self, input_ids):
        bs = input_ids.size(0)
        input_tokens = einops.repeat(self.input_tokens, 'seq -> bs seq', bs=bs).to(input_ids.device)
        prefix_embeds = self.embeddings(input_tokens)
        return prefix_embeds


def get_text_features(model, prefix, class_input_ids):
    bsz, seq_len = class_input_ids.size()
    seq_len += prefix.prefix_length
    
    inputs_embeds = model.text_model.embeddings(class_input_ids)
    prefix_embeds = prefix(class_input_ids)
    hidden_states = torch.cat([prefix_embeds, inputs_embeds], dim=1)
    # hidden_states = inputs_embeds
        
    causal_attention_mask = model.text_model._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(hidden_states.device)
    encoder_outputs = model.text_model.encoder(
        inputs_embeds=hidden_states,
        causal_attention_mask=causal_attention_mask,
        attention_mask=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
    )
    
    last_hidden_state = encoder_outputs[0]
    last_hidden_state = model.text_model.final_layer_norm(last_hidden_state)
    pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), class_input_ids.argmax(dim=-1) + prefix.prefix_length]
    # pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), class_input_ids.argmax(dim=-1)]

    text_features = prefix.linear(pooled_output)
    return text_features

In [9]:
model = CLIPModel.from_pretrained(cfg.pretrained).eval().to(cfg.device)

In [10]:
prefix = Prefix(cfg.prefix_length, model.config.projection_dim).to(cfg.device)
optimizer = torch.optim.Adam(prefix.parameters(), lr=1e-4, weight_decay=1e-2)

## 4. Train

In [11]:
pbar = tqdm(range(cfg.num_training_steps))

for st in pbar:
    inputs, labels = next(iter(train_loader))
    inputs, labels = inputs.to(cfg.device), labels.to(cfg.device)
    
    text_embeds = get_text_features(model, prefix, class_input_ids)
    image_embeds = model.get_image_features(inputs)
    image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
    text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)

    logit_scale = model.logit_scale.exp()
    logits = torch.matmul(image_embeds, text_embeds.T) * logit_scale
    loss = F.cross_entropy(logits, labels)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    preds = logits.argmax(dim=-1)
    acc = (preds == labels).float().mean()
    pbar.set_postfix({'loss': loss.item(), 'acc': acc.item()})

  0%|          | 0/1000 [00:00<?, ?it/s]

In [12]:
preds, targets = [], []
for inputs, labels in tqdm(eval_loader):
    inputs, labels = inputs.to(cfg.device), labels.to(cfg.device)
    
    with torch.no_grad():
        text_embeds = get_text_features(model, prefix, class_input_ids) 
        image_embeds = model.get_image_features(inputs)
    image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
    text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)

    logit_scale = model.logit_scale.exp()
    logits = torch.matmul(image_embeds, text_embeds.T) * logit_scale
    
    preds.append(logits.argmax(dim=-1).cpu())
    targets.append(labels.cpu())
    
    
preds = torch.cat(preds, dim=0)
targets = torch.cat(targets, dim=0)
acc = (preds == targets).float().mean().item() * 100
print(f'Accuracy: {acc:.2f}')

  0%|          | 0/32 [00:00<?, ?it/s]

Accuracy: 89.60
