In [8]:
# !pip3 uninstall scikit-learn
# !pip3 install scikit-learn==0.24.2
# !pip3 uninstall imbalanced-learn==0.5.0
!pip3 install imbalanced-learn==0.4.2



In [9]:
!pip3 install trixi



In [None]:
from trixi.util.pytorchutils import set_seed

In [None]:
import os
import fnmatch
import random
from abc import ABCMeta, abstractmethod
import torch
from torch.utils.data import DataLoader, Dataset
from skimage.transform import resize
from trixi.util.pytorchutils import set_seed
import numpy as np
import pickle




def load_dataset(base_dir, pattern='*.npz', keys=None):
    fls = []
    files_len = []
    dataset = []

    for root, dirs, files in os.walk(base_dir):
        i = 0
        for filename in sorted(fnmatch.filter(files, pattern)):

            if keys is not None and filename[:-4] in keys:
                npz_file = os.path.join(root, filename)
                numpy_array = np.load(npz_file)['data']
                
                fls.append(npz_file)
                files_len.append(numpy_array.shape[1])

                dataset.extend([i])

                i += 1

    return fls, files_len, dataset

class SlimDataLoaderBase(object):
    def __init__(self, data, batch_size, number_of_threads_in_multithreaded=None):
        __metaclass__ = ABCMeta
        self.number_of_threads_in_multithreaded = number_of_threads_in_multithreaded
        self._data = data
        self.batch_size = batch_size
        self.thread_id = 0

    def set_thread_id(self, thread_id):
        self.thread_id = thread_id

    def __iter__(self):
        return self

    def __next__(self):
        return self.generate_train_batch()

    @abstractmethod
    def generate_train_batch(self):
        '''override this
        Generate your batch from self._data .Make sure you generate the correct batch size (self.BATCH_SIZE)
        '''
        pass


