In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import json
import os
from PIL import Image
import subprocess
from typing import List
import torch.optim as optim
from google.colab import drive

In [None]:
device = torch.device('cuda:' + str(0) if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')
print(device)

cuda:0


# Create Embeddings

In [None]:
def generate_fake_embeddings(n_classes=32, n_images_per_class=128, dim=16):
  text_emb = np.random.normal(size=(n_classes, dim))
  diff = np.random.normal(size=(dim))

  image_emb = np.random.normal(size=(n_classes, n_images_per_class, dim)) * 0.2
  for c in range(n_classes):
    for i in range(n_images_per_class):
      for d in range(dim):
        image_emb[c, i, d] = text_emb[c, d] * diff[d] + image_emb[c, i, d]
  
  image_emb = image_emb.reshape(n_classes * n_images_per_class, dim)
  text_emb = np.repeat(text_emb, n_images_per_class, axis=0)
  ids = np.repeat(np.arange(n_classes).astype(int), n_images_per_class, axis=0)

  return torch.Tensor(text_emb).to(device), torch.Tensor(image_emb).to(device), ids

In [None]:
text, img, ids = generate_fake_embeddings(4, 3, 2)

In [None]:
text

tensor([[ 1.6409, -1.0399],
        [ 1.6409, -1.0399],
        [ 1.6409, -1.0399],
        [-0.3177,  0.1545],
        [-0.3177,  0.1545],
        [-0.3177,  0.1545],
        [-0.4827,  0.4798],
        [-0.4827,  0.4798],
        [-0.4827,  0.4798],
        [-0.4576, -0.3284],
        [-0.4576, -0.3284],
        [-0.4576, -0.3284]], device='cuda:0')

In [None]:
img

tensor([[ 0.0969,  0.4120],
        [ 0.2877, -0.0457],
        [ 0.0748,  0.0336],
        [-1.4416,  1.3112],
        [-1.8129,  1.0428],
        [-1.3793,  1.3609],
        [-1.8015,  1.5771],
        [-1.7325,  1.4416],
        [-1.7159,  1.3838],
        [-1.8142,  0.7534],
        [-1.8160,  0.9299],
        [-1.4563,  0.7801]], device='cuda:0')

In [None]:
ids

tensor([0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3.], device='cuda:0')

# Model

In [None]:
class CLIP(nn.Module):

  def __init__(self, text_input_dim, image_input_dim, latent_dim):
    super().__init__()

    self.text_input_dim = text_input_dim
    self.image_input_dim = image_input_dim
    self.latent_dim = latent_dim

    self.text_fc = nn.Linear(text_input_dim, latent_dim)
    self.image_fc = nn.Linear(image_input_dim, latent_dim)

  def forward(self, text, image):
    #print(text.shape, image.shape, self.text_input_dim, self.image_input_dim)
    # [batch_size, latent_dim]
    text_latent = self.text_fc(text)
    image_latent = self.image_fc(image)

    text_norms = torch.linalg.norm(text_latent, axis=1)
    image_norms = torch.linalg.norm(image_latent, axis=1)

    text_norms_repeated = text_norms.repeat(len(image), 1).T
    image_norms_repeated = image_norms.repeat(len(text), 1).T

    cosine_sim_unnormalised = text_latent @ image_latent.T
    #print(cosine_sim_unnormalised.shape, text_norms_repeated.shape, image_norms_repeated.shape)
    cosine_sim_normalised = cosine_sim_unnormalised / text_norms_repeated / image_norms_repeated.T
    #print(text_norms_repeated.shape, image_norms_repeated.shape, cosine_sim_normalised.shape)

    return cosine_sim_normalised

In [None]:
def train_clip(model, text_embeddings, image_embeddings, ids, optimizer, loss_fn, batch_size, n_epochs):
  assert text_embeddings.shape[0] == image_embeddings.shape[0]

  xs_len = text_embeddings.shape[0]
  n_batches = int(xs_len/batch_size)
  
  for epoch in range(n_epochs):
    model.train()

    shuff_idxs = torch.randperm(xs_len)
    for batch in range(n_batches):
        batch_idxs = shuff_idxs[batch*batch_size:(batch+1)*batch_size]

        batch_text = text_embeddings[batch_idxs]
        batch_image = image_embeddings[batch_idxs]
        batch_ids = ids[batch_idxs]

        optimizer.zero_grad()

        output = model(batch_text, batch_image)

        target = np.zeros((batch_size, batch_size))
        for i in range(batch_size):
          target[i] = (batch_ids[i] == batch_ids)
        target = torch.Tensor(target.astype(int) * 2 - 1).to(device)

        #print(output.shape, target.shape)

        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()

In [None]:
def evaluate_clip(model, test_text, test_image, test_ids):
  # zero shot 4 way
  shots = np.arange(len(test_ids))
  np.random.shuffle(shots)
  shots = shots.reshape(len(test_ids) // 4, 4)

  correct = 0
  total = 0

  for i in range(len(shots)):
    shot_ids = shots[i]

    # Take the description from the first image
    shot_text = test_text[shot_ids[0]].unsqueeze(0)
    shot_image = test_image[shot_ids]
    shot_cat_ids = test_ids[shot_ids]


    with torch.no_grad():
      preds = model(shot_text, shot_image)
    
    if preds.argmax() == 0:
      correct += 1
    total += 1
  
  return correct / total

# Training

In [None]:
N_CLASSES = 512
N_IMAGES_PER_CLASS = 128
DIM = 16

text_embs, image_embs, ids = generate_fake_embeddings(n_classes=N_CLASSES, n_images_per_class=N_IMAGES_PER_CLASS, dim=DIM)

train_text = text_embs[:int(N_CLASSES * N_IMAGES_PER_CLASS * 0.75)]
train_image = image_embs[:int(N_CLASSES * N_IMAGES_PER_CLASS * 0.75)]
train_ids = ids[:int(N_CLASSES * N_IMAGES_PER_CLASS * 0.75)]

test_text = text_embs[int(N_CLASSES * N_IMAGES_PER_CLASS * 0.75):]
test_image = image_embs[int(N_CLASSES * N_IMAGES_PER_CLASS * 0.75):]
test_ids = ids[int(N_CLASSES * N_IMAGES_PER_CLASS * 0.75):]


model = CLIP(train_text.shape[1], train_image.shape[1], 12).to(device)
optimizer = optim.Adam(model.parameters())
loss_fn = nn.MSELoss().to(device)

train_clip(model, train_text, train_image, train_ids, optimizer, loss_fn, batch_size=256, n_epochs=5)

In [None]:
csn = evaluate_clip(model, test_text, test_image, test_ids)

In [None]:
csn

0.76904296875