In [108]:
from PIL import Image
from pathlib import Path
from datasets import load_dataset
from decoding.listener import CLIPListener
import open_clip
import torch

import matplotlib.pyplot as plt

torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f3918187eb0>

In [2]:
images_path = "/data/tir/projects/tir3/users/svadugur/pragmatic-clip/image-sets"
images_path = Path(images_path) if isinstance(images_path, str) else images_path

dataset = (
    load_dataset("BennoKrojer/ImageCoDe", split="validation")
    .filter(lambda x: "open-images" in x["image_set"])
    .map(lambda x: {"captions": [x["description"] for i in range(2)]})
    .to_dict()
)

test_predictions = [
    {k: dataset[k][i] for k in dataset} for i in range(len(dataset["image_set"]))
]

model, _, preprocess = open_clip.create_model_and_transforms(
    "ViT-B-32", pretrained="openai", device="cuda:0", precision="bf16"
)
tokenizer = open_clip.get_tokenizer("ViT-B-32")
listener = CLIPListener(model, preprocess, tokenizer, device="cuda:0")

In [119]:
import torch.optim as optim
from torch import nn
from open_clip.transformer import text_global_pool

x = test_predictions[0]
images = [
    Image.open(images_path / x["image_set"] / f"img{i}.jpg") for i in range(10)
]
texts = x['captions']

with torch.no_grad():
    text_inputs = tokenizer(texts).to(listener.device)
    cast_dtype = listener.model.transformer.get_cast_dtype()
    pooled_text_features, seq_text_features = listener.encode_texts(texts)
    image_embeddings = listener.encode_images([images])

In [120]:
new_text_features = seq_text_features.clone()
lr = 10

for i in range(10):
    text_features = new_text_features.detach()
    text_features.requires_grad_(True).retain_grad()

    # Get pooled features
    pooled_text_features, _ = text_global_pool(text_features, text_inputs, listener.model.text_pool_type)
    if listener.model.text_projection is not None:
        if isinstance(listener.model.text_projection, nn.Linear):
            pooled_text_features = listener.model.text_projection(pooled_text_features)
        else:
            pooled_text_features = pooled_text_features @ listener.model.text_projection
    pooled_text_features = pooled_text_features / pooled_text_features.norm(dim=-1, keepdim=True)

    text_logits = listener.model.logit_scale * image_embeddings @ pooled_text_features.T
    targets = torch.arange(text_logits.size(0), device=text_logits.device)
    targets_one_hot = torch.nn.functional.one_hot(targets, num_classes=text_logits.size(1))
    loss = torch.nn.functional.cross_entropy(text_logits[0, ...], targets_one_hot[0, :], reduction='sum')
    loss.backward()
    # TODO(jykoh): text_grads will be all zeros except for the eot token index
    # Is this intended?
    text_grads = text_features.grad  # (len(texts), 77, 512)
    
    # Take a step in the negative gradient direction
    new_text_features = text_features - lr * text_grads
    
    print(f"Step {i}: Loss = {loss.item()}")

Step 0: Loss = 6.90625
Step 1: Loss = 4.0625
Step 2: Loss = 3.453125
Step 3: Loss = 3.234375
Step 4: Loss = 3.109375
Step 5: Loss = 3.015625
Step 6: Loss = 2.9375
Step 7: Loss = 2.875
Step 8: Loss = 2.8125
Step 9: Loss = 2.75
