In [1]:
import os
import fnmatch
import random

import numpy as np

from batchgenerators.dataloading.data_loader import DataLoader
from batchgenerators.dataloading import MultiThreadedAugmenter

In [2]:
def get_transforms(mode="train", target_size=128):
    tranform_list = []

    if mode == "train":
        tranform_list = [# CenterCropTransform(crop_size=target_size),
                         ResizeTransform(target_size=(target_size,target_size), order=1),
                         MirrorTransform(axes=(1,)),
                         ]


    elif mode == "val":
        tranform_list = [CenterCropTransform(crop_size=target_size),
                         ResizeTransform(target_size=target_size, order=1),
                         ]

    elif mode == "test":
        tranform_list = [CenterCropTransform(crop_size=target_size),
                         ResizeTransform(target_size=target_size, order=1),
                         ]

    tranform_list.append(NumpyToTensor())

    return Compose(tranform_list)

In [146]:
def load_data_set(root_dir,mode,keys,taskname):
    image_names = keys
    print("root_dir:",root_dir,"taskname : ",taskname,"mode :",mode,"keys : ",keys )
    dataDir = "imagesTr"
    maskDir = "masksTr"
    suffix =".png"  
    img_data = []
    img_labels = []

    for image in image_names : 
        img = cv2.imread(osp.join(root_dir,taskname,dataDir,image+suffix))
        print("image path: ",osp.join(root_dir,taskname,dataDir,image+suffix))
        img_data.append(img)
                
        target_img = np.zeros(img.shape[:2], dtype=np.uint8)
        target_img_ = cv2.imread(osp.join(root_dir,taskname,maskDir,image+suffix),0)
        target_img = np.maximum(target_img, target_img_)
        img_labels.append(target_img)

    return img_data,img_labels

In [138]:
class MedImageDataSet(object):

        def __init__(self, base_dir, mode="train",batch_size=16,num_batches=10000000,taskname=None,seed=None, num_processes=8,
                     num_cached_per_queue=8 * 4, target_size=128, file_pattern='*.png', do_reshuffle=True, keys=None):

            data_loader = MedImageDataLoader(base_dir=base_dir,mode,batch_size, 
                                        num_batches,taskname,seed,file_pattern,keys)

            self.data_loader = data_loader
            self.batch_size = batch_size
            #self.do_reshuffle = do_reshuffle
            self.number_of_slices = 1

            self.transforms = get_transforms(mode=mode, target_size=target_size)
            self.augmenter = MultiThreadedAugmenter(data_loader, self.transforms, num_processes=num_processes,num_cached_per_queue=num_cached_per_queue, seeds=seed,
                                                 shuffle=do_reshuffle)
            self.augmenter.restart()

        def __len__(self):
            return len(self.data_loader)

        def __iter__(self):
            self.augmenter.renew()
            return self.augmenter

        def __next__(self):
            return next(self.augmenter)


SyntaxError: positional argument follows keyword argument (<ipython-input-138-403db30e36fd>, line 6)

