In [2]:
%load_ext autoreload
%autoreload 2

import sys

import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader
from torchvision import transforms as T

sys.path.append("../")
from SADE.assets.dataset import LMDBImageDataset
from SADE.assets.utils import create_image_grid, tensor_to_pil

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
# Path to LMDB directory with images
samples_path = "../experiments/<examples-lmdb>/"

In [13]:
image_size = (64, 64)
label_length = 0

In [14]:
dataset = LMDBImageDataset(samples_path, label_length=label_length)

In [None]:
len(dataset)

In [16]:
def collate_fn(examples):
    image_data = torch.stack([x[0] for x in examples])
    labels = [x[1] for x in examples]
    return {"image_data": image_data, "labels": labels}

In [23]:
batch_size = 256
image_transforms = T.Compose(
    [
        T.Resize(image_size),
        T.Grayscale(),
        T.ToTensor(),
        T.Normalize([0.5], [0.5])
    ]
)
dataset = LMDBImageDataset(
    samples_path,
    transform=image_transforms,
    label_length=label_length,
)
dataloader = DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
)


In [24]:
def bmw_case(sequence):
    return "".join([c.lower() if c in ['I', 'D', 'E'] else c for c in sequence])

In [None]:
num_columns = 3
num_rows = 2


it = iter(dataloader)

try:
    for _ in range(10):   
        gen_samples, part_labels = next(it).values()
        gen_samples = [tensor_to_pil(sample, mode='L') for sample in gen_samples]

        print(len(gen_samples))
        print(len(part_labels))

        bs = num_columns * num_rows
        for i in range(len(gen_samples)//bs):
            l = part_labels[i*bs:(i*bs+bs)]
            grid = create_image_grid(gen_samples[i*bs:(i*bs+bs)], num_columns=num_columns, num_rows=num_rows, labels=l)
            plt.imshow(grid)
            plt.axis('off')
            plt.show()
except StopIteration:
    print("Stopped")
    pass