# Understanding BERT and how to use it :D

In [None]:
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer, BertModel, BertConfig, BertTokenizer
import torch

In [None]:
# caption = "An x-ray image of a shoulder with an abnoramility"
caption = "Pizza is amazing"
caption2 = "Pizza is amazing"
caption3 = "Everthing is going to end. But I am sure that this is not the end yet"

captions = [caption, caption2, caption3]

In [None]:
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
tokens = tokenizer(captions, return_tensors="pt", padding=True, truncation=True, max_length=20)
print(tokens)
print(tokens.input_ids.shape)

configuration = DistilBertConfig()
distilbert = DistilBertModel(configuration)
distilbert.eval()

with torch.no_grad():
    outputs = distilbert(**tokens)

last_hidden_states = outputs.last_hidden_state
print(last_hidden_states.shape)

sentence_embeddings = last_hidden_states[:, 0, :]  # Use the first token (CLS token) for sentence embeddings

print(sentence_embeddings.shape)
cosine_similarities = sentence_embeddings @ sentence_embeddings.T
cosine_similarities = cosine_similarities / (sentence_embeddings.norm(dim=1, keepdim=True) * sentence_embeddings.norm(dim=1, keepdim=True).T)
print(cosine_similarities)

In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokens = tokenizer(captions, return_tensors="pt", padding=True, truncation=True, max_length=20)
print('Tokens:', tokens)
print('Tokens input ids shape:', tokens.input_ids.shape)

configuration = BertConfig()
bert = BertModel(configuration)
bert.eval()

with torch.no_grad():
    outputs = bert(**tokens, output_attentions=True)

last_hidden_states = outputs.last_hidden_state
print('Last hidden states shape:', last_hidden_states.shape)

print('Attention shape:', outputs.attentions[0].shape)  # Shape of the first attention layer
print('Attentions:', outputs.attentions)

sentence_embeddings = last_hidden_states[:, 0, :]  # Use the first token (CLS token) for sentence embeddings

print(sentence_embeddings.shape)
cosine_similarities = sentence_embeddings @ sentence_embeddings.T
cosine_similarities = cosine_similarities / (sentence_embeddings.norm(dim=1, keepdim=True) * sentence_embeddings.norm(dim=1, keepdim=True).T)
print(cosine_similarities)

# Trying a simple training

In [None]:
import os
import torch
from datasets import load_dataset
from torchvision import transforms

class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer):
        """
        image_filenames and cpations must have the same length; so, if there are
        multiple captions for each image, the image_filenames must have repetitive
        file names 
        """

        dataset = load_dataset("jxie/flickr8k", split="train[:256]")

        self.images = dataset[:]["image"]
        self.captions = dataset[:]["caption_0"]

        self.encoded_captions = tokenizer(
            list(self.captions), padding=True, truncation=True, max_length=1000 # TODO
        )

        self.transforms = transforms.Compose([
            transforms.Resize((224, 224)),      # Resize to fixed size
            transforms.ToTensor(),              # Convert to tensor, scales to [0,1]
            transforms.Normalize(               # Normalize with ImageNet stats
                mean=[0.485, 0.456, 0.406],     
                std=[0.229, 0.224, 0.225]
            )
        ])

    def __getitem__(self, idx):
        item = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_captions.items()
        }

        item['image'] = self.transforms(self.images[idx])
        item['caption'] = self.captions[idx]

        return item


    def __len__(self):
        return len(self.captions)


tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
dataset = CLIPDataset(tokenizer)
print("First sample:", dataset[0])
print("Dataset length:", len(dataset))

In [None]:
import timm
import torch.nn as nn

class ImageEncoder(nn.Module):
    """
    Encode images to a fixed size vector
    """

    def __init__(
        self
    ):
        super().__init__()
        self.model = timm.create_model(
            'resnet34', pretrained=True, num_classes=0, global_pool="avg"
        )

    def forward(self, x):
        return self.model(x)
    
