In [2]:
import os
import torch
import numpy as np
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import h5py
from torch.utils.data.sampler import Sampler
from torchvision.transforms import Compose
import random
import torchio as tio


class AortaDissection(Dataset):
    

    def __init__(self, data_dir, list_dir, split, aug_times=1,
                       scale_range = (0.9,1.1), 
                       rotate_degree = 10):

        self.data_dir = data_dir
        self.list_dir = list_dir
        self.split = split
        self.scale_range = scale_range
        self.rotate_degree = rotate_degree
        self.aug_times = aug_times
        
        if split == 'lab':
            data_path = os.path.join(list_dir,'train_lab.txt')
            self.transform = True
        elif split == 'unlab':
            data_path = os.path.join(list_dir,'train_unlab.txt')
            self.transform = False
        elif split == 'train':
            data_path = os.path.join(list_dir,'train.txt')
            self.transform = True
        else:
            data_path = os.path.join(list_dir,'test.txt')
            self.transform = False

        with open(data_path, 'r') as f:
            self.image_list = f.readlines()

        self.image_list = [item.replace('\n', '') for item in self.image_list]
        self.image_list = [os.path.join(self.data_dir, item) for item in self.image_list]

        print("{} set: total {} samples".format(split, len(self.image_list)))

    def __len__(self):
        if self.split != 'test':
            return len(self.image_list) * self.aug_times
        else:
            return len(self.image_list)

    def __getitem__(self, idx):
        try:
            image_path = self.image_list[idx % len(self.image_list)]
            h5f = h5py.File(image_path, 'r')
            image, label = h5f['image'][:], h5f['label'][:].astype(np.float32)
        except FileNotFoundError:
        # Handle the missing file by adding placeholders to a dictionary
            image = np.zeros((128, 128, 128), dtype=np.float32)
            label = np.zeros((128, 128, 128), dtype=np.float32)
            if self.transform:
                subject = tio.Subject(image=tio.ScalarImage(tensor=image[np.newaxis,]), \
                              label=tio.LabelMap(tensor=label[np.newaxis,]))
                RandomAffine = tio.RandomAffine(scales=self.scale_range, degrees=self.rotate_degree)
                randaff_sample = RandomAffine(subject)
                image = randaff_sample['image']['data']
                label = torch.unsqueeze(randaff_sample['label']['data'], 0)
            else:
                image = torch.from_numpy(image[np.newaxis,]) if image is not None else torch.zeros((1, 128, 128, 128), dtype=torch.float32)
                label = torch.from_numpy(label) if label is not None else torch.zeros((1, 128, 128, 128), dtype=torch.float32)

        return {'image': image.float(), 'label': label.squeeze().long(), 'name': image_path}

    # def display_first_five_images(dataset, title):
    #     plt.figure(figsize=(15, 3))
    #     for i in range(5):
    #         image_path = dataset.image_list[i]
    #         nii_img = nib.load(image_path)
    #         image = nii_img.get_fdata()
    #         slice_index = image.shape[2] // 2
    #         selected_slice = image[:, :, slice_index]
    #         plt.subplot(1, 5, i + 1)
    #         plt.imshow(selected_slice, cmap='gray')
    #         plt.axis('off')
    #         plt.title(f'{title} {i+1}')

    # plt.suptitle(f'First five images from {title} set')
    # plt.show()
        

