<center><a href="https://www.nvidia.com/en-us/training/"><img src="https://dli-lms.s3.amazonaws.com/assets/general/DLI_Header_White.png" width="400" height="186" /></a></center>

# 2b. Contrastive Pre-training

While our experiments are running in the previous lab, let's take a moment to learn about an important multimodal technique: contrastive pre-training. This technique is famously used to make Contrastive Language-Image Pre-training, also known as, CLIP. Contrastive Pre-training is not limited to comparing language and image data. So, in this lab, we will be exploring how to make a contrastive pre-training model for any two correlated datatypes.

#### Learning Objectives:
The goals of this notebook are to:
* Apply computer vision techniques such as the Sobel filter to create a dataset
* Calculate the cosine similarity between two vectors
* Create a contrastive pre-training model from two embedding models
* Use the contrastive pre-training model to create a vector database

To begin, let's load the libraries needed for this course.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

# Visualization tools
import graphviz
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

import numpy as np
from numpy.linalg import norm

cmap = "gray"

## 2.1 The Data

For this notebook, we will be using the [FashionMNIST](https://www.kaggle.com/datasets/zalando-research/fashionmnist) dataset. We can use contrastive pre-training models to create vector databases, which are commonly used in retrieval augmented generation (RAG) pipelines. Let's create a database where users can look up an article of clothing based on a sketched outline.

In [2]:
train_data = torchvision.datasets.FashionMNIST(
    "./data/", download=True, transform=transforms.Compose([transforms.ToTensor()])
)

In [None]:
def show_images(dataset, num_samples=10):
    plt.figure(figsize=(16, 1))
    for i, img in enumerate(dataset):
        if i == num_samples:
            return
        plt.subplot(1, num_samples, i + 1)
        plt.imshow(torch.squeeze(img[0]), cmap=cmap)

show_images(train_data)

We can generate the image outlines by using a [Sobel filter](https://medium.com/@deepika.vadlamudi/implementing-a-sobel-filter-with-cuda-in-python-2b9b18485e31). `Gx` is a convolution kernel that will identify horizontal edges, and `Gy` will identify vertical edges.

In [None]:
Gx = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]], dtype=torch.float)
Gy = torch.tensor([[[[1, 2, 1], [0, 0, 0], [-1, -2, -1]]]], dtype=torch.float)
print("Gx:")
print(Gx)
print("Gy:")
print(Gy)

Let's start by taking a sample image.

In [None]:
img = train_data[0][0][0]
plt.imshow(img)

There is a small amount of noise in the fashionMNIST images. We can set any pixel value below a threshold to 0, and everything above the threshold to 1.

In [None]:
threshold = 0.25
new_img = img.clone()
new_img[new_img > threshold] = 1
new_img[new_img <= threshold] = 0
plt.imshow(new_img)

Then, we can apply to Sobel filters.

In [None]:
edges_x = F.conv2d(new_img[None, :,:], Gx, stride=1, padding=1)
edges_y = F.conv2d(new_img[None, :,:], Gy, stride=1, padding=1)
edges = edges_x + edges_y
edges = edges[0] # Remove batch
plt.imshow(edges)

Because the filters are asymmetrical, some parts of the outline will be positive, and some parts will be negative. We will set all nonzero values to 1 in order to have a consistent outline. The result will look like this:

In [None]:
outline = edges
outline[edges != 0] = 1
plt.imshow(outline)

**TODO**: Let's combine the above into a reusable function, `outline_img`. Please fix the FIXMEs below. Click the `...` for the solution.

In [9]:
def outline_img(img):
    threshold = 0.25
    new_img = img.clone()
    new_img[new_img > FIXME] = 1
    new_img[new_img <= FIXME] = 0
    edges_x = F.FIXME(new_img, Gx, stride=1, padding=1)
    edges_y = F.FIXME(new_img, Gy, stride=1, padding=1)
    edges = edges_x + edges_y
    edges[edges != FIXME] = 1
    return edges

In [10]:
def outline_img(img):
    threshold = 0.25
    new_img = img.clone()
    new_img[new_img > threshold] = 1
    new_img[new_img <= threshold] = 0
    edges_x = F.conv2d(new_img, Gx, stride=1, padding=1)
    edges_y = F.conv2d(new_img, Gy, stride=1, padding=1)
    edges = edges_x + edges_y
    edges[edges != 0] = 1
    return edges

We can verify `outline_img` function works by running it on a batch of data. Let's modify our `show_images` function into a `show_outlined_images` function by adding our image outliner.

In [None]:
def show_outlined_images(dataset, num_samples=10):
    plt.figure(figsize=(16, 1))
    for i, img in enumerate(dataset):
        if i == num_samples:
            return
        plt.subplot(1, num_samples, i + 1)
        img = img[0][None, :, :, :]
        img = torch.squeeze(outline_img(img))
        plt.imshow(img, cmap=cmap)

show_outlined_images(train_data)

In [12]:
IMG_SIZE = 28
BATCH_SIZE = 6

