In [None]:
import torch
import clip
from PIL import Image

# Load the CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device=device)

# Load the image and preprocess it
image = Image.open("example_image.jpg")
image_input = preprocess(image).unsqueeze(0).to(device)

# Prepare the caption prompt
caption_prompt = "a photo of a cat sitting on a table"

# Encode the caption prompt using CLIP
caption_input = clip.tokenize(caption_prompt).to(device)
caption_features = model.encode_text(caption_input)

# Calculate the cosine similarity between the image and caption embeddings
image_features = model.encode_image(image_input)
similarity = (100.0 * image_features @ caption_features.T).softmax(dim=-1)

# Get the most likely caption for the image
captions = clip.tokenize(["a cat sitting on a table"]).to(device)
caption_features = model.encode_text(captions)
logits_per_image, logits_per_text = model(image_input, caption_features)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
most_likely_caption = " ".join([clip.decode(c.item()) for c in captions[probs.argmax()]])

print("Most likely caption:", most_likely_caption)


# Fine tuning CLIP
image captioning (on COCO dataset)

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CocoCaptions
from torchvision.transforms import transforms
from transformers import CLIP, AdamW

# Load the COCO dataset
train_set = CocoCaptions(root='path/to/coco', annFile='path/to/annotations.json', transform=transforms.Compose([    transforms.Resize((224, 224)),    transforms.ToTensor(),    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]))
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)

# Load the pre-trained CLIP model
model = CLIP.from_pretrained('openai/clip-vit-base-patch32')

# Freeze the image encoder and unfreeze the text encoder
for param in model.visual.parameters():
    param.requires_grad = False
for param in model.transformer.parameters():
    param.requires_grad = True

# Set up the optimizer and loss function
optimizer = AdamW(model.parameters(), lr=1e-4)
loss_fn = torch.nn.CrossEntropyLoss()

# Fine-tune the model
for epoch in range(10):
    for images, captions in train_loader:
        optimizer.zero_grad()
        image_features = model.encode_image(images)
        caption_features = model.encode_text(captions)
        logits = torch.matmul(image_features, caption_features.t())
        loss = loss_fn(logits, torch.arange(len(images)).to(logits.device))
        loss.backward()
        optimizer.step()

# Generate captions for new images
image = ...
image_features = model.encode_image(image)
caption_features = ...
caption = ...

# Code for generating a generalised prompt

In [None]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import clip

model, preprocess = clip.load("ViT-B/32", device=device)
image_input = preprocess(image).unsqueeze(0).to(device)
image_features = model.encode_image(image_input)

def generate_captions(image_features, model_gpt2, tokenizer_gpt2, num_captions=5):
    # Convert image features to torch tensor and normalize
    image_features = torch.tensor(image_features).unsqueeze(0)
    image_features /= image_features.norm(dim=-1, keepdim=True)

    # Load CLIP model and tokenize image features
    clip_model, preprocess = clip.load('ViT-B/32', device='cpu')
    with torch.no_grad():
        image_features = clip_model.encode_image(image_features).float()

    # Generate captions using GPT-2
    model_gpt2.eval()
    generated_captions = []
    with torch.no_grad():
        input_ids = torch.tensor(tokenizer_gpt2.encode("", add_special_tokens=False)).unsqueeze(0)
        for i in range(num_captions):
            # Generate text using GPT-2
            output = model_gpt2(input_ids=input_ids.to(model_gpt2.device), 
                                image_features=image_features.to(model_gpt2.device))
            next_token_logits = output[0][:, -1, :]
            next_token_id = next_token_logits.argmax().unsqueeze(-1)
            input_ids = torch.cat([input_ids, next_token_id], dim=-1)
            caption = tokenizer_gpt2.decode(input_ids.squeeze().tolist())
            
            # Remove start and end tokens and append to list
            caption = caption.replace(tokenizer_gpt2.cls_token, "").replace(tokenizer_gpt2.sep_token, "")
            generated_captions.append(caption.strip())
    
    return generated_captions