if __name__ == '__main__':

    data_dir = '/home/amishr17/aryan/aorta-dissec/'
    list_dir = '/home/amishr17/aryan/aorta-dissec/data/datalist/AD/AD_0'

    ''' aorta-dissec/data'''

    labset = AortaDissection(data_dir, list_dir,split='lab')
    unlabset = AortaDissection(data_dir,list_dir,split='unlab')
    trainset = AortaDissection(data_dir,list_dir,split='train')
    testset = AortaDissection(data_dir, list_dir,split='test')

    lab_sample = labset[0]
    unlab_sample = unlabset[0]
    train_sample = trainset[0] 
    test_sample = testset[0]
 
    print(len(trainset), lab_sample['image'].shape, lab_sample['label'].shape) # 100 torch.Size([1, 128, 128, 128]) torch.Size([128, 128, 128])
    print(len(labset), unlab_sample['image'].shape, unlab_sample['label'].shape) # 20 torch.Size([1, 128, 128, 128]) torch.Size([128, 128, 128])
    print(len(unlabset), train_sample['image'].shape, train_sample['label'].shape) # 80 torch.Size([1, 128, 128, 128]) torch.Size([128, 128, 128])
    print(len(testset), test_sample['image'].shape, test_sample['label'].shape) # 24 torch.Size([1, 128, 128, 128]) torch.Size([128, 128, 128])

    labset = AortaDissection(data_dir, list_dir,split='lab', aug_times=5)
    unlabset = AortaDissection(data_dir,list_dir,split='unlab', aug_times=5)
    trainset = AortaDissection(data_dir,list_dir,split='train', aug_times=5)
    testset = AortaDissection(data_dir, list_dir,split='test', aug_times=5)

    lab_sample = labset[0]
    unlab_sample = unlabset[0]
    train_sample = trainset[0] 
    test_sample = testset[0]

    print(len(trainset), lab_sample['image'].shape, lab_sample['label'].shape) # 500 torch.Size([1, 128, 128, 128]) torch.Size([128, 128, 128])
    print(len(labset), unlab_sample['image'].shape, unlab_sample['label'].shape) # 100 torch.Size([1, 128, 128, 128]) torch.Size([128, 128, 128])
    print(len(unlabset), train_sample['image'].shape, train_sample['label'].shape) # 400 torch.Size([1, 128, 128, 128]) torch.Size([128, 128, 128])
    print(len(testset), test_sample['image'].shape, test_sample['label'].shape) # 24 torch.Size([1, 128, 128, 128]) torch.Size([128, 128, 128])

    lab_sample = labset[0]
    print(lab_sample['image'].shape, lab_sample['label'].shape)

lab set: total 20 samples
unlab set: total 80 samples
train set: total 100 samples
test set: total 24 samples
100 torch.Size([1, 128, 128, 128]) torch.Size([128, 128, 128])
20 torch.Size([1, 128, 128, 128]) torch.Size([128, 128, 128])
80 torch.Size([1, 128, 128, 128]) torch.Size([128, 128, 128])
24 torch.Size([1, 128, 128, 128]) torch.Size([128, 128, 128])
lab set: total 20 samples
unlab set: total 80 samples
train set: total 100 samples
test set: total 24 samples
500 torch.Size([1, 128, 128, 128]) torch.Size([128, 128, 128])
100 torch.Size([1, 128, 128, 128]) torch.Size([128, 128, 128])
400 torch.Size([1, 128, 128, 128]) torch.Size([128, 128, 128])
24 torch.Size([1, 128, 128, 128]) torch.Size([128, 128, 128])
torch.Size([1, 128, 128, 128]) torch.Size([128, 128, 128])


In [12]:
def display_first_five_images(dataset, title):
    plt.figure(figsize=(15, 3))
    for i in range(5):
        image_path = dataset.image_list[i]
        nii_img = nib.load(image_path)
        image = nii_img.get_fdata()
        slice_index = image.shape[2] // 2
        selected_slice = image[:, :, slice_index]
        plt.subplot(1, 5, i + 1)
        plt.imshow(selected_slice, cmap='gray')
        plt.axis('off')
        plt.title(f'{title} {i+1}')

    plt.suptitle(f'First five images from {title} set')
    plt.show()
        

if __name__ == '__main__':
    data_dir = '/home/amishr17/aryan/aorta-dissec/'
    

    labset = AortaDissection(data_dir, split='lab')
    unlabset = AortaDissection(data_dir, split='unlab')
    trainset = AortaDissection(data_dir, split='train')
    testset = AortaDissection(data_dir, split='test')

    display_first_five_images(labset, 'Labeled')
    display_first_five_images(unlabset, 'Unlabeled')
    display_first_five_images(trainset, 'Training')
    display_first_five_images(testset, 'Testing')
    
   
    lab_sample = labset[0]
    print(lab_sample['image'].shape, lab_sample['label'].shape)