In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.models as models
import torch.utils.data as data
import json
import os
from PIL import Image
import subprocess
from typing import List
from torchvision.transforms import Compose
from torchvision import transforms
import torch.optim as optim

In [None]:
device = torch.device('cuda:' + str(0) if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')
print(device)

cuda:0


# Create Embeddings

In [None]:
def generate_fake_embeddings(n=2, d=5):
  base_truth = np.random.normal(size=(n, d))

  emb1 = base_truth ** 2
  emb2 = np.concatenate((base_truth, np.exp(base_truth)), axis=1)

  return emb1, emb2

# Model

In [None]:
class CLIP(nn.Module):

  def __init__(self, text_input_dim, image_input_dim, latent_dim):
    super().__init__()

    self.text_input_dim = text_input_dim
    self.image_input_dim = image_input_dim
    self.latent_dim = latent_dim

    self.text_fc = nn.Linear(text_input_dim, latent_dim)
    self.image_fc = nn.Linear(image_input_dim, latent_dim)

  def forward(self, text, image):
    #print(text.shape, image.shape, self.text_input_dim, self.image_input_dim)
    # [batch_size, latent_dim]
    text_latent = self.text_fc(text)
    image_latent = self.image_fc(image)

    text_norms = torch.linalg.norm(text_latent, axis=1)
    image_norms = torch.linalg.norm(image_latent, axis=1)

    text_norms_repeated = text_norms.repeat(len(text), 1).T
    image_norms_repeated = image_norms.repeat(len(image), 1).T

    cosine_sim_unnormalised = text_latent @ image_latent.T
    #print(cosine_sim_unnormalised.shape, text_norms_repeated.shape, image_norms_repeated.shape)
    cosine_sim_normalised = cosine_sim_unnormalised / text_norms_repeated / image_norms_repeated.T

    return cosine_sim_normalised

In [None]:
def train_clip(model, text_embeddings, image_embeddings, optimizer, loss_fn, batch_size, n_epochs):
  assert text_embeddings.shape[0] == image_embeddings.shape[0]

  xs_len = text_embeddings.shape[0]
  n_batches = int(xs_len/batch_size)
  
  for epoch in range(n_epochs):
    model.train()

    shuff_idxs = torch.randperm(xs_len)
    for batch in range(n_batches):
        batch_idxs = shuff_idxs[batch*batch_size:(batch+1)*batch_size]

        batch_text = text_embeddings[batch_idxs]
        batch_image = image_embeddings[batch_idxs]

        optimizer.zero_grad()

        output = model(batch_text, batch_image)

        # CHANGE
        # ------------------------------------------CHANGE----------------------------------------------
        # CHANGE
        target = torch.eye(batch_size, batch_size).to(device)

        #print(output.shape, target.shape)

        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()

# Training

In [None]:
text_embeddings, image_embeddings = generate_fake_embeddings(n=1024, d=16)

text_embeddings = torch.Tensor(text_embeddings).to(device)
image_embeddings = torch.Tensor(image_embeddings).to(device)

model = CLIP(text_embeddings.shape[1], image_embeddings.shape[1], 12).to(device)
optimizer = optim.Adam(model.parameters())
loss_fn = nn.MSELoss().to(device)

train_clip(model, text_embeddings, image_embeddings, optimizer, loss_fn, batch_size=64, n_epochs=5)