image_encoder = ImageEncoder()
image = dataset[0]['image'].unsqueeze(0)  # Add batch dimension
image_embeddings = image_encoder(image)
print("Image embedding shape:", image_embeddings.shape)

In [None]:
class TextEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = DistilBertModel.from_pretrained("distilbert-base-uncased")
            
        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :]
    
text_encoder = TextEncoder()
input_ids = dataset[0]['input_ids'].unsqueeze(0)  # Add batch dimension
attention_mask = dataset[0]['attention_mask'].unsqueeze(0)  # Add batch dimension
text_embedding_dim = text_encoder(input_ids, attention_mask)
print("Text embedding shape:", text_embedding_dim.shape)

In [None]:
import lightning as L
import numpy as np
import torch.nn.functional as F

class CLIPModule(L.LightningModule):
    def __init__(
        self,
        image_embedding_dim=512,
        text_embedding_dim=768,
        embedding_dim=128,
        lr=1e-4
    ):
        super().__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()

        # freezing the image and text encoders
        # This way I simply use the pre-trained weights of the encoders and only train the projections
        for param in self.image_encoder.parameters():
            param.requires_grad = False
        for param in self.text_encoder.parameters():
            param.requires_grad = False

        self.image_projection = nn.Parameter(torch.empty(image_embedding_dim, embedding_dim))
        nn.init.normal_(self.image_projection, std=image_embedding_dim ** -0.5)
        self.text_projection = nn.Parameter(torch.empty(text_embedding_dim, embedding_dim))
        nn.init.normal_(self.text_projection, std=text_embedding_dim ** -0.5)

        self.temperature = nn.Parameter(torch.empty(1))
        nn.init.constant_(self.temperature, 0.0)  # Initialize temperature to 0.0

        self.lr = lr

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        return optimizer

    def forward(self, batch):
        # Getting Image and Text Features
        image_features = self.image_encoder(batch["image"])
        text_features = self.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        # print("Image features shape:", image_features.shape)
        # print("Image features:", image_features)
        # print("Text features shape:", text_features.shape)
        # print("Text features:", text_features)

        # Getting Image and Text Embeddings (with same dimension)
        image_embeddings = image_features @ self.image_projection
        text_embeddings = text_features @ self.text_projection

        # print("Image embeddings shape:", image_embeddings.shape)
        # print("Image embeddings:", image_embeddings)
        # print("Text embeddings shape:", text_embeddings.shape)
        # print("Text embeddings:", text_embeddings)

        # Calculating the Loss
        logits = (image_embeddings @ text_embeddings.T) * torch.exp(self.temperature)
        
        return logits
    
    def training_step(self, batch, batch_idx):
        logits = self(batch)
        labels = torch.arange(len(batch["image"]))

        # print("Logits:", logits)
        # print("Labels:", labels)

        image_loss = F.cross_entropy(logits, labels, reduction='mean')
        text_loss = F.cross_entropy(logits.T, labels, reduction='mean')
        loss = (image_loss + text_loss) / 2

        # print("Image loss:", image_loss.item())
        # print("Text loss:", text_loss.item())
        # print("Total loss:", loss.item())

        self.log("image_loss", image_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("text_loss", text_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        print("Training step loss:", loss.item())

        return loss
    
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=3,
    shuffle=False,
    num_workers=1,
)
batch = next(iter(dataloader))
clip_model = CLIPModule()
clip_model(batch)

In [None]:
# checking dimensions for projection
image = dataset[0]['image'].unsqueeze(0)  # Add batch dimension
image_embeddings = image_encoder(image)

image_projection = nn.Parameter(torch.empty(512, 128))
print("Image embeddings shape:", image_embeddings.shape)
print("Image projection shape:", image_projection.shape)
final_embeddings = image_embeddings @ image_projection
print("Image embedding shape:", final_embeddings.shape)

In [None]:
train_dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)
trainer = L.Trainer(
    max_epochs=20,
    accelerator="cpu",
    devices=1,
    log_every_n_steps=1
)
clip_model = CLIPModule(lr=1e-3)
trainer.fit(clip_model, train_dataloader)

