In [29]:
if __name__ == "__main__":
  # Mount Google Drive
  from google.colab import drive
  drive.mount('/content/drive/')

  # Copy imagen folder locally
  !cp -r /content/drive/MyDrive/imagen .

ModuleNotFoundError: No module named 'google.colab'

In [30]:
from cgitb import text
from pathlib import Path
from functools import partial

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T, utils
import torch.nn.functional as F

from PIL import Image
import numpy as np

import json
import glob



def exists(val):
    return val is not None

def cycle(dl):
    while True:
        for data in dl:
            yield data

def convert_image_to(img_type, image):
    if image.mode != img_type:
        return image.convert(img_type)
    return image

class COCODataset(Dataset):
    def __init__(
        self,
        image_folder,
        embedding_folder,
        annotations_folder,
        im_path_file,
        em_path_file,
        image_size,
        max_embed_size,
        exts = ['jpg', 'jpeg', 'png', 'tiff'],
        convert_image_to_type = None,
        save_paths = False
    ):
        super().__init__()

        self.image_folder = image_folder
        self.image_size = image_size
        self.max_embed_size = max_embed_size

        self.embedding_folder = embedding_folder

        if save_paths:
          self.save_id_path_dicts(im_path_file, em_path_file, exts)          

        self.load_id_path_dicts(im_path_file, em_path_file)
        
        with open(f"{annotations_folder}/captions_train2017.json", 'r') as f:
            self.annotations = json.load(f)

        self.annotations['annotations'] = self.annotations['annotations'][0::5] + self.annotations['annotations'][1::5]
            
        convert_fn = partial(convert_image_to, convert_image_to_type) if exists(convert_image_to_type) else nn.Identity()

        self.transform = T.Compose([
            T.Lambda(convert_fn),
            T.Resize(image_size),
            T.RandomHorizontalFlip(),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

        self.embed_keys = list(self.id_to_empath.keys())


    def load_id_path_dicts(self, id_to_im_path_file, id_to_em_path_file):
        with open(id_to_im_path_file, "r") as f:
            id_to_impath = json.load(f)
            self.id_to_impath = {int(k):v for k,v in id_to_impath.items()}
        with open(id_to_em_path_file, "r") as f:
            id_to_empath = json.load(f)
            self.id_to_empath = {int(k):v for k,v in id_to_empath.items()}

        # maximum = -1
        # for i, path in enumerate(self.id_to_empath.values()):
        #     if not (i%10000):
        #         print(i)
        #     embed = torch.load(path)
        #     maximum = max(embed.shape[1], maximum)
        # print(maximum)

    def save_id_path_dicts(self, id_to_im_path_file, id_to_em_path_file, exts):
        em_paths = glob.glob(self.embedding_folder + "/*/*.pt")
        em_ids = [int(str(path).split('_')[-1].split('.')[0]) for path in em_paths]
        id_to_empath = {idn: path for idn, path in zip(em_ids, em_paths)}

        with open(id_to_em_path_file, 'w') as f:
            json.dump(id_to_empath, f)

        im_paths = []
        for ext in exts:
            im_paths.extend(glob.glob(f"{self.image_folder}-*/*.{ext}"))
        print(len(im_paths))
        #[p for ext in exts for p in Path(f'{self.image_folder}').glob(f'**/*.{ext}')]
        print(im_paths[0])
        im_ids = [int(str(path).split('/')[-1].split('.')[0]) for path in im_paths]
        id_to_impath = {idn: str(path) for idn, path in zip(im_ids, im_paths)}

        with open(id_to_im_path_file, 'w') as f:
            json.dump(id_to_impath, f)



    def __len__(self):
        return len(self.embed_keys)

    def __getitem__(self, index):
        # annotation = self.annotations['annotations'][index]   
        # print(annotation['image_id'], annotation['id']) 
        # embed_id = annotation['id']

        embed_id = self.embed_keys[index]

        embed_path = self.id_to_empath[embed_id]
        text_embeds = torch.load(embed_path)
        text_embeds = text_embeds[:,:self.max_embed_size,:]
        text_embeds = F.pad(input = text_embeds, pad=(0, 0, 0, self.max_embed_size-text_embeds.shape[1], 0, 0), mode='constant', value=0)
        text_embeds = text_embeds[0,:,:]

        # image_id = annotation['image_id']
        image_id = int(embed_path.split('/')[-1].split('_')[0])

        img_path = self.id_to_impath[image_id]
        img = Image.open(img_path)
        img = np.array(img)
        
        if len(img.shape) == 2:
            img = img[..., np.newaxis]
            img = np.stack((img,)*3, axis=-1)
        if len(img.shape) == 4:
            img = img[:,:,0,:]


        img = Image.fromarray(img)

        return self.transform(img), text_embeds



In [31]:
if __name__ == "__main__":
    data = COCODataset('./data/train2017/train2017', './embeddings', './annotations', "./paths/im_paths.json", "./paths/em_paths.json", image_size = 64, max_embed_size=32, save_paths=False)
    print(len(data))

236936


In [32]:
if __name__ == "__main__":
  img, embed = data[0]
  print(embed.shape)
  # from google.colab.patches import cv2_imshow
  # cv2_imshow(np.moveaxis(img.numpy()*255, 0,-1))

torch.Size([3, 64, 64]) torch.Size([32, 768])
torch.Size([32, 768])


In [None]:
if __name__ == "__main__":
  !cp -r ./imagen/paths /content/drive/MyDrive/imagen/paths

cp: cannot stat './paths': No such file or directory
