I know how to do vector look ups from RAG, but handling the multimodal input is new to me.

I want to try learning how to get multimodal embeddings myself.

In [64]:
# Don't wanna follow the tutorial here, try pulling a multimodal embedding model from openclip

import torch
from PIL import Image
import open_clip
from datasets import load_dataset
from tqdm import tqdm
from chromadb import Client
import numpy as np

In [None]:
import warnings
warnings.filterwarnings('ignore')

import os
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv()) # read local .env file

In [None]:
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
model.eval()  # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
tokenizer = open_clip.get_tokenizer('ViT-B-32')

image = preprocess(Image.open("docs/CLIP.png")).unsqueeze(0)
text = tokenizer(["a diagram", "a dog", "a cat"])

with torch.no_grad(), torch.cuda.amp.autocast():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

print("Label probs:", text_probs)  # prints: [[1., 0., 0.]]

Label probs: tensor([[9.9950e-01, 4.1207e-04, 8.5316e-05]])


In [None]:
# embedding size is 512
text_features.shape, image_features.shape

torch.Size([3, 512])

In [None]:
# use flickr30k to build image encoding and text encoding, then store them into vector db for look ups. How to track whether it's correct?
from torch.utils.data import DataLoader
dataset = load_dataset('nlphuji/flickr30k')

# Have to load each image individually, no batching
processed_images = []
captions_tokens = []
for sample in tqdm(dataset['test']):
    processed_images.append(preprocess(sample['image']).unsqueeze(0))
    captions_tokens.append(tokenizer(sample['caption']))
# batch_size = 16
# dataloader = DataLoader(dataset['test'], batch_size=batch_size, shuffle=True)

# text_embeddings = []
# image_embeddings = []

# for batch in dataloader:
#     images = preprocess(batch['image'])
#     texts = batch['caption'] # text will be nested as there's multiple captions per image

#     image_features = model.encode_image(image)

100%|██████████| 31014/31014 [02:26<00:00, 211.21it/s]


In [55]:
# Batch and convert into embeddings
batch_size = 16
image_embeddings = []
text_embeddings = []
with torch.no_grad(), torch.cuda.amp.autocast():
    for i in tqdm(range(0, len(processed_images), batch_size)):
        image_batch = processed_images[i:i+batch_size]
        caption_batch = captions_tokens[i:i+batch_size]
        _image_embeddings = model.encode_image(image)
        _text_embeddings = model.encode_text(text)
        _image_embeddings /= _image_embeddings.norm(dim=-1, keepdim=True)
        _text_embeddings /= _text_embeddings.norm(dim=-1, keepdim=True)
        image_embeddings.append(_image_embeddings)
        text_embeddings.append(_text_embeddings)

100%|██████████| 1939/1939 [03:29<00:00,  9.27it/s]


In [None]:
# ok as expected, now to store the embeddings into chroma
image_embeddings[0].shape, text_embeddings[0].shape

(torch.Size([1, 512]), torch.Size([3, 512]))

In [71]:
# Convert list of tensor to np array for loading into chroma
np_image_embeddings = np.array(image_embeddings).squeeze()
np_text_embeddings = np.array(text_embeddings).squeeze()

In [None]:
# Load everything into chroma
client = Client()
collection = client.create_collection("flickr30k")

# Set ids as the index in the original list, so we can verify the results.
image_ids = [f'id_{i}' for i in range(len(np_image_embeddings))]

# Set ids for captions properly as each set of image has multiple 
caption_ids = []
for image_id, caption_embeddings in enumerate(np_text_embeddings):
    # each caption_embeddings can contain multiple vectors, as flickr30k has multiple captions per image
    caption_embeddings_ids = [f'image_id_{image_id}_caption_id_{i}' for i in range(len(caption_embeddings))]
    caption_ids.append(caption_embeddings_ids)

# load into chroma
vector_load_batch_size = 512
for i in tqdm(range(0, len(image_ids), vector_load_batch_size)):
    # Load images
    image_id_batch = image_ids[i:i+vector_load_batch_size]
    image_embedding_batch = np_image_embeddings[i:i+vector_load_batch_size]
    collection.add(ids=image_id_batch, embeddings=image_embedding_batch)

    # Load captions
    caption_id_batch = [item for sublist in caption_ids[i:i+vector_load_batch_size] for item in sublist]
    caption_embedding_batch = [item for sublist in np_text_embeddings[i:i+vector_load_batch_size] for item in sublist]
    collection.add(ids=caption_id_batch, embeddings=caption_embedding_batch)