In [None]:
def get_image_embeddings():
    # get best model from the trainer
    best_model_path = trainer.checkpoint_callback.best_model_path
    print("Best model path:", best_model_path)
    model = CLIPModule.load_from_checkpoint(best_model_path)
    model.eval()

    image_embeddings = []
    with torch.no_grad():
        for batch in dataloader:
            image_features = model.image_encoder(batch["image"])
            image_embeddings.append(image_features @ model.image_projection)
    image_embeddings = torch.cat(image_embeddings, dim=0)
    print("Image embeddings shape:", image_embeddings.shape)
    return image_embeddings

# Masking out logits for duplicate captions


In [None]:
# lets pretend that the first and thrid captions are the same
captions = ["caption1", "caption2", "caption1"]
# we get logits like this
logits = torch.Tensor([[1.0, 0.5, 1.0],
                      [0.5, 1.0, 0.3],
                      [1.0, 0.3, 1.0]])

# then we want to maks out the logits at [1, 3] and [3, 1] like so:
logits_desired_result = torch.Tensor([[1.0, 0.5, 0.0],
                                      [0.5, 1.0, 0.3],
                                      [0.0, 0.3, 1.0]])

def get_mask(captions):
    """
    Get a mask for the logits to set the values that are not on the diagonal of the captions to 0.0
    """
    mask = torch.ones((len(captions), len(captions)))
    for i in range(len(captions)):
        for j in range(len(captions)):
            if i == j:
                continue
            if captions[i] == captions[j]:
                mask[i, j] = 0.0
    return mask

# now I asked chatgpt for a more pytorch/numpy way to do this
def get_mask_pytorch(captions):
    # Step 1: Convert strings to unique indices
    unique_captions = {caption: idx for idx, caption in enumerate(set(captions))}
    caption_ids = torch.tensor([unique_captions[c] for c in captions])

    # Step 2: Create comparison matrix
    eq = caption_ids.unsqueeze(0) == caption_ids.unsqueeze(1)  # shape: (N, N)

    # Step 3: Create mask
    mask = torch.ones_like(eq, dtype=torch.float)
    mask[eq & ~torch.eye(len(captions), dtype=torch.bool)] = 0.0

    return mask

mask = get_mask(captions)
mask_pytorch = get_mask_pytorch(captions)
logits_masked = logits * mask
logits_masked_pytorch = logits * mask_pytorch

print("Logits before masking:\n", logits)
print("Mask:\n", mask)
print("Logits after masking:\n", logits_masked)
print("Logits after masking (pytorch):\n", logits_masked_pytorch)
print("Desired result:\n", logits_desired_result)
print("Are logits masked correctly?", torch.allclose(logits_masked, logits_desired_result, atol=1e-6))
print("Are logits masked correctly (pytorch)?", torch.allclose(logits_masked_pytorch, logits_desired_result, atol=1e-6))

# Non Square Matrix Loss

Problem: I have only 22 different captions in my pretraining dataset. That means captions will appear multiple times.

Illustration: Lets consider the logits, so the cosine similarity between image and caption embeddings. The rows correspond to images and the columns to features

What CE loss maximizes: X, the rest it minimizes

|   |  caption1 | caption2  | caption1  | caption2  |
|---|---|---|---|---|
|  img1 |  XO |   | O  |   |
| img2  |   | X  |   |   |
| img3  |   |   | X  |   |
| img4  |   |   |   | X  |

The both O's are the same and get computed the same. The first one is maximized and the second one minimized, meaning that the loss tries to maximize and minimize the same thing at the same time.

Idea: Remove duplicated columns and maximize accordingly

|   |  caption1 | caption2 
|---|---|---|
|  img1 |  X |   |
| img2  |   | X  |
| img3  | X  |   |
| img4  |   | X  |
            




