In [None]:
from tqdm import tqdm

import torch
rnd_seed = 0
torch.manual_seed(0 + rnd_seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import h5py
import numpy as np

from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import EMNIST, CIFAR100

from module import Binarize, Dilate, IrregularMaskDataset, GenerateDiffraction

In [None]:
train = False
valid = False # valid or test

shape_dataset = EMNIST(
    root='./datasets/',
    split='balanced',
    train=train,
    download=True,
    transform=transforms.Compose([Binarize(0.1),
                                  Dilate((3, 7), False),
                                  transforms.Pad((18, 18), fill=0),
                                  transforms.RandomAffine(90, scale=(0.8, 1.5), fill=0),
                                  transforms.CenterCrop(64),
                                  transforms.ToTensor(),
                                  ])
    )
# Scale argument in RandomAffine adjust object size
# Adjust object size to control oversampling ratio
# Oversampling ratio is crucial in network performance
# Note that size of EMNIST images is 28-by-28

density_dataset = CIFAR100(
    root='./datasets/',
    train=train,
    download=True,
    transform=transforms.Compose([transforms.Grayscale(),
                                  transforms.RandomResizedCrop(64, antialias=True),
                                  transforms.ToTensor(),
                                  ])
    )

mask_dataset = IrregularMaskDataset(
    root='./datasets/',
    train=train
    )

In [None]:
batch_size = 48

shape_generator = torch.Generator()
density_generator = torch.Generator()
mask_generator = torch.Generator()

shape_dataloader = DataLoader(dataset=shape_dataset, batch_size=batch_size, shuffle=True,
                              num_workers=12, pin_memory=True, drop_last=True, generator = shape_generator)

density_dataloader = DataLoader(dataset=density_dataset, batch_size=batch_size, shuffle=True,
                                num_workers=12, pin_memory=True, drop_last=True, generator = density_generator)

mask_dataloader = DataLoader(dataset=mask_dataset, batch_size=batch_size, shuffle=True,
                             num_workers=12, pin_memory=True, drop_last=True, generator = mask_generator)

In [None]:
if train:
    data_size = 96000
    h5path = './datasets/dataset_train_n{}k.h5'.format(data_size // 1000)
else:
    data_size = 12000
    if valid:
        h5path = './datasets/dataset_valid_n{}k.h5'.format(data_size // 1000)
    else:
        h5path = './datasets/dataset_test_n{}k.h5'.format(data_size // 1000)
    
num_iter = data_size // batch_size

# initialize random numer generator
torch.manual_seed(1 + rnd_seed)
shape_dataloader.generator.manual_seed(2 + rnd_seed)
density_dataloader.generator.manual_seed(3 + rnd_seed)
mask_dataloader.generator.manual_seed(4 + rnd_seed)

shape_iterator = iter(shape_dataloader)
density_iterator = iter(density_dataloader)
mask_iterator = iter(mask_dataloader)

# generate data
f = h5py.File(h5path, 'a')
dinput = f.create_dataset('input', (data_size, 1, 512, 512), dtype=np.single)
dtarget = f.create_dataset('target', (data_size, 1, 64, 64), dtype=np.single)
dmask = f.create_dataset('mask', (data_size, 1, 512, 512), dtype=np.bool_)

if not train and not valid:
    # pass dataset for validation
    for i in tqdm(range(num_iter)):
        try:
            shape, _ = next(shape_iterator)
        except StopIteration:
            shape_iterator = iter(shape_dataloader)
            shape, _ = next(shape_iterator)
            
        try:
            density, _ = next(density_iterator)
        except StopIteration:
            density_iterator = iter(density_dataloader)
            density, _ = next(density_iterator)
            
        try:
            mask = next(mask_iterator)
        except StopIteration:
            mask_iterator = iter(mask_dataloader)
            mask = next(mask_iterator)

for i in tqdm(range(num_iter)):    
    try:
        shape, _ = next(shape_iterator)
    except StopIteration:
        shape_iterator = iter(shape_dataloader)
        shape, _ = next(shape_iterator)
        
    try:
        density, _ = next(density_iterator)
    except StopIteration:
        density_iterator = iter(density_dataloader)
        density, _ = next(density_iterator)
        
    try:
        mask = next(mask_iterator)
    except StopIteration:
        mask_iterator = iter(mask_dataloader)
        mask = next(mask_iterator)
    
    # check empty object
    if torch.any(torch.sum(shape * density, dim=(-2, -1)) == 0):
        ind = torch.nonzero(torch.sum(shape * density, dim=(-2, -1)) == 0, as_tuple=True)[0]
        density[ind, :, :, :] = 1
        
    shape = shape.to(device)
    density = density.to(device)
    
    input, target = GenerateDiffraction(shape * density, ph_ord=6, l_coh=200, false_scale=True, device=device)
    
    input = input.detach().cpu()
    target = target.detach().cpu()
    
    mask = mask > 0
    
    input = input.numpy()
    target = target.numpy()
    mask = mask.numpy()

    dinput[batch_size * i:batch_size * (i + 1), :, :, :] = input
    dtarget[batch_size * i:batch_size * (i + 1), :, :, :] = target
    dmask[batch_size * i:batch_size * (i + 1), :, :, :] = mask
    
f.close()