## Dataset download
https://machinelearningmastery.com/develop-a-deep-learning-caption-generation-model-in-python/

In [None]:
!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip

In [None]:
!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip

In [None]:
!unzip Flickr8k_Dataset.zip

In [None]:
!unzip Flickr8k_text.zip

## Imports

In [None]:
!pip install wandb

In [None]:
!pip install transformers

In [None]:
!pip install pygments -U

In [None]:
!pip install lightning -U

In [None]:
import torch
import torchvision.transforms as transforms
import cv2
import wandb
import pandas as pd
import os

## Creating DataLoader

In [None]:
from abc import abstractmethod
class ImageRetrievalDataset(torch.utils.data.Dataset):
  def __init__(self, artifact_id, tokenizer=None, target_size=None, max_length=200, lazy_loading=False):
    super().__init__()
    self.artifact_id = artifact_id
    self.target_size = target_size
    self.max_length = max_length
    self.lazy_loading = lazy_loading
    self.image_files, self.captions = self.fetch_dataset()
    self.images = self.image_files

    assert tokenizer is not None

    self.tokenizer = tokenizer

    self.tokenized_captions = tokenizer(
        list(self.captions), padding=True, truncation=True,
        max_length=self.max_length, return_tensors='pt'
    )
    self.transforms = transforms.Compose([
        transforms.Resize(target_size, target_size, always_apply=True),
        transforms.Normalize(max_pixel_value=255.0, always_apply=True)
    ])

  @abstractmethod
  def fetch_dataset():
    pass

  def __len__(self):
    return len(self.captions)

  def __getitem__(self, index):
    item = {
        key: values[index]
        for key, values in self.tokenized_captions.items()
    }
    image = cv2.imread(self.image_files[index])
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    iamge = self.transforms(image=image)["image"]
    item["image"] = torch.tensor(image).permute(2, 0, 1).float()
    item["caption"] = self.captions[index]
    return item

In [None]:
class Filckr8kDatasert(ImageRetrievalDataset):
  def __init__(self, artifact_id, tokenizer=None, target_size=None, max_length=100, lazy_loading=False):
    super.__init__(artifact_id, tokenizer, target_size, max_length, lazy_loading)

  def fetch_dataset(self):
    if wandb.run is None:
      api = wandb.Api()
      artifact = api.artifact(self.artifact_id, type="dataset")
    else:
      articact = wandb.use_artifact(self.artifact_id, type="dataset")

    artifact_dir = artifact.download()
    annotations = pd.read_csv(os.path.join(artifact_dir, "captions.txt"))
    image_files = [
        os.path.join(artifact_dir, "Images", image_file)
        for image_file in annotations["image"].to_list()
    ]
    for image_file in image_files:
      assert os.path.isfile(image_file)
    captions = annotations["caption"].to_list()
    return image_files, captions

## DataModule

In [None]:
from typing import Optional
from torch.utils.data import random_split, DataLoader
from pytorch_lightning import LightningDataModule
from transformers import AutoTokenizer

DATASET_LOOKUP = {
    "flickr8k":  Flickr8kDataset
}