### Demo of Crop_Dataset dataloader

Should (eventually) support loading `CA17`, `ON17` (both partial and full annotations) and `SB16` datasets.

For now has only been tested on `CA17` with full annotations (as there is structurally no difference from `ON17` data that should work too).

Input `sample_size` will ensure slicing the input images to samples of given size.

**TODO**:
- partial annotations `CA17`/`ON17`
- `SB16`

Functions to collate the batches into a 4D tensors instead of tuples of 3D images.

In [1]:
def cat_list(images, fill_value=0):
    max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
    batch_shape = (len(images),) + max_size
    batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
    for img, pad_img in zip(images, batched_imgs):
        pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
    return batched_imgs

# Collates into a 4th tensor dimension rather than tuples:
def collate_fn(batch):
    images, targets = list(zip(*batch))
    batched_imgs = cat_list(images, fill_value=0)
    batched_targets = cat_list(targets, fill_value=255)
    return batched_imgs, batched_targets


Dataset settings:
- input folder
- `sample_size` to slice the input images
- `samples_per_image` to control the number of input slices if less than allowed by `sample_size`

**Note:** Customised transforms for segmentation-type datasets are used from `vegseg_transforms.py` (local file). These ensure that the (image, mask) pairs are correctly transformed together. (taken from a VOC example somewhere)

In [2]:
from skimage import io
import torch
import torchvision
import torchvision.transforms as transforms

from vegseg_transforms import Normalize, Compose, Resize, ToTensor

import crop_datasets

#DATA_ROOT = './voc/'

CARROT_ROOT = "/home/pbosilj/Data/CA17/carrots_labelled"
#sample_size = (int(384/2),int(512/2))
sample_size = (384,512)
samples_per_image = (4,4)

Get dataset image stats (mean and std).

This is done by loading all the data at once and calculating mean and std accross all data. This is then used to set up a normalised dataloader.

In [3]:
my_transforms = ToTensor()



carrots_train = crop_datasets.Crop_Dataset(root=CARROT_ROOT,
                                           train=True,
                                           partial_truth=True,
                                           sample_size=sample_size,
                                           samples_per_image=samples_per_image,
                                           transforms=ToTensor())

print("{} images in the training set".format(len(carrots_train)))

single_loader = torch.utils.data.DataLoader(carrots_train, batch_size=len(carrots_train), num_workers=1)
data, labels = next(iter(single_loader))

d_mean = data.mean(axis=(0,2,3))
d_std = data.std(axis=(0,2,3))
    
print("Dataset mean: {} std: {}".format(d_mean, d_std))

File 'train.txt' does not exist in root. Looking for 'test.txt'
256 images in the training set
Dataset mean: tensor([0.6396, 0.5730, 0.4229, 0.5532]) std: tensor([0.1391, 0.1122, 0.0918, 0.1718])


Get datasets stats class. This can be used when setting up the loss criterion:

In [4]:
class_probs = carrots_train.get_class_probability()
class_weights = 1.0/class_probs
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

print("Class probabilities: {}".format(class_probs))

Class probabilities: tensor([0.7844, 0.0781, 0.1374])


Set up a normalised dataset (using mean and std calculated above). Relies on customised transforms for segmentation data implemented in `vegseg_transforms.py`.

Set up a data loader which loads 4 by 4 images in a single batch (using the collate functions from above).

In [5]:
my_transform = Compose([ToTensor(), Normalize(d_mean, d_std)])

carrots_train_norm = crop_datasets.Crop_Dataset(
                    root = CARROT_ROOT,
                    train = True,
                    partial_truth = True,
                    sample_size = sample_size,
                    samples_per_image=samples_per_image,
                    transforms = my_transforms
                    )
#test_sampler = torch.utils.data.SequentialSampler(carrots_test_norm)

carrot_loader_train = torch.utils.data.DataLoader(
    carrots_train_norm, batch_size=4,
    num_workers=1,
    collate_fn=collate_fn)

File 'train.txt' does not exist in root. Looking for 'test.txt'


In [7]:
print(carrots_train_norm.get_ignore_index())

255


Test the data loader by looping through three batches and printing some basic stats. **Note:** this example is using samples of half width and length from the paper, just to test stuff.

In [9]:
import numpy as np

for i, (image, target) in enumerate(carrot_loader_train):
    print(type(image), image.size())
    print(type(target), target.size())
    
    np_image = image.numpy()
    np_target = target.numpy()
    
    print("Image  max: {} min: {}".format(np.amax(np_image), np.amin(np_image)))
    print("Target max: {} min: {}".format(np.amax(np_target), np.amin(np_target)))
    print("Unique labels: {}".format(np.unique(np_target)))
    print()
    
    if i > 3:
        break

<class 'torch.Tensor'> torch.Size([4, 4, 384, 512])
<class 'torch.Tensor'> torch.Size([4, 384, 512])
Image  max: 1.0 min: 0.09019608050584793
Target max: 255 min: 0
Unique labels: [  0   1   2 255]

<class 'torch.Tensor'> torch.Size([4, 4, 384, 512])
<class 'torch.Tensor'> torch.Size([4, 384, 512])
Image  max: 1.0 min: 0.13725490868091583
Target max: 255 min: 0
Unique labels: [  0   1   2 255]

<class 'torch.Tensor'> torch.Size([4, 4, 384, 512])
<class 'torch.Tensor'> torch.Size([4, 384, 512])
Image  max: 1.0 min: 0.08235294371843338
Target max: 255 min: 0
Unique labels: [  0   1   2 255]

<class 'torch.Tensor'> torch.Size([4, 4, 384, 512])
<class 'torch.Tensor'> torch.Size([4, 384, 512])
Image  max: 1.0 min: 0.07450980693101883
Target max: 255 min: 0
Unique labels: [  0   1   2 255]

<class 'torch.Tensor'> torch.Size([4, 4, 384, 512])
<class 'torch.Tensor'> torch.Size([4, 384, 512])
Image  max: 1.0 min: 0.10588235408067703
Target max: 255 min: 0
Unique labels: [  0   1   2 255]

