## Build dataloader

In [1]:
import h5py
import numpy as np
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.utils.data
import torchvision.utils as vutils

from torch_connectomics.utils.seg.aff_util import affinitize, seg_to_affgraph
from torch_connectomics.utils.seg.seg_util import widen_border, mknhood3d
from torch_connectomics.data.dataset import AffinityDataset
from torch_connectomics.data.utils import collate_fn
from torch_connectomics.data.augmentation import *

from matplotlib import pyplot as plt

In [2]:
def show_imgs(imgs, label=None, cmap=None):
    plt.figure(figsize=(16,4))
    for i in range(1,5):
        plt.subplot('14%d' % (i))
        if cmap is not None:
            plt.imshow(imgs[i-1], cmap=cmap)
        else:
            plt.imshow(imgs[i-1])
        if label is not None:
            plt.title(label+' '+str(i))
    plt.show()

### 1. Load data

For Harvard Research Computing (RC) cluster users, you can directly access the dsta directory if you have access to the `coxfs01` partition. For external users please check the tutotial for downloading the dataset and change the `data_path` accordingly.

In [3]:
data_path = '/n/coxfs01/zudilin/data/SNEMI3D/'
image_path = data_path + 'train_image.h5'
label_path = data_path + 'train_label.h5'
image = np.array(h5py.File(image_path, 'r')['main'])
label = np.array(h5py.File(label_path, 'r')['main'])
print(image.shape, image.ndim, image.dtype)
print(label.shape, label.ndim, label.dtype)

(100, 1024, 1024) 3 uint8
(100, 1024, 1024) 3 uint16


### 2. Build dataloader

In [4]:
model_io_size = (8, 256, 256)
# setup augmentor
augmentor = Compose([Rotate(p=1.0),
                        Rescale(p=0.5),
                        Flip(p=1.0),
                        Elastic(alpha=10.0, p=0.5),
                        Grayscale(p=0.75)], 
                        input_size = model_io_size)

print('data augmentation: ', augmentor is not None)

dataset = AffinityDataset(volume=[image / 255.0], label=[label], sample_input_size=augmentor.sample_size,
                          sample_label_size=augmentor.sample_size, augmentor=augmentor, mode = 'train')    
img_loader =  torch.utils.data.DataLoader(
        dataset, batch_size=8, shuffle=True, collate_fn = collate_fn,
        num_workers=1, pin_memory=True)

Sample size required for the augmentor: [  8 477 477]
data augmentation:  True


In [None]:
for iteration, (_, volume, label, class_weight, _) in enumerate(img_loader):
    if iteration==0:
        print(volume.size(), label.size())
        show_imgs(volume[0, 0, 2:6].detach().numpy(), 'image', cmap='gray')
        show_imgs(label[0, 0, 2:6].detach().numpy(), 'z_aff')
        show_imgs(label[0, 1, 2:6].detach().numpy(), 'y_aff')
        show_imgs(label[0, 2, 2:6].detach().numpy(), 'x_aff')
        break