In [None]:
%run data_classes.ipynb


## 2D Dataset

In [None]:
torch.cuda.empty_cache()

In [None]:
from torch.utils.data import Dataset, DataLoader
import random

class PrepareData2D(object):
    """
    Loads the data in project classes, combines data from all subjects, 
    shuffles slices, and returns train, validation, and test split.
    Each split is a tuple with (project names (list), subject names (list),
    pcmra slices (list), masks slices (list)).
    length of project names and subject names: # slices
    shape of pcmra and masks slices: height x width x # slices
    """
    def __init__(self, project_names, seed=34): 
        self.project_names = project_names  # list with project names
        self.seed = seed
        
        self.project = self.load_project(self.project_names) # returns Project class
        
        #filter shape of the images, temporarily until further preprocessing
        self.project.filter_dimension_shape(0, 128, print_dropped=False)
        self.project.filter_dimension_shape(2, 24, print_dropped=False)
        
        # normalize masks and pcmra between 0 and 1
        self.project.normalize()
        
        # creates list with project name and subject name for each slice
        self.proj_list = np.array(self.projects_list())
        self.subj_list = np.array(self.subjects_list())
        
        # stack masks and pcmras of all subjects onto 1 3d array
        self.masks = np.dstack(self.project.masks)        
        self.pcmras = np.dstack(self.project.pcmras)
        
        # returns 3 tuples with (project names (1d array), subject names (1d array), 
        #                        pcmra slices (3d array), masks slices (3d array))
        self.train, self.val, self.test = self.create_train_val_test_split()


    def load_project(self, project_names): 
        """
        Returns project as Project class.
        """
        # load only one folder if single folder is given
        if type(project_names) == str: 
            project = Project(project_names)
        
        # load and append multiple folder as one project
        if type(project_names) == list: 
            if len(project_names) == 1: 
                project = Project(project_names[0])
            else:
                project = Project(project_names[0])
                for i in range(1, len(project_names)):
                    project.append_project(project_names[i])
        
        return project 
    
    
    def subjects_list(self):
        """
        Returns a list with subject name for each slice in self.masks.
        """
        subj_list = []
        
        for i in range(len(self.project.subjects)):
            subject = self.project.subjects[i]
            for j in range(self.project.masks_shape[i][2]): 
                subj_list.append(subject)
        
        return subj_list

    
    def projects_list(self): 
        """
        Returns a list with project name for each slice in self.masks.
        """
        proj_list = []
        
        for i in range(len(self.project.subprojects)):
            subproject = self.project.subprojects[i]
            for j in range(self.project.masks_shape[i][2]): 
                proj_list.append(subproject)
        
        return proj_list

    def create_train_val_test_split(self):

        random.seed(self.seed)
        
        # list with all slice indices (form 0 to # slices)
        idx = list(range(self.subj_list.shape[0]))
        
        # set two split points
        split1 = int(len(idx) * 0.6)
        split2 = int(len(idx) * 0.8)

        random.shuffle(idx) # shuffles indices

        # incides per data subset
        train_idx = idx[:split1]
        val_idx = idx[split1:split2]
        test_idx = idx[split2:]

        #create tuples with data
        train_data = (self.proj_list[train_idx], 
                      self.subj_list[train_idx], 
                      self.pcmras[:,:,train_idx], 
                      self.masks[:,:,train_idx])

        val_data = (self.proj_list[val_idx], 
                      self.subj_list[val_idx], 
                      self.pcmras[:,:,val_idx], 
                      self.masks[:,:,val_idx])

        test_data = (self.proj_list[test_idx], 
                      self.subj_list[test_idx], 
                      self.pcmras[:,:,test_idx], 
                      self.masks[:,:,test_idx])

        return train_data, val_data, test_data
    
    
    
class Dataset2D(Dataset):  
    """ 
    Returns tuple with (project name (str), subject name (str), 
    pcmra (np array shape: 1xHxW), mask (np array shape: 1xHxW))
    """
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return self.data[0].shape[0]

    def __getitem__(self, idx):
        # returns: <project>, <subject>, <pcmra>, <mask>
        return (self.data[0][idx], 
                self.data[1][idx], 
                np.array([self.data[2][:, :, idx]]), 
                np.array([self.data[3][:,:,idx]]))

