In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

In [4]:
#demo using CIFAR10
#create dataset with pytorch datset and dataloaders
transform = torchvision.transforms.ToTensor()
TRAIN_BATCH_SIZE = 128

trainset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=True,
    download=True, 
    transform=transform,
)
trainloader = torch.utils.data.DataLoader(
    trainset, 
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True, 
    num_workers=2
)

testset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=False,
    download=True, 
    transform=transform,
)
testloader = torch.utils.data.DataLoader(
    testset, 
    batch_size=200,
    shuffle=False, 
    num_workers=2
)

In [None]:
NUM_CLASSES=10
IMG_PATCH_SIZE=2
NUM_IMG_PATCHES=(32 * 32 * 3) / (IMG_PATCH_SIZE ** 2)
IMG_PATCH_DIM=IMG_PATCH_SIZE ** 2 * 3

#embeddings for both images and text
img_embed = nn.Linear(in_features=IMG_PATCH_SIZE ** 2 * 3, out_features=1)
text_embed = nn.Embedding(num_embeddings=NUM_CLASSES, embedding_dim=64)

#mapping from embedding to shared embedding space
img_embed_map = nn.Linear(in_features=256, out_features=32)
text_embed_map = nn.Linear(in_features=64, out_features=32)

#learnable temperature parameter
tau = nn.Parameter(torch.randn(1))

In [14]:
X, y = next(iter(trainloader))
X.shape, y.shape

(torch.Size([128, 3, 32, 32]), torch.Size([128]))

In [None]:
#forward pass

img_patches = F.unfold(X, kernel_size=2, stride=2).transpose(-2, -1)
img_e = img_embed(img_patches).squeeze(-1)
labels_e = text_embed(y)

img_e = img_embed_map(img_e)
labels_e = text_embed_map(labels_e)

#l2 norm
img_e = F.normalize(img_e, dim=-1)
labels_e = F.normalize(labels_e, dim=-1)

#cos sim
logits = img_e @ labels_e.transpose(-2, -1) * torch.exp(tau)

#losses
batch_size = logits.shape[0]
loss_i = F.cross_entropy(logits, torch.arange(batch_size))
loss_t = F.cross_entropy(logits, torch.arange(batch_size))
loss = (loss_i + loss_t) / 2

loss