In [148]:
class MedImageDataLoader(DataLoader):
    def __init__(self, base_dir,mode ,batch_size,num_batches,taskname,seed=1234, file_pattern=".png",keys=None):
        self.img_data ,self.img_labels = load_data_set(base_dir,mode,keys,taskname)
        self.data = list(range(len(self.img_data)))

        super().__init__(self.data, batch_size, seed_for_shuffle=seed)
        self.patch_size = patch_size
        self.batch_size = batch_size

        self.use_next = False
        if mode == "train":
            self.use_next = False

        self.indices = list(range(len(self.img_data))) 
        self.data_len = len(self.img_data)

        self.num_batches = min((self.data_len // self.batch_size)+10, num_batches)

    def generate_train_batch(self):
        
        idx = self.get_indices()
        patients_for_batch = [self._data[i] for i in idx]
        
        data = []
        labels = []
        
        for idx in patients_for_batch:
            data.append(self.img_data[idx])
            labels.append(self.img_labels[idx])
            
        return {'data': data, 'seg':seg}

    def __len__(self):
        n_items = min(self.data_len // self.batch_size, self.num_batches)
        return n_items

In [140]:
import os
import pickle
from collections import OrderedDict
from collections import defaultdict
import numpy as np
import torch
import torchvision.transforms as TF
import torch.optim as optim

In [128]:
import os

from trixi.util import Config


def get_config():
    # Set your own path, if needed.
    #data_root_dir = os.path.abspath('data')  # The path where the downloaded dataset is stored.
    #data_root_dir = "/home/ramesh/Desktop/WS/Implementation/experiment/Data/Filtereddataset"
    data_root_dir ="/home/ramesh/Desktop/IIITB/experiment/data/FilteredDataSet"
    taskName = "Task01_Hippocampus"
    #taskName = "Task09_Spleen"
    c = Config(
        update_from_argv=True,
        # Train parameters
        #num_classes=3,
        #in_channels=1,
        batch_size=2,
        patch_size=64,
        n_epochs=5,
        learning_rate=0.0002,
        fold=0,  # The 'splits.pkl' may contain multiple folds. Here we choose which one we want to use.

        device="cuda",  # 'cuda' is the default CUDA device, you can use also 'cpu'. For more information, see https://pytorch.org/docs/stable/notes/cuda.html

        # Logging parameters
        name='Segmentation_Experiment_Unet',
        plot_freq=10,  # How often should stuff be shown in visdom
        append_rnd_string=False,
        start_visdom=True,

        do_instancenorm=True,  # Defines whether or not the UNet does a instance normalization in the contracting path
        do_load_checkpoint=False,
        checkpoint_dir='',

        # Adapt to your own path, if needed.
        #google_drive_id='1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C',
        
        dataset_name = taskName,
        base_dir=os.path.abspath('output_experiment'),  # Where to log the output of the experiment.

        data_root_dir=data_root_dir,  # The path where the downloaded dataset is stored.
        data_dir=os.path.join(data_root_dir, taskName,'imagesTr'),  # This is where your training and validation data is stored
       
        #data_test_dir=os.path.join(data_root_dir, 'Task04_Hippocampus/preprocessed'),  # This is where your test data is stored
        split_dir= os.path.join(data_root_dir, taskName,'preprocessed'),  # This is where the 'splits.pkl' file is located, that holds your splits.
    )

    print(c)
    return c

In [129]:
config = get_config()

{
    "append_rnd_string": false,
    "base_dir": "/home/ramesh/Desktop/IIITB/MedicalImaging/output_experiment",
    "batch_size": 2,
    "checkpoint_dir": "",
    "data_dir": "/home/ramesh/Desktop/IIITB/experiment/data/FilteredDataSet/Task01_Hippocampus/imagesTr",
    "data_root_dir": "/home/ramesh/Desktop/IIITB/experiment/data/FilteredDataSet",
    "dataset_name": "Task01_Hippocampus",
    "device": "cuda",
    "do_instancenorm": true,
    "do_load_checkpoint": false,
    "fold": 0,
    "learning_rate": 0.0002,
    "n_epochs": 5,
    "name": "Segmentation_Experiment_Unet",
    "patch_size": 64,
    "plot_freq": 10,
    "split_dir": "/home/ramesh/Desktop/IIITB/experiment/data/FilteredDataSet/Task01_Hippocampus/preprocessed",
    "start_visdom": true
}


In [130]:
pkl_dir = config.split_dir
with open(os.path.join(pkl_dir, "splits.pkl"), 'rb') as f:
    splits = pickle.load(f)
tr_keys = splits[config.fold]['train']
val_keys = splits[config.fold]['val']
test_keys = splits[config.fold]['test']
print("pkl_dir: ",pkl_dir) 
print("tr_keys: ",tr_keys)
print("val_keys: ",val_keys)
print("test_keys: ",test_keys)

pkl_dir:  /home/ramesh/Desktop/IIITB/experiment/data/FilteredDataSet/Task01_Hippocampus/preprocessed
tr_keys:  ['hippocampus_001_0020', 'hippocampus_001_0000', 'hippocampus_001_0019', 'hippocampus_001_0002', 'hippocampus_001_0001', 'hippocampus_001_0005', 'hippocampus_001_0014', 'hippocampus_001_0026', 'hippocampus_001_0033', 'hippocampus_001_0031', 'hippocampus_001_0018', 'hippocampus_001_0030', 'hippocampus_001_0010', 'hippocampus_001_0011', 'hippocampus_001_0007', 'hippocampus_001_0025', 'hippocampus_001_0024', 'hippocampus_001_0022', 'hippocampus_001_0003', 'hippocampus_001_0032', 'hippocampus_001_0028']
val_keys:  ['hippocampus_001_0016', 'hippocampus_001_0013', 'hippocampus_001_0004', 'hippocampus_001_0027', 'hippocampus_001_0021', 'hippocampus_001_0017', 'hippocampus_001_0029', 'hippocampus_001_0015', 'hippocampus_001_0009', 'hippocampus_001_0008']
test_keys:  ['hippocampus_001_0023', 'hippocampus_001_0006', 'hippocampus_001_0034']


In [131]:
task = config.dataset_name
task

'Task01_Hippocampus'

In [132]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os

import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms

import torch.nn as nn

import numpy as np
import cv2

import os.path as osp
from glob import glob
from tqdm import tqdm


def load_data_set(root_dir,taskname,mode="train",keys=None):
    image_names = keys

    dataDir = "imagesTr"
    maskDir = "masksTr"
    suffix =".png"  
    img_data = []
    img_labels = []

    for image in image_names : 
        img = cv2.imread(osp.join(root_dir,taskname,dataDir,image+suffix))
        #print("image path: ",osp.join(root_dir,taskname,dataDir,image+suffix))
        img_data.append(img)
                
        target_img = np.zeros(img.shape[:2], dtype=np.uint8)
        target_img_ = cv2.imread(osp.join(root_dir,taskname,maskDir,image+suffix),0)
        target_img = np.maximum(target_img, target_img_)
        img_labels.append(target_img)

    return img_data,img_labels

class NucleusDataset(Dataset):
    def __init__(self, root_dir, train=True, transform=None, target_transform=None, mode ="train",
                  do_reshuffle=True, keys=None,taskname = None,batch_size=16, num_batches=10000000, seed=None):
        self.root_dir = root_dir
        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        self.taskname = taskname
        self.image_names = keys
        self.mode = mode
        self.data_len = len(self.image_names)
        self.batch_size = batch_size
        self.num_batches = min((self.data_len // self.batch_size)+10, num_batches)

        dataDir = "imagesTr"
        maskDir = "masksTr"
        suffix =".png" 
        print("root_dir :",root_dir, " taskname : ",taskname,"self.mode :",self.mode)
        print(" path : ",osp.join(self.root_dir, taskname))
        
        if not self._check_task_exists():
            raise RuntimeError("Task does not exist")
            

        if self.mode=="train":

            print(" Mode : ",mode , " train image_names :",self.image_names)
            self.train_data ,self.train_labels =   load_data_set(root_dir=self.root_dir,taskname=self.taskname,mode=self.mode,keys=self.image_names)
            
        elif self.mode =="val":

            print(" Mode : ",mode , " val image_names :",self.image_names)
            self.val_data ,self.val_labels = load_data_set(root_dir=self.root_dir,taskname=self.taskname,mode=self.mode,keys=self.image_names)

        else :

            print(" Mode : ",mode , " test image_names :",self.image_names)
            self.test_data ,self.test_labels = load_data_set(root_dir=self.root_dir,taskname=self.taskname,mode=self.mode,keys=self.image_names)
            
    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, item):
        if self.mode=="train":
            image, mask = self.train_data[item], self.train_labels[item]

            if self.transform:
                image = self.transform(image)

            if self.target_transform:
                mask = self.target_transform(mask)

            return image, mask
                                              
        elif self.mode=="val":
            image, mask = self.val_data[item], self.val_labels[item]

            if self.transform:
                image = self.transform(image)

            if self.target_transform:
                mask = self.target_transform(mask)

            return image, mask     
                                              
        else:
            image, mask = self.test_data[item], self.test_labels[item]

            if self.transform:
                image = self.transform(image)

            if self.target_transform:
                mask = self.target_transform(mask)

            return image, mask     

    def _check_exists(self):
        return osp.exists(osp.join(self.root_dir, "train")) and osp.exists(osp.join(self.root_dir, "test"))
    
    def _check_task_exists(self):
        return osp.exists(osp.join(self.root_dir, self.taskname))

In [133]:
def tensor_to_numpy(tensor):
    t_numpy = tensor.cpu().numpy()
    t_numpy = np.transpose(t_numpy, [0, 2, 3, 1])
    t_numpy = np.squeeze(t_numpy)

    return t_numpy

class ToTensor:
    def __call__(self, data):
        if len(data.shape) == 2:
            data = np.expand_dims(data, axis=0)
        elif len(data.shape) == 3:
            data = data.transpose((2, 0, 1))
        else:
            print("Unsupported shape!")
        return torch.from_numpy(data)

class Normalize:
    def __call__(self, image):
        image = image.astype(np.float32) / 255
        return image

class Horizontal_flip:
    def __call__(self, image):
        # horizontal flip doesn't need skimage, it's easy as flipping the image array of pixels !
        image = image[:, ::-1]
        return image

class Rescale:
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, image):
        return cv2.resize(image, (self.output_size, self.output_size), cv2.INTER_AREA)

In [134]:
train_data_loader = torch.utils.data.DataLoader(
        NucleusDataset(config.data_root_dir, train=True,
                       transform=transforms.Compose([
                           Normalize(),
                           Rescale(config.patch_size),
                           ToTensor()
                       ]),
                       target_transform=transforms.Compose([
                           Normalize(),
                           Rescale(config.patch_size),
                           ToTensor()
                       ]),
                      mode ="train",
                      keys = tr_keys,
                      taskname = task),
        batch_size=config.batch_size, shuffle=True)

root_dir : /home/ramesh/Desktop/IIITB/experiment/data/FilteredDataSet  taskname :  Task01_Hippocampus self.mode : train
 path :  /home/ramesh/Desktop/IIITB/experiment/data/FilteredDataSet/Task01_Hippocampus
 Mode :  train  train image_names : ['hippocampus_001_0020', 'hippocampus_001_0000', 'hippocampus_001_0019', 'hippocampus_001_0002', 'hippocampus_001_0001', 'hippocampus_001_0005', 'hippocampus_001_0014', 'hippocampus_001_0026', 'hippocampus_001_0033', 'hippocampus_001_0031', 'hippocampus_001_0018', 'hippocampus_001_0030', 'hippocampus_001_0010', 'hippocampus_001_0011', 'hippocampus_001_0007', 'hippocampus_001_0025', 'hippocampus_001_0024', 'hippocampus_001_0022', 'hippocampus_001_0003', 'hippocampus_001_0032', 'hippocampus_001_0028']


In [150]:
tr_data_loader = MedImageDataSet(base_dir=config.data_root_dir, mode="train", batch_size=4, num_batches=10000000, taskname = config.dataset_name,seed=None, 
                                num_processes=8,num_cached_per_queue=8 * 4, target_size=128, file_pattern='*.png', do_reshuffle=True, keys=tr_keys)

root_dir: /home/ramesh/Desktop/IIITB/experiment/data/FilteredDataSet taskname :  Task01_Hippocampus mode : train keys :  ['hippocampus_001_0020', 'hippocampus_001_0000', 'hippocampus_001_0019', 'hippocampus_001_0002', 'hippocampus_001_0001', 'hippocampus_001_0005', 'hippocampus_001_0014', 'hippocampus_001_0026', 'hippocampus_001_0033', 'hippocampus_001_0031', 'hippocampus_001_0018', 'hippocampus_001_0030', 'hippocampus_001_0010', 'hippocampus_001_0011', 'hippocampus_001_0007', 'hippocampus_001_0025', 'hippocampus_001_0024', 'hippocampus_001_0022', 'hippocampus_001_0003', 'hippocampus_001_0032', 'hippocampus_001_0028']
image path:  /home/ramesh/Desktop/IIITB/experiment/data/FilteredDataSet/Task01_Hippocampus/imagesTr/hippocampus_001_0020.png
image path:  /home/ramesh/Desktop/IIITB/experiment/data/FilteredDataSet/Task01_Hippocampus/imagesTr/hippocampus_001_0000.png
image path:  /home/ramesh/Desktop/IIITB/experiment/data/FilteredDataSet/Task01_Hippocampus/imagesTr/hippocampus_001_0019.png

TypeError: __init__() got an unexpected keyword argument 'seed_for_shuffle'

In [121]:
base_dir, mode="train",batch_size=16,num_batches=10000000,taskname=None,seed=None, num_processes=8,num_cached_per_queue=8 * 4, target_size=128, file_pattern='*.png', do_reshuffle=True, keys=tr_keys):

SyntaxError: invalid syntax (<ipython-input-121-80f469ac8712>, line 2)