In [None]:
images = torch.randn(4, 3, 224, 224)  # Example images
captions = ["caption1", "caption2", "caption1", "caption1"]

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
captions_tokenized = tokenizer(captions, return_tensors="pt", padding=True, truncation=True, max_length=20)

image_encoder = ImageEncoder()
image_embedding_dim = 512
text_encoder = TextEncoder()
text_embedding_dim = 768

embedding_dim = 128  # Dimension for the final embeddings
temperature = 0.0

image_projection = nn.Parameter(torch.empty(image_embedding_dim, embedding_dim))
nn.init.normal_(image_projection, std=image_embedding_dim ** -0.5)
text_projection = nn.Parameter(torch.empty(text_embedding_dim, embedding_dim))
nn.init.normal_(text_projection, std=text_embedding_dim ** -0.5)

temperature = nn.Parameter(torch.empty(1))
nn.init.constant_(temperature, 0.0)  # Initialize temperature to 0.0

In [None]:
def get_logits(images, captions_tokenized):
    image_features = image_encoder(images)
    text_features = text_encoder(
        input_ids=captions_tokenized["input_ids"],
        attention_mask=captions_tokenized["attention_mask"]
    )

    image_embeddings = image_features @ image_projection
    text_embeddings = text_features @ text_projection

    # normalize embeddings
    image_embeddings = F.normalize(image_embeddings, dim=-1)
    text_embeddings = F.normalize(text_embeddings, dim=-1)

    # print("Image embeddings shape:", image_embeddings.shape)
    # print("Text embeddings shape:", text_embeddings.shape)
    logits = (image_embeddings @ text_embeddings.T) * torch.exp(temperature)
    return logits

logits = get_logits(images, captions_tokenized)



In [None]:
def compute_non_square_loss(logits, captions):
    # get duplicate captions
    _, caption_ids = np.unique(captions, return_inverse=True) # get the uniqueness by captions and not the embeddings, since they might actually differ due to dropout during training
    # turn into torch tensor
    caption_ids = torch.tensor(caption_ids, dtype=torch.int64)

    # there are multiple "classes" now to which we should maximize, we get them by getting the indices of the unique captions
    unique_vals = torch.unique(caption_ids)
    class_indices = [(caption_ids == val).nonzero(as_tuple=True)[0].tolist() for val in unique_vals]

    # given the unique caption ids, remove duplicate columns
    unique_ids = torch.unique(caption_ids, return_inverse=False, return_counts=False, sorted=False)
    # For each unique id, get the FIRST index where it occurs
    selected_indices = torch.stack([torch.where(caption_ids == uid)[0][0] for uid in unique_ids])
    # Select logits
    selected_logits = logits[:, selected_indices]
    # print("Selected logits shape:", selected_logits.shape)

    criterion = torch.nn.CrossEntropyLoss()
    labels = torch.zeros_like(selected_logits)
    for class_id, indices in enumerate(class_indices):
        labels[indices, class_id] = 1.0

    # print("Labels :", labels)

    loss_img = criterion(selected_logits, labels)
    loss_text = criterion(selected_logits.T, labels.T)

    print('loss_img (cross entropy loss for image):', loss_img.item())
    print('loss_text (bce for image):', torch.nn.BCEWithLogitsLoss()(selected_logits, labels).item())

    print("Loss image:", loss_img.item())
    print("Loss text:", loss_text.item())

    loss = (loss_img + loss_text) / 2

    return loss

print("Logits shape:", logits.shape)
print('Logits:', logits)
compute_non_square_loss(logits, captions)


In [None]:
print("Logits:\n", logits)

sigmoid_0 = torch.nn.functional.sigmoid(logits)

print("sigmoid along dim 0:\n", sigmoid_0)
# print("sigmoid along dim 1:\n", sigmoid_1)

compute_non_square_loss(logits, captions)

In [None]:
def compute_normal_loss(logits):
    labels = torch.arange(len(logits))

    image_loss = F.cross_entropy(logits, labels, reduction='mean')
    text_loss = F.cross_entropy(logits.T, labels, reduction='mean')

    print("Image loss:", image_loss.item())
    print("Text loss:", text_loss.item())

    loss = (image_loss + text_loss) / 2

    return loss

