In [1]:
import torch
import math
import numpy as np
import os
from natsort import natsorted
import matplotlib.pyplot as plt
from torch import optim
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
from torch.nn.utils import clip_grad_value_
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset, DataLoader
from collections import OrderedDict
from scipy.ndimage import zoom
from glob import glob
import nibabel as nib
from tqdm import tqdm
import cv2
import torch.nn.functional as F


# Outline for dataloader
- reads in the original nifti 3d image
- reads in the gradCAM
- based on the gradCAM, choose the most significant slice of the nifti image
    - choose 3 different slices from the 3 directions   
- output the 3 slices in 3 channels

In [47]:
def downsize_transform(data): 
    target_size = (192, 192, 192)
    data = torch.from_numpy(data).unsqueeze(0).unsqueeze(0)
    downsampled = torch.nn.functional.interpolate(data, size=target_size, mode='trilinear')

    return downsampled.squeeze(0).squeeze(0)

class TransformerDataset(Dataset):
    def __init__(self, img_dir, grad_dir, transforms = None):
        self.img_dir = img_dir
        self.grad_dir = grad_dir
        self.transforms = transforms
        self.cn_dir = os.path.join(self.img_dir, "MNI152_affine_WB_iso1mm/CN")
        self.scz_dir = os.path.join(self.img_dir, "MNI152_affine_WB_iso1mm/schiz")
        self.grad_cn_dir = os.path.join(self.grad_dir, "MNI152_affine_WB_iso1mm/CN")
        self.grad_scz_dir = os.path.join(self.grad_dir, "MNI152_affine_WB_iso1mm/CN") # change this later when rerun the extract grad script
        self.samples, self.labels = self._load_samples()

    def _load_samples(self):
        samples = []
        
        samples = [file for file in os.listdir(self.cn_dir) if file.endswith(".nii.gz")]
        labels = [0] * len(samples)
        samples += [file for file in os.listdir(self.scz_dir) if file.endswith(".nii.gz")]
        labels += [1] * (len(samples) - len(labels))

        return samples, labels

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        k = 3
        label = self.labels[idx]
        grad_file = self.samples[idx].split(".")[0]
        if label == 0:
            file_path = os.path.join(self.cn_dir, self.samples[idx])
            grad_path = os.path.join(self.grad_cn_dir, grad_file+ ".nii_activation.nii.gz")
        else:
            file_path = os.path.join(self.scz_dir, self.samples[idx])
            grad_path = os.path.join(self.grad_scz_dir, grad_file+ ".nii_activation.nii.gz")
        one_hot_label = torch.zeros(2)
        one_hot_label[label] = 1
        label = one_hot_label

        # Load the NIfTI image
        img = nib.load(file_path)

        
        grad = nib.load(grad_path)

        # Get the image data array
        img_data = np.float32(img.get_fdata())
        img_data = self.transforms(img_data)

        grad_data = np.float32(grad.get_fdata())
        sums_x = np.sum(grad_data, axis=(1, 2))
        sums_y = np.sum(grad_data, axis=(0, 2))
        sums_z = np.sum(grad_data, axis=(0, 1))

        # Find the indices of the maximum sum along each axis
        # max_x_index = np.argmax(sums_x)
        x_slices = np.argsort(sums_x)[::-1][:k]
        x_copy = x_slices.copy()
        y_slices = np.argsort(sums_y)[::-1][:k]
        y_copy = y_slices.copy()
        z_slices = np.argsort(sums_z)[::-1][:k]
        z_copy = z_slices.copy()

        # max_y_index = np.argmax(sums_y)
        # max_z_index = np.argmax(sums_z)

        x_slice = img_data[x_copy, :, :]
        y_slice = img_data[:, y_copy, :].reshape((3, 192, 192))
        z_slice = img_data[:, :, z_copy].reshape((3, 192, 192))

        return np.concatenate((x_slice, y_slice, z_slice), axis = 0), label

In [38]:
x = np.random.rand(3, 2)

y = np.random.rand(3, 2)
np.array([x, y])

array([[[0.91557547, 0.09639896],
        [0.9508808 , 0.07190435],
        [0.79231897, 0.15145252]],

       [[0.0461586 , 0.53150262],
        [0.75110032, 0.81654726],
        [0.85463303, 0.65851167]]])

In [43]:
root_dir = "/media/youzhi/SSD/bme_project/data"
grad_root_dir = "/media/youzhi/SSD/bme_project/activations"
folds_dir = [dir for dir in os.listdir(root_dir) if dir.startswith("fold")]
grads_dir = [os.path.join(grad_root_dir, dir) for dir in folds_dir]
folds_dir = [os.path.join(root_dir, dir) for dir in folds_dir]
folds_dir = natsorted(folds_dir)
grads_dir = natsorted(grads_dir)
grads_dir

['/media/youzhi/SSD/bme_project/activations/fold1',
 '/media/youzhi/SSD/bme_project/activations/fold2',
 '/media/youzhi/SSD/bme_project/activations/fold3',
 '/media/youzhi/SSD/bme_project/activations/fold4',
 '/media/youzhi/SSD/bme_project/activations/fold5',
 '/media/youzhi/SSD/bme_project/activations/fold6',
 '/media/youzhi/SSD/bme_project/activations/fold7',
 '/media/youzhi/SSD/bme_project/activations/fold8',
 '/media/youzhi/SSD/bme_project/activations/fold9',
 '/media/youzhi/SSD/bme_project/activations/fold10']

In [48]:
dataloaders = []
for i in range(len(folds_dir)):
    fold_dir = folds_dir[i]
    grad_dir = grads_dir[i]
    dataset = TransformerDataset(fold_dir, grad_dir, downsize_transform) #, downsize_transform)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
    dataloaders.append(dataloader)
    print("fold", i+1, ": ", len(dataset))

fold 1 :  196
fold 2 :  188
fold 3 :  187
fold 4 :  185
fold 5 :  187
fold 6 :  185
fold 7 :  187
fold 8 :  182
fold 9 :  188
fold 10 :  187


In [49]:
sample = next(iter(dataloaders[0]))
# plt.imshow(sample[0][0][0][:, int(96/2), :], cmap = 'bone')
print(sample[0].shape)
print(sample[0].type)
print(sample[1])
print(sample[1].shape)
# plt.imshow(sample[0][0][2], cmap = 'bone')

torch.Size([8, 9, 192, 192])
<built-in method type of Tensor object at 0x75fe2ed656d0>
tensor([[1., 0.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [1., 0.]])
torch.Size([8, 2])