## Demo Code

In [None]:
# data = PrepareData2D(["Aorta Volunteers", "Aorta BaV", "Aorta Resvcue", "Aorta CoA"])
# train_ds = Dataset2D(data.train)
# train_dl = DataLoader(train_ds, batch_size=32)

# batch = next(iter(train_dl))

# print("Projects:", batch[0])
# print("\n Subjects:", batch[1])
# print("\n PCMRAs:", batch[2].shape)
# print("\n Masks:", batch[3].shape)

In [None]:
# def show_2d_batch(batch):
#     title = [proj + ": " + subj for proj, subj in zip(batch[0], batch[1])]
    
#     pcmra = batch[2].clone()
#     pcmra = pcmra.reshape(pcmra.shape[0], pcmra.shape[2], pcmra.shape[3])
#     pcmra = pcmra.permute(2, 1, 0).detach().numpy()
    
#     mask = batch[3].clone()
#     mask = mask.reshape(mask.shape[0], mask.shape[2], mask.shape[3])
#     mask = mask.permute(2, 1, 0).detach().numpy()


#     show = Show_images(title, (pcmra, "pcmra"), (mask, "mask"), (pcmra + mask, "pcmra + mask"))
    
#     return show

In [None]:
# %matplotlib qt
# show = show_2d_batch(batch)

## 3D Dataset

In [None]:
from torch.utils.data import Dataset, DataLoader
import random

class PrepareData3D(object):
    """
    Loads the data in project classes, combines data from all subjects, 
    shuffles slices, and returns train, validation, and test split.
    Each split is a tuple with ((project names (1d array), subject names (1d array), 
    pcmra slices (4d array), masks slices (4d array)).
    length of project names and subject names: # subjects
    shape of pcmra and masks slices: # subjects x height x width x # slices
    """
    def __init__(self, project_names, seed=34): 
        self.project_names = project_names  # list with project names
        self.seed = seed
        
        self.project = self.load_project(self.project_names) # returns Project class
        
        #filter shape of the images, temporarily until further preprocessing
        self.project.filter_dimension_shape(0, 128, print_dropped=False)
        self.project.filter_dimension_shape(2, 24, print_dropped=False)
        
        # normalize masks and pcmra between 0 and 1
        self.project.normalize()
        
        self.proj_list = np.array(self.project.subjects)
        self.subj_list = np.array(self.project.subprojects)
        
        self.masks = np.array(self.project.masks)        
        self.pcmras = np.array(self.project.pcmras)
        
        # returns 3 tuples with (project names (1d array), subject names (1d array), 
        #                        pcmra slices (4d array), masks slices (4d array))
        self.train, self.val, self.test = self.create_train_val_test_split()


    def load_project(self, project_names): 
        """
        Returns project as Project class.
        """
        # load only one folder if single folder is given
        if type(project_names) == str: 
            project = Project(project_names)
        
        # load and append multiple folder as one project
        if type(project_names) == list: 
            if len(project_names) == 1: 
                project = Project(project_names[0])
            else:
                project = Project(project_names[0])
                for i in range(1, len(project_names)):
                    project.append_project(project_names[i])
        
        return project 
    
    def create_train_val_test_split(self):

        random.seed(self.seed)
        
        # list with all slice indices (form 0 to # slices)
        idx = list(range(self.subj_list.shape[0]))
        
        # set two split points
        split1 = int(len(idx) * 0.6)
        split2 = int(len(idx) * 0.8)

        random.shuffle(idx) # shuffles indices

        # incides per data subset
        train_idx = idx[:split1]
        val_idx = idx[split1:split2]
        test_idx = idx[split2:]

        #create tuples with data
        train_data = (self.proj_list[train_idx], 
                      self.subj_list[train_idx], 
                      self.pcmras[train_idx,:,:,:], 
                      self.masks[train_idx,:,:,:])

        val_data = (self.proj_list[val_idx], 
                      self.subj_list[val_idx], 
                      self.pcmras[val_idx,:,:,:], 
                      self.masks[val_idx,:,:,:])

        test_data = (self.proj_list[test_idx], 
                      self.subj_list[test_idx], 
                      self.pcmras[test_idx,:,:,:], 
                      self.masks[test_idx,:,:,:])

        return train_data, val_data, test_data
    
    
class Dataset3D(Dataset):  
    """ 
    Returns tuple with (project name (str), subject name (str), 
    pcmra (np array shape: 1 x # slices x H x W), mask (np array shape: 1 x # slics x H x W))
    """
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return self.data[0].shape[0]

    def __getitem__(self, idx):
        # returns: <project>, <subject>, <pcmra>, <mask>
        
        return (self.data[0][idx], 
                self.data[1][idx], 
                np.array([self.data[2][idx, :, :, :]]).transpose(0, 3, 1, 2), 
                np.array([self.data[3][idx, :, :, :]]).transpose(0, 3, 1, 2))

## Demo code

In [None]:
# data = PrepareData3D(["Aorta Volunteers", "Aorta BaV", "Aorta Resvcue", "Aorta CoA"])
# train_ds = Dataset3D(data.train)
# train_dl = DataLoader(train_ds, batch_size=32)

# batch = next(iter(train_dl))

# print("Projects:", batch[0])
# print("\n Subjects:", batch[1])
# print("\n PCMRAs:", batch[2].shape)
# print("\n Masks:", batch[3].shape)

## Siren Dataset

In [None]:
def get_coords(*sidelengths):
    '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.
    sidelen: int
    dim: int'''
    
    tensors = []
    
    for sidelen in sidelengths:
        tensors.append(torch.linspace(-1, 1, steps=sidelen))
    
    tensors = tuple(tensors)
    coords = torch.stack(torch.meshgrid(*tensors), dim=-1)
    return coords.reshape(-1, len(sidelengths))


def get_sorted_ind(tensor, axis):
    values, ind = torch.sort(tensor[:, axis])
    
    sorted_ind = []
    print(sorted(set(values.tolist())))
    for v in sorted(set(values.tolist())): 
        subset = ((values == v).nonzero(as_tuple=True)[0])
        sorted_ind += ind[subset].sort()[0].tolist()
    
    return sorted_ind


def sort_coords_and_pixels(coords, pixels):
    for axis in reversed(range(coords.shape[1])):
        ind = get_sorted_ind(coords, axis)
        coords = coords[ind]
        pixels = pixels[ind]
        print(axis)
        print(coords)
        print(pixels)
    
    return coords, pixels


def prod(val) :  
    res = 1 
    for ele in val:  
        res *= ele  
    return res   

def image_to_array(image): 
    length = prod(image.shape)

    coords = get_coords(*image.shape)
    image = image.view(length, 1)
    
    return coords, image
      
def array_to_image(coords, pixels, sort=True): 
    if sort: 
        coords, pixels = sort_coords_and_pixels(coords, pixels)
        
    size = list()
    for dim in range(coords.shape[1]): 
        i = len(set(coords[:, dim].tolist()))
        size.append(i)
    image = pixels.view(*size)
    return image

In [None]:
class SirenDataset(Dataset): 
    
    def __init__(self, data, DEVICE):
        self.data = []
        
#         for i in range(data[0].shape[0]): 
        for i in range(1): 
            sample = []
            sample.append(data[0][i])
            sample.append(data[1][i])
            coords, pcmra = image_to_array(torch.Tensor(data[2][i]))
            _, mask = image_to_array(torch.Tensor(data[3][i]))
            sample.append(coords.to(DEVICE))
            sample.append(pcmra.to(DEVICE))
            sample.append(mask.to(DEVICE))
            self.data.append(tuple(sample))
                
    
    def __len__(self):
        return len(self.data)

    
    def __getitem__(self, idx):    
#         if idx > 0: raise IndexError
            
        return (idx, *self.data[idx])

In [None]:
# data = PrepareData3D(["Aorta Volunteers", "Aorta BaV", "Aorta Resvcue", "Aorta CoA"])
# train_ds = SirenDataset(data.train) 

In [None]:
# train_ds[0]