In [None]:
def compute_masked_loss(logits, captions):
    mask = get_mask_pytorch(captions)
    print("Mask:", mask)
    logits_masked = logits * mask
    loss = compute_normal_loss(logits_masked)
    return loss

In [None]:
loss_non_square = compute_non_square_loss(logits, captions)
loss_normal = compute_normal_loss(logits)
print("Computed normal loss:", loss_normal.item())
print("Computed non square loss:", loss_non_square.item())

In [None]:
# Sanity check, if no captions are duplicated
images = torch.randn(4, 3, 224, 224)  # Example images
captions_unique = ["caption1", "caption2", "caption3", "caption4"]
captions_tokenized_unique = tokenizer(captions_unique, return_tensors="pt", padding=True, truncation=True, max_length=20)

logits_unique = get_logits(images, captions_tokenized_unique)
loss_normal_unique = compute_normal_loss(logits_unique)
loss_masked_unique = compute_masked_loss(logits_unique, captions_unique)
loss_non_square_unique = compute_non_square_loss(logits_unique, captions_unique)

print("Computed normal loss (unique captions):", loss_normal_unique.item())
print("Computed masked loss (unique captions):", loss_masked_unique.item())
print("Computed non square loss (unique captions):", loss_non_square_unique.item())

In [None]:
# One caption duplicated
images = torch.randn(4, 3, 224, 224)  # Example images
captions_unique = ["caption1", "caption2", "caption3", "caption1"]
captions_tokenized_unique = tokenizer(captions_unique, return_tensors="pt", padding=True, truncation=True, max_length=20)

logits_unique = get_logits(images, captions_tokenized_unique)
loss_normal_unique = compute_normal_loss(logits_unique)
loss_masked_unique = compute_masked_loss(logits_unique, captions_unique)
loss_non_square_unique = compute_non_square_loss(logits_unique, captions_unique)

print("Computed normal loss (unique captions):", loss_normal_unique.item())
print("Computed masked loss (unique captions):", loss_masked_unique.item())
print("Computed non square loss (unique captions):", loss_non_square_unique.item())

## Precision @ k

In [None]:
def precision_at_k_on_image_embeddings(image_embeddings, labels, k=3):
    assert k + 1 <= image_embeddings.shape[0], "k+1 must be less than or equal to the batch size"

    # compute the cosine similarity between all image embedding pairs
    image_embeddings = torch.nn.functional.normalize(image_embeddings)
    similarity_matrix = image_embeddings @ image_embeddings.T  # [batch_size, batch_size]
    print("Similarity matrix:", similarity_matrix)
    # get the top k indices for each image embedding
    top_k_indices = similarity_matrix.topk(k=k+1, dim=1).indices  # [batch_size, k]
    # remove the first index, which is the image itself (self-similarity)
    print("Top k indices:", top_k_indices)
    top_k_indices = top_k_indices[:, 1:]  # [batch_size, k
    # check if the labels of the top k indices match the labels of the current image embedding
    correct_predictions = (labels.unsqueeze(1) == labels[top_k_indices]).sum(dim=1)  # [batch_size]
    # compute the precision at k
    precision_at_k = correct_predictions.float() / k  # [batch_size]
    return precision_at_k.mean()  # return the mean precision at k over the batch

In [None]:
image_embeddings = torch.Tensor([[1, 1], [1, 1.1],  [2, 1], [3, 1]])
labels = torch.Tensor([0, 0, 1, 1])
precision_at_k_on_image_embeddings(image_embeddings, labels, k=1)

In [None]:

image_embeddings_normalized_1 = F.normalize(image_embeddings)
cosine_similarities_1 = image_embeddings_normalized_1 @ image_embeddings_normalized_1.T
print("Cosine similarities (normalized):", cosine_similarities_1)