def load_fashionMNIST(data_transform, train=True):
    return torchvision.datasets.FashionMNIST(
        "./",
        download=True,
        train=train,
        transform=data_transform,
    )


def load_transformed_fashionMNIST():
    data_transforms = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),  # Scales data into [0,1]
        transforms.RandomHorizontalFlip()
    ]

    data_transform = transforms.Compose(data_transforms)
    train_set = load_fashionMNIST(data_transform, train=True)
    test_set = load_fashionMNIST(data_transform, train=False)
    return train_set, test_set

train_set, valid_set = load_transformed_fashionMNIST()
train_dataloader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
valid_dataloader = DataLoader(valid_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

## 2.2 Cosine Similarity

We can use vector embeddings to compare an outline with an article of clothing. One way to think of these embeddings is as geometric rays. We can use the angle between these rays to measure how similar they are. Try changing `x1`, `x2`, `y1`, and `y2` in the code block below to see what the cosine similarity is between the two vectors.

In [None]:
x1, y1 = [0, 1] # Change me
x2, y2 = [1, 0] # Change me

p1 = [x1, y1]
p2 = [x2, y2]

arrow_width = 0.05
plt.axis('square')
plt.xlim(-1.5, 1.5)
plt.ylim(-1.5, 1.5)
plt.arrow(0, 0, x1, y1, width=arrow_width, color="black")
plt.arrow(0, 0, x2, y2, width=arrow_width, color="green")
plt.show()

cosine = np.dot(p1, p2) / (norm(p1) * norm(p2))
print("Cosine Similarity:", cosine)

The cosine similarity is not limited to two dimensions and can be used to compare two vectors of any dimensional length as long as they are the same.

In [None]:
p1 = [1, 8, 6, 7]
p2 = [5, 3, 0, 9]

cosine = np.dot(p1, p2) / (norm(p1) * norm(p2))
print("Cosine Similarity:", cosine)

## 2.3 The Model

Before we can create a contrastive pre-training model, we should define a model to create embeddings for each of our data types. Any type of model could work as long as the output is a vector. Since both our our inputs are images, we can use a convolutional neural network. In this case, the final layer is a `Linear` layer. The size of that layer determines the size of the output embedding vector.

In [15]:
class ImgEmbedder(nn.Module):
    def __init__(
        self, in_ch, img_size, down_ch_1=32, down_ch_2=64, embed_dim=10
    ):
        super().__init__()
        kernel_size = 3
        stride = 1
        padding = 1

        # Convolution
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, down_ch_1, kernel_size, stride, padding),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(down_ch_1, down_ch_2, kernel_size, stride, padding),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # Embeddings
        self.dense_emb = nn.Sequential(
            nn.Flatten(),
            nn.Linear(down_ch_2 * (img_size // 4) ** 2, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim),
        )

    def forward(self, x):
        conv = self.conv(x)
        emb = self.dense_emb(conv)
        return F.normalize(emb)

Next, we can create a model to compare the two vectors from our embedding models. A couple of important operations are [repeat_interleave](https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html) and [repeat](https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html). Since we are comparing a batch at a time, this allows each vector from one embedder to be compared against each vector from the other embedder.

In [None]:
x = torch.tensor([1, 2, 3])
repeat_x = x.repeat(3)
repeat_interleave = x.repeat_interleave(3)
print("repeat: ", repeat_x)
print("repeat_interleave: ", repeat_interleave)

We'll then calculate the `CosineSimilarity` of each combination of vectors. Then, we will use [torch.unflatten](https://pytorch.org/docs/stable/generated/torch.unflatten.html) to turn the result into a matrix.

In [None]:
print(torch.unflatten(repeat_x, 0, (3, 3)))

In [None]:
print(torch.unflatten(repeat_interleave, 0, (3, 3)))

Let's put it all together in the `ContrastivePretraining` class below.

In [17]:
class ContrastivePretraining(nn.Module):
    def __init__(self, in_ch, img_size, embed_dim=10):
        super().__init__()
        self.baseImgEmbedder = ImgEmbedder(in_ch, img_size, down_ch_2=128, embed_dim=embed_dim)
        self.outlineEmbedder = ImgEmbedder(in_ch, img_size, embed_dim=embed_dim)
        self.cos = nn.CosineSimilarity()

    def forward(self, base_imgs, outlined_imgs):
        base_emb = self.baseImgEmbedder(base_imgs)
        outline_emb = self.outlineEmbedder(outlined_imgs)

        repeated_base_emb = base_emb.repeat_interleave(len(outline_emb), dim=0)
        repeated_outline_emb = outline_emb.repeat(len(base_emb), 1)

        similarity = self.cos(repeated_base_emb, repeated_outline_emb)
        similarity = torch.unflatten(similarity, 0, (BATCH_SIZE, BATCH_SIZE))
        similarity = (similarity + 1) / 2

        logits_per_base = similarity
        logits_per_outline = similarity.T
        return logits_per_base, logits_per_outline

Time to define the model. We chose to use two [CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) functions to make the math a little clearer in our training loop. Since we're comparing two data types, our `total_loss` will be the average of the loss for the outlines and the loss for the original images. Since [CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) compares against the index of the target value, we'll use `arange` to generate the `ground_truth` indices.

In [18]:
model = ContrastivePretraining(1, 28)
optimizer = Adam(model.parameters(), lr=0.0001)
loss_base = nn.CrossEntropyLoss()
loss_outline = nn.CrossEntropyLoss()
ground_truth = torch.arange(BATCH_SIZE, dtype=torch.long)
epochs = 1

Run the cell below to begin training. As the model is training, try to keep an eye on the output matrix. Do the values along the diagonal get closer to 1, and do the other values get closer to 0?

In [None]:
for epoch in range(epochs):
    model.train()
    for step, batch in enumerate(train_dataloader):
        optimizer.zero_grad()

        images = batch[0]
        outlines = outline_img(images)
        logits_per_base, logits_per_outline = model(images, outlines)
        total_loss = (loss_base(logits_per_base, ground_truth) + loss_outline(logits_per_outline, ground_truth))/2
        total_loss.backward()
        optimizer.step()

        if epoch % 1 == 0 and step % 2000 == 0:
            print(f"Train Epoch {epoch} | Step {step:03d} Loss: {total_loss.item()} ")
            print("Similarity:")
            print(logits_per_base)
    model.eval()
    valid_loss = 0
    for step, batch in enumerate(valid_dataloader):
        images = batch[0]
        outlines = outline_img(images)
        logits_per_base, logits_per_outline = model(images, outlines)
        total_loss = (loss_base(logits_per_base, ground_truth) + loss_outline(logits_per_outline, ground_truth))/2
        total_loss.backward()
        optimizer.step()
        valid_loss += total_loss.item()

    print(f"Valid Loss: {valid_loss / step} ")
    print("Similarity:")
    print(logits_per_base)

## 2.4 Vector Lookup

Now that we have a model, let's put it to use! We've provided an image at `images/my_outline.png`, but feel free to upload your own sketch to see how it compares. Once we've identified the image we'd like to search with, we'll convert it from 3 color channels to 1 by using `gray_mode`. 

In [None]:
gray_mode = torchvision.io.ImageReadMode.GRAY
my_outline = torchvision.io.read_image("images/my_outline.png", gray_mode)
my_outline = my_outline.float() / 255
my_outline.size()

In [None]:
plt.imshow(my_outline[0], cmap=cmap)

Next, we'll wrap it in a batch so it can be run through the neural network.

In [None]:
my_batched_outline = my_outline[None,:,:,:]
my_batched_outline.size()

Since we're searching with an outline, we'll use the `outlineEmbedder` to get its corresponding embedding.

In [None]:
out_emb = model.outlineEmbedder(my_batched_outline)
out_emb

We can turn our `train_set` into a vector database by running the images through the `baseImgEmbedder`. Then, we can calculate the `cos`ine similarity between each image embedding and the embedding for our outline.`

Since the dataset is so large, it would be challenging to run all the images through the model at once, so we'll a batch at a time.

In [None]:
best_score = -1
best_img = None
compare_batch_size = 5000

cos = nn.CosineSimilarity()
repeated_out_emb = out_emb.repeat(len(out_emb), 1)
compare_dataloader = DataLoader(train_set, batch_size=compare_batch_size)

In [None]:
for step, batch in enumerate(compare_dataloader):
    images = batch[0]
    img_embs = model.baseImgEmbedder(images)
    scores = cos(img_embs, repeated_out_emb)
    best_idx = torch.argmax(scores)
    batch_best_score = scores[best_idx]
    print("Step:", step, "| Batch Best Score:", batch_best_score.item())
    if batch_best_score.item() > best_score:
        best_score = batch_best_score.item()
        best_img = images[best_idx]
print("Best Score: ", best_score)

Let's look at the lucky winner. How close does the image match the outline?

In [None]:
plt.imshow(best_img[0])

In [None]:
compare_outline = outline_img(best_img)
plt.imshow(compare_outline[0])

In [None]:
plt.imshow(my_outline[0])

## Next

Congrats on finishing this lab. Contrastive pre-training is a powerful technique that allows us to substitute one data type for another. Not only can it be used for searching through vector databases, but other models trained on the embeddings of a contrastive pre-training model allows us to perform inference with either data type.

By now, the experiments from the previous lab are likely complete. Please return to [02a_Intermediate_Fusion](02a_Intermediate_Fusion.ipynb) to see the results. Please also run the cell below to free up resources for future labs.

In [None]:
import IPython
app = IPython.Application.instance()
app.kernel.do_shutdown(True)

<center><a href="https://www.nvidia.com/en-us/training/"><img src="https://dli-lms.s3.amazonaws.com/assets/general/DLI_Header_White.png" width="400" height="186" /></a></center>