class NumpyDataLoader(SlimDataLoaderBase):
    def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000,
                 seed=None, file_pattern='*.npz', label=1, input=(0,), keys=None):

        shorter_keys=[]
        for key in keys:
            arr=key.split('/')
            
            shorter_keys.append(arr[len(arr)-1])
        
        keys=shorter_keys
        self.files, self.file_len, self.dataset = load_dataset(base_dir=base_dir, pattern=file_pattern, keys=keys )
        
        super(NumpyDataLoader, self).__init__(self.dataset, batch_size, num_batches)

        self.batch_size = batch_size

        self.use_next = False
        if mode == "train":
            self.use_next = False

        self.idxs = list(range(0, len(self.dataset)))

        self.data_len = len(self.dataset)

        self.num_batches = min((self.data_len // self.batch_size)+10, num_batches)

        if isinstance(label, int):
            label = (label,)
        self.input = input
        self.label = label

        self.np_data = np.asarray(self.dataset)

    def reshuffle(self):
        print("Reshuffle...")
        random.shuffle(self.idxs)
        print("Initializing... this might take a while...")

    def generate_train_batch(self):
        open_arr = random.sample(self._data, self.batch_size)
        return self.get_data_from_array(open_arr)

    def __len__(self):
        n_items = min(self.data_len // self.batch_size, self.num_batches)
        return n_items

    def __getitem__(self, item):
        idxs = self.idxs
        data_len = len(self.dataset)
        np_data = self.np_data

        if item > len(self):
            raise StopIteration()
        if (item * self.batch_size) == data_len:
            raise StopIteration()

        start_idx = (item * self.batch_size) % data_len
        stop_idx = ((item + 1) * self.batch_size) % data_len

        if ((item + 1) * self.batch_size) == data_len:
            stop_idx = data_len

        if stop_idx > start_idx:
            idxs = idxs[start_idx:stop_idx]
        else:
            raise StopIteration()

        open_arr = np_data[idxs]

        return self.get_data_from_array(open_arr)

    def get_data_from_array(self, open_array):
        data = []
        fnames = []
        idxs = []
        labels = []

        for idx in open_array:
            fn_name = self.files[idx]

            numpy_array = np.load(fn_name)

            data.append(numpy_array[list(self.input)])   # 'None' keeps the dimension

            if self.label is not None:
                labels.append(numpy_array[list(self.input)])   # 'None' keeps the dimension

            fnames.append(self.files[idx])
            idxs.append(idx)

        ret_dict = {'data': data, 'fnames': fnames, 'idxs': idxs}
        if self.label is not None:
            ret_dict['seg'] = labels

        return ret_dict

class WrappedDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.transform = transform
        self.dataset = dataset

        self.is_indexable = False
        if hasattr(self.dataset, "__getitem__") and not (hasattr(self.dataset, "use_next") and self.dataset.use_next is True):
            self.is_indexable = True

    def __getitem__(self, index):

        if not self.is_indexable:
            item = next(self.dataset)
        else:
            item = self.dataset[index]
        # item = self.transform(**item)
        print(type(item))
        old_data=item['data']
        old_seg=item['seg']
        
        new_shape=(128,128,128)
        result_list=[]
        
        for i in range(len(old_data)):
            result_element = np.zeros(new_shape, dtype=old_data[i].dtype)
            result_element= resize(old_data[i].astype(float), new_shape, order=3, clip=True, anti_aliasing=False)
            result_list.append(result_element)
        item['data']=result_list
        result_list=[]
        result_element = np.zeros(new_shape, dtype=old_seg[0].dtype)
        unique_labels = np.unique(old_seg[0])
        for i, c in enumerate(unique_labels):
            mask = old_seg[0] == c
            reshaped_multihot = resize(mask.astype(float), new_shape, order=1, mode="edge", clip=True, anti_aliasing=False)
            result_element[reshaped_multihot >= 0.5] = c
        
        result_list.append(result_element)
        item['seg']=result_list
        print(np.unique(result_list[0]))
        return item

    def __len__(self):
        return int(self.dataset.num_batches)


class MultiThreadedDataLoader(object):
    def __init__(self, data_loader,  num_processes,transform=None, **kwargs):

        self.cntr = 1
        self.ds_wrapper = WrappedDataset(data_loader, transform)

        self.generator = DataLoader(self.ds_wrapper, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                                    num_workers=num_processes, pin_memory=True, drop_last=False,
                                    worker_init_fn=self.get_worker_init_fn())

        self.num_processes = num_processes
        self.iter = None

    def get_worker_init_fn(self):
        def init_fn(worker_id):
            set_seed(worker_id + self.cntr)

        return init_fn

    def __iter__(self):
        self.kill_iterator()
        self.iter = iter(self.generator)
        return self.iter

    def __next__(self):
        if self.iter is None:
            self.iter = iter(self.generator)
        return next(self.iter)

    def renew(self):
        self.cntr += 1
        self.kill_iterator()
        self.generator.worker_init_fn = self.get_worker_init_fn()
        self.iter = iter(self.generator)

    def kill_iterator(self):
        try:
            if self.iter is not None:
                self.iter._shutdown_workers()
                for p in self.iter.workers:
                    p.terminate()
        except:
            print("Could not kill Dataloader Iterator")

class NumpyDataSet(object):
    """
    TODO
    """
    def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, seed=None, num_processes=8, num_cached_per_queue=8 * 4, target_size=128,
                 file_pattern='*.npz', label=1, input=(0,), do_reshuffle=True, keys=None):#8*4->2*4  8->2

        data_loader = NumpyDataLoader(base_dir=base_dir, mode=mode, batch_size=batch_size, num_batches=num_batches, seed=seed, file_pattern=file_pattern,
                                      input=input, label=label, keys=keys)

        self.data_loader = data_loader
        self.batch_size = batch_size
        self.do_reshuffle = do_reshuffle
        self.number_of_slices = 1

        self.transforms = None
        self.augmenter = MultiThreadedDataLoader(data_loader, num_processes,num_cached_per_queue=num_cached_per_queue, seeds=seed,
                                                 shuffle=do_reshuffle)
        

    def __len__(self):
        return len(self.data_loader)

    def __iter__(self):
        if self.do_reshuffle:
            self.data_loader.reshuffle()
        self.augmenter.renew()
        return self.augmenter

    def __next__(self):
        return next(self.augmenter)

data_dir='/home/jovyan/main/BraTS2020_TrainingData/'
with open(os.path.join(data_dir, "splits.pkl"), 'rb') as f:
  splits = pickle.load(f)
tr_keys = splits[0]['train']
val_keys = splits[0]['val']
test_keys = splits[0]['test']
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_data_loader = NumpyDataSet(data_dir, target_size=64, batch_size=8,keys=tr_keys)
print("ok")

In [None]:
# (5, 137, 167, 133)
# (5, 143, 176, 131)
# (5, 137, 167, 124)
# (5, 143, 187, 138)
# (5, 144, 170, 138)
# (5, 140, 186, 136)
# (5, 146, 160, 127)
# (5, 139, 158, 137)
# (5, 145, 172, 140)
# (5, 140, 173, 130)
# (5, 140, 164, 145)
# (5, 140, 182, 132)
# (5, 144, 168, 146)
# (5, 141, 178, 135)
# (5, 145, 177, 140)
# (5, 147, 167, 125)
# (5, 138, 167, 142)
# (5, 146, 178, 139)
# (5, 136, 157, 133)
# (5, 140, 187, 137)
# (5, 137, 174, 139)
# (5, 137, 166, 140)
# (5, 141, 177, 140)
# (5, 137, 169, 138)
# (5, 143, 174, 137)
# (5, 141, 178, 140)
# (5, 143, 187, 132)
# (5, 141, 174, 138)
# (5, 136, 173, 131)
# (5, 136, 168, 134)
# (5, 141, 171, 130)
# (5, 135, 163, 129)
# (5, 138, 168, 128)
# (5, 149, 176, 143)
# (5, 138, 179, 140)
# (5, 138, 167, 135)
# (5, 141, 176, 144)
# (5, 134, 157, 126)
# (5, 142, 184, 141)
# (5, 129, 175, 128)
# (5, 144, 170, 130)
# (5, 144, 173, 137)
# (5, 130, 167, 148)
# (5, 135, 162, 142)
# (5, 140, 176, 133)
# (5, 142, 185, 132)
# (5, 141, 165, 143)
# (5, 141, 173, 131)