In [1]:
import os
import cv2
import torch
import albumentations as A

import config as CFG


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/storage/homefs/yc24j783/miniconda3/envs/pyg/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/storage/homefs/yc24j783/miniconda3/envs/pyg/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/storage/homefs/yc24j783/.local/lib/python3.9/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/storage/homefs/yc24j783/.local/lib/python3.9/site-packages/traitlets/config/application.py

In [2]:
class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, image_filenames, captions, tokenizer, transforms):
        """
        image_filenames and captions must have the same length; so if there are multiple captions for each image, 
        the image_filenames mush have repetitive file names
        """

        self.image_filenames = image_filenames
        self.captions = list(captions)
        self.encoded_captions = tokenizer(
            list(captions), padding=True, truncation=True, max_length=CFG.max_length
        ) 
        self.transforms = transforms
    
    def __getitem(self, idx):
        """
        returns a dictionary containing the image tensor and the caption.
        """
        item = {
            key: torch.tensor(values[idx]) for key, values in self.encoded_captions.items()
        }

        image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transforms(image=image)["image"]
        item["image"] = torch.tensor(image).permute(2, 0, 1).float()
        item['caption'] = self.captions[idx]

        return item
    
    def __len__(self):
        """ Returns the total number of items in the dataset, which is the length of the captions list """
        return len(self.captions)

In [None]:
def get_transforms(mode="train"):
    """Yu: I don't understand why the current implementation doesn't differentiate between modes"""
    if mode == "train":
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )
    else:
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )