In [1]:
import torch
from torch import nn
import torch.nn.functional as F

import config as CFG
from modules import ImageEncoder, TextEncoder, ProjectionHead


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/storage/homefs/yc24j783/miniconda3/envs/pyg/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/storage/homefs/yc24j783/miniconda3/envs/pyg/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/storage/homefs/yc24j783/.local/lib/python3.9/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/storage/homefs/yc24j783/.local/lib/python3.9/site-packages/traitlets/config/application.py

In [6]:
class CLIPModel(nn.Module):
    def __init__(
        self,
        temperature = CFG.temperature,
        image_embedding = CFG.image_embedding,
        text_embedding = CFG.text_embedding, 
    ):
        super().__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        self.temperature = temperature
    
    def forward(self, batch):
        # Getting Image and Text Features
        image_features = self.image_encoder(batch['image'])
        text_features = self.text_encoder(
            input_ids = batch['input_ids'], attention_mask = batch['attention_mask']
        )
        # Getting Image and Text Embeddings (with same dimension)
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)

        # Calculating the loss
        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        image_similarity = image_embeddings @ image_embeddings.T
        text_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax(
            (image_similarity + text_similarity) / 2 * self.temperature, dim=-1
        )
        texts_loss = cross_entropy(logits,targets, reduction='none')
        images_loss = cross_entropy(logits.T,targets.T, reduction='none')
        loss = (images_loss + texts_loss) / 2 # shape: (batch_size)

        return loss.mean()

def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == 'none':
        return loss
    elif reduction == 'mean':
        return loss.mean()

In [7]:
if __name__ == '__main__':
    images = torch.randn(8, 3, 224, 224)
    input_ids = torch.randint(5, 300, size=(8, 25))
    attention_mask = torch.ones(8, 25)
    batch = {
        'image': images,
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }

    CLIP = CLIPModel()
    loss = CLIP(batch)
    print("")


RuntimeError: mat1 and mat2 shapes cannot be multiplied (8x2048 and 256x256)