In [1]:
import os
from torch.utils.data import Dataset, DataLoader, ConcatDataset,  WeightedRandomSampler
from skimage import io, transform
from torchvision import transforms, utils
import torch
import numpy as np
import nibabel as nib
from random import randint
from PIL import Image
import torch.optim as optim
import time
import QuickNAT as QN
import torch.nn as nn
import matplotlib.pyplot as plt

### The dataset of T1w image and brainmask image in axial plane

In [2]:
class TrainDataset(Dataset):
    """Training dataset with mask image mapping to classes"""
    def __init__(self, T1a_dir, brainmask_dir, transform=None):
        """
        Args:
            T1a_dir (string): Directory with T1w image in axial plane
            transform (callable): Optional transform to be applied on a sample
            brainmask_dir (string): Directory with brainmask in axial plane
        """
        self.T1a_dir = T1a_dir
        self.transform = transform
        self.brainmask_dir = brainmask_dir
#         self.mapping = {
#             180:91
#         }
        
#     def mask_to_class(self, mask):
#         for k in self.mapping:
#             mask[mask==k] = self.mapping[k]
#         return mask
    
    def __len__(self):
        T1a_list = os.listdir(self.T1a_dir)
        return len(T1a_list)
    
    
    def __getitem__(self, idx):
        T1a_list = os.listdir(self.T1a_dir)
        brainmask_list = os.listdir(self.brainmask_dir)
        
        T1a_str = T1a_list[idx]
        
        T1a_arr = io.imread(os.path.join(self.T1a_dir, T1a_str))
        T1a_tensor = torch.from_numpy(T1a_arr)
       
        
        compose_T1 = transforms.Compose([transforms.ToPILImage(),
                                         transforms.Resize((128,128),interpolation=Image.NEAREST),
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        
        T1a_tensor = torch.unsqueeze(T1a_tensor, dim = 0)   

        T1a_tensor = compose_T1(T1a_tensor)
       
        
        # The original brainmask value is 0,1,2     
        brainmask_str = brainmask_list[idx]
    
        brainmask_arr = io.imread(os.path.join(self.brainmask_dir, brainmask_str))
        brainmask_tensor = torch.from_numpy(brainmask_arr)
        
        compose = transforms.Compose([transforms.ToPILImage(),
                                      transforms.Resize((128,128),interpolation=Image.NEAREST), 
                                      transforms.ToTensor()])
        
        brainmask_tensor = torch.unsqueeze(brainmask_tensor, dim = 0)
        brainmask_tensor = compose(brainmask_tensor)
        brainmask_tensor = brainmask_tensor.squeeze()
        
     
        # After the resize, the value of brainmask is 0, 0.0039 and 0.0078, so this formula below is used
        # to make it to 0, 1, 2
        brainmask_tensor = torch.round(brainmask_tensor / 0.0039).byte()   
        
#         parc1a_tensor = self.mask_to_class(parc1a_tensor)
      
        sample = {'T1a':T1a_tensor, 'brainmask':brainmask_tensor}
        
        if self.transform:
            T1a = self.transform(T1a_tensor)
            sample = {'T1a':T1a, 'brainmask':brainmask}
            
        return sample

### Using the data in the directory MRIdata to count the number of 0, 1, 2, in order to define the weight of the target.
* MRIdata directory contains all the sliced data

In [3]:
# # subject 0, construct the first data in total_data
# sub_idx = 0
# T1a_dir = '/home/xiaoyu/MRIdata/T1w/axial/sub{}'.format(sub_idx)
# brainmask_dir = '/home/xiaoyu/MRIdata/brainmask/axial/sub{}'.format(sub_idx)
# total_data = TrainDataset(T1a_dir=T1a_dir, brainmask_dir = brainmask_dir)
# print(len(total_data))

182


In [4]:
# # Add up all the subject data
# for sub_idx in range(1,330):
#     T1a_dir = '/home/xiaoyu/MRIdata/T1w/axial/sub{}'.format(sub_idx)
#     brainmask_dir = '/home/xiaoyu/MRIdata/brainmask/axial/sub{}'.format(sub_idx)
#     train_data = TrainDataset(T1a_dir=T1a_dir, brainmask_dir = brainmask_dir)
#     total_data = total_data + train_data
# print(len(total_data))

60060


The total data of Group ID. (one group has 10 slices, so for 330 subjects, there are in total 3330 slices in one group)

In [6]:
# mask_total = torch.tensor([])

# for i in range(len(total_data)):
#     sample = total_data[i]
#     mask = sample['brainmask'].float()
#     mask_total = torch.cat((mask_total, mask))
    
# print(mask_total.size())

In [5]:
# gpu_id = 1
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

# device = torch.device('cuda')
# print(device)

cuda


In [None]:
# mask_total = torch.tensor([])
# mask_total= mask_total.to(device)
# # total_data = total_data.to(device)
# for i in range(len(total_data)):
#     sample = total_data[i]
#     mask = sample['brainmask'].float()
#     mask = mask.to(device)
#     mask_total = torch.cat((mask_total, mask))
    
# print(mask_total.size())

In [None]:
# unique_color, count = np.unique(mask_total.cpu(), return_counts = True)
# print(count)
# print(unique_color.size)
# print(unique_color)