In this example, we first load the pre-trained CLIP model ("ViT-B/32" in this case) and the necessary preprocessing functions. We then load an example image and caption, and generate embeddings for each using the encode_image and encode_text functions provided by the CLIP model. Finally, we concatenate the two embeddings to create a joint embedding of size 1024.

Note that this example only uses one image and one caption, but the same approach can be used to create joint embeddings for larger datasets by iterating over the images and captions and concatenating their embeddings. Additionally, the pre-trained CLIP model used in this example can be fine-tuned on a larger image-caption dataset to learn a joint embedding space that is better suited to a specific task.

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

# Load a pre-trained ResNet-50 model for image feature extraction
resnet = models.resnet50(pretrained=True)
modules = list(resnet.children())[:-1]
resnet = nn.Sequential(*modules)
resnet.eval()

# Load a pre-trained BERT model for text feature extraction
bert = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased')
bert.eval()

# Define the linear projection for mapping image and text embeddings to joint space
projection = nn.Linear(2048, 512)  # Map ResNet-50 output (2048) to 512-dim joint space
projection.eval()

# Define an image and a text input
image_path = 'path/to/image.jpg'
text_input = "The quick brown fox jumps over the lazy dog."

# Extract image embedding using ResNet-50
image = Image.open(image_path)
image_tensor = transforms.ToTensor()(image)
image_tensor = image_tensor.unsqueeze_(0)
with torch.no_grad():
    image_embedding = resnet(image_tensor).squeeze()

# Extract text embedding using BERT
input_ids = torch.tensor(bert.tokenizer.encode(text_input, add_special_tokens=True)).unsqueeze(0)
with torch.no_grad():
    bert_outputs = bert(input_ids)
    text_embedding = bert_outputs[0][:, 0, :]

# Map image and text embeddings to joint space
with torch.no_grad():
    image_embedding = projection(image_embedding)
    text_embedding = projection(text_embedding)

# Compute cosine similarity between image and text embeddings in joint space
cos_sim = nn.CosineSimilarity(dim=0)
similarity = cos_sim(image_embedding, text_embedding)

print(similarity)
