In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import nibabel as nib
from tqdm.notebook import tqdm
from torch.autograd import Variable
from torch.utils.data.sampler import SubsetRandomSampler

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
import pandas as pd
import os

ngpu = torch.cuda.device_count()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device, ngpu)

cuda:0 1


In [12]:
from batchgenerators.dataloading.data_loader import DataLoaderBase

In [2]:
class MSD(DataLoaderBase):
    def __init__(self, root, labeldir1, labeldir2, slices=1):
        self.slices = slices
        assert self.slices % 2 == 1, "slices must be odd!"
        es = int(self.slices/2) # floor operation
        img_fn = sorted(os.listdir(root + '/' + 'imagesTr'))
        seg_fn = sorted(os.listdir(root + '/' + labeldir1))
        seg_fn2 = sorted(os.listdir(root + '/' + labeldir2))
        self.datalist = []
        for i, (img_file, seg_file, seg_file2) in enumerate(zip(img_fn, seg_fn, seg_fn2)):
            assert img_file == seg_file, "Image and label files not from the same patient"
            if (img_file == seg_file):
                seg_img = nib.load(root + '/' + labeldir1 + '/' + seg_file)
                seg_img_data = seg_img.get_fdata()
                seg_img_data = np.array(seg_img_data)
                print(seg_img_data.shape)
                seg_img2 = nib.load(root + '/' + labeldir2 + '/' + seg_file2)
                seg_img_data2 = seg_img2.get_fdata()
                seg_img_data2 = np.array(seg_img_data2)
                print(seg_img_data2.shape)
                img = nib.load(root + '/' + 'imagesTr' + '/' + img_file)
                img_data = img.get_fdata()
                img_data = np.array(img_data)
                assert img_data.shape == seg_img_data.shape, "Image and Labels have different shapes"
                for j in range(img_data.shape[0]):
                    w = img_data.shape[1]
                    h = img_data.shape[2]

                    image2d = torch.from_numpy(img_data[j,:,:]).unsqueeze(dim=0)
                    label2d = torch.from_numpy(seg_img_data[j,:,:]).unsqueeze(dim=0)
                    label2d2 = torch.from_numpy(seg_img_data2[j,:,:]).unsqueeze(dim=0)
                    if (len(torch.unique(label2d))>=2):
                        self.datalist.append([ image2d , label2d, label2d2 ])

    def __getitem__(self, index):
        [img, seg, seg2] = self.datalist[index]
#         return {"A": img, "B": seg}
        return [img, seg, seg2]

    def __len__(self):
        return len(self.datalist)

In [3]:
root_dir = '../data/Task04_Hippocampus'
labeldir = 'labels_stage1'
labeldir_2 = 'labelsTr'

def my_collate(batch):
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    target2 = [item[2] for item in batch]
#     target = torch.LongTensor(target)
    
    return [data, target, target2]
validation_split = 0.2
shuffle_dataset = True
random_seed= 42
batch_size = 1

dataset = MSD(root = root_dir, labeldir1 = labeldir, labeldir2 = labeldir_2)

dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
# print(train_indices, val_indices)

(35, 51, 35)
(35, 51, 35)
(34, 52, 35)
(34, 52, 35)
(36, 52, 38)
(36, 52, 38)
(35, 52, 34)
(35, 52, 34)
(34, 47, 40)
(34, 47, 40)
(36, 48, 40)
(36, 48, 40)
(36, 50, 31)
(36, 50, 31)
(39, 50, 40)
(39, 50, 40)
(42, 51, 28)
(42, 51, 28)
(35, 48, 32)
(35, 48, 32)
(36, 47, 41)
(36, 47, 41)
(36, 46, 43)
(36, 46, 43)
(35, 51, 35)
(35, 51, 35)
(38, 52, 33)
(38, 52, 33)
(35, 48, 35)
(35, 48, 35)
(36, 50, 36)
(36, 50, 36)
(33, 48, 38)
(33, 48, 38)
(36, 49, 40)
(36, 49, 40)
(35, 47, 37)
(35, 47, 37)
(36, 47, 39)
(36, 47, 39)
(34, 51, 32)
(34, 51, 32)
(37, 51, 35)
(37, 51, 35)
(34, 53, 34)
(34, 53, 34)
(36, 52, 37)
(36, 52, 37)
(36, 51, 34)
(36, 51, 34)
(37, 52, 34)
(37, 52, 34)
(38, 48, 33)
(38, 48, 33)
(36, 48, 37)
(36, 48, 37)
(36, 49, 38)
(36, 49, 38)
(38, 52, 29)
(38, 52, 29)
(35, 51, 36)
(35, 51, 36)
(38, 49, 38)
(38, 49, 38)
(33, 54, 39)
(33, 54, 39)
(34, 52, 40)
(34, 52, 40)
(37, 51, 35)
(37, 51, 35)
(41, 47, 42)
(41, 47, 42)
(35, 51, 34)
(35, 51, 34)
(34, 53, 36)
(34, 53, 36)
(39, 52, 31)

In [4]:
from batchgenerators.transforms.color_transforms import ContrastAugmentationTransform
from batchgenerators.transforms.spatial_transforms import MirrorTransform
from batchgenerators.transforms.abstract_transforms import Compose
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter

In [10]:
my_transforms = []
brightness_transform = ContrastAugmentationTransform((0.3, 3.), preserve_range=True)
my_transforms.append(brightness_transform)
mirror_transform = MirrorTransform(axes=(0, 1))
my_transforms.append(mirror_transform)

all_transforms = Compose(my_transforms)
multithreaded_generator = MultiThreadedAugmenter(dataset, all_transforms, 4, 2, seeds=None)

In [11]:
new_batch = multithreaded_generator.next()

Process Process-7:
Traceback (most recent call last):
Process Process-8:
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Process Process-9:
  File "/home/rahul/.local/lib/python3.6/site-packages/batchgenerators/dataloading/multi_threaded_augmenter.py", line 43, in producer
    data_loader.set_thread_id(thread_id)
  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
AttributeError: 'MSD' object has no attribute 'set_thread_id'
  File "/home/rahul/.local/lib/python3.6/site-packages/batchgenerators/dataloading/multi_threaded_augmenter.py", line 43, in producer
    data_loader.set_thread_id(thread_id)
Traceback (most recent call last):


KeyboardInterrupt: 

In [None]:
fig = plt.figure()
plt.subplot(2,1,1).imshow(dataset[10][0].squeeze().numpy().T, cmap='gray')
plt.subplot(2,1,2).imshow(dataset[10][1].squeeze().numpy().T, cmap='gray')
plt.show()