# Dataset

> Dev notebook

In [1]:
# | default_exp dataset

In [2]:
# | hide
%reload_ext autoreload
%reload_ext nb_black
%autoreload 2
from nbdev.showdoc import *
import sys

__root = "../"
sys.path.append(__root)


<IPython.core.display.Javascript object>

In [3]:
# | export
from torch_snippets import *
from torch_snippets.imgaug_loader import iaa
from transformers import DistilBertTokenizer
from clip.core import *
from clip.config import ClipConfig


<IPython.core.display.Javascript object>

In [4]:
# | export


def normalize(images, random_state, parents, hooks):
    images = [img / 255 for img in images]
    return images


def get_transforms(config):
    return iaa.Sequential(
        [
            iaa.Resize({"height": config.size, "width": config.size}),
            iaa.Lambda(normalize),
        ]
    )


class CLIPDataset(Dataset):
    def __init__(self, df, config, mode):
        """
        image_filenames and cpations must have the same length; so, if there are
        multiple captions for each image, the image_filenames must have repetitive
        file names
        """
        self.config = config
        self.tokenizer = DistilBertTokenizer.from_pretrained(
            config.distilbert_text_tokenizer
        )
        self.image_filenames = df.image.tolist()
        self.captions = df.caption.tolist()
        with notify_waiting(f"Creating encoded captions for {mode} dataset..."):
            self.encoded_captions = self.tokenizer(
                self.captions,
                padding=True,
                truncation=True,
                max_length=config.max_length,
            )
        self.transforms = get_transforms(config)

    def __getitem__(self, idx):
        item = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_captions.items()
        }

        image = read(f"{self.config.image_path}/{self.image_filenames[idx]}", 1)
        image = self.transforms(image=image)
        item["image"] = torch.tensor(image).permute(2, 0, 1).float()
        item["caption"] = self.captions[idx]
        return item

    def __len__(self):
        return len(self.captions)

    @classmethod
    def train_test_split(cls, config):
        dataframe = pd.read_csv(config.captions_csv_path)
        max_id = dataframe["id"].max() + 1 if not config.debug else 100
        image_ids = np.arange(0, max_id)
        np.random.seed(42)
        valid_ids = np.random.choice(
            image_ids, size=int(0.2 * len(image_ids)), replace=False
        )
        train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
        train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(
            drop=True
        )
        valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(
            drop=True
        )
        return cls(train_dataframe, config, mode="train"), cls(
            valid_dataframe, config, mode="valid"
        )


def build_clip_data_loaders(config):
    trn_ds, val_ds = CLIPDataset.train_test_split(config)

    return (
        DataLoader(
            trn_ds,
            batch_size=config.batch_size,
            num_workers=config.num_workers,
            shuffle=True,
        ),
        DataLoader(
            val_ds,
            batch_size=config.batch_size,
            num_workers=config.num_workers,
            shuffle=False,
        ),
    )



<IPython.core.display.Javascript object>

In [None]:
from torch_snippets import *
from clip.core import download_flickr8k_from_kaggle
from clip.config import ClipConfig
from clip.dataset import CLIPDataset
from clip.models import CLIP

CLIPDataset.train_test_split()

In [5]:
# | hide
import nbdev
nbdev.nbdev_export()
import subprocess

subprocess.run(["/home/yyr/anaconda3/envs/mcvp-book/bin/black", __root])


Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
Skipping .ipynb files as Jupyter dependencies are not installed.
You can fix this by running ``pip install "black[jupyter]"``
reformatted /mnt/347832F37832B388/projects/MCVP2e/Chapter-15b/CLIP/clip/core.py
reformatted /mnt/347832F37832B388/projects/MCVP2e/Chapter-15b/CLIP/clip/config.py
reformatted /mnt/347832F37832B388/projects/MCVP2e/Chapter-15b/CLIP/clip/models.py
reformatted /mnt/347832F37832B388/projects/MCVP2e/Chapter-15b/CLIP/clip/dataset.py
reformatted /mnt/347832F37832B388/projects/MCVP2e/Chapter-15b/CLIP/clip/_modidx.py

All done! ✨ 🍰 ✨
5 files reformatted, 3 files left unchanged.


CompletedProcess(args=['/home/yyr/anaconda3/envs/mcvp-book/bin/black', '../'], returncode=0)

<IPython.core.display.Javascript object>