In [1]:
import torch
import torch.nn.functional as F


import os
from scipy.ndimage import rotate
from  reorient_nii import reorient_1
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import SimpleITK as sitk

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import nibabel as nib
import nilearn.plotting as nlplt
from nibabel.testing import data_path
from scipy.ndimage import affine_transform



In [3]:
folder_path = 'mr_train/'
save_folder_path = 'mr_npy/'

In [4]:
def reset_label_value(image_data):
    
    image_data[image_data == 820] = 4 # AA  # small isolated
    image_data[image_data == 500] = 3 # LV  # center 
    image_data[image_data == 421] = 2 # LA  # long tail # only 1 case has 421
    image_data[image_data == 420] = 2 # LA  # long tail 
    image_data[image_data == 205] = 1 # Myo # blue semi-cicle close to red   

    image_data[image_data == 550] = 0
    image_data[image_data == 600] = 0
    image_data[image_data == 850] = 0
    
    return image_data

In [5]:
def resample(image_data):
    
    target_shape = [256, 256, 256]
    
    image_data = torch.from_numpy( image_data.copy() ).unsqueeze(0).unsqueeze(0) # torch.Size([1, 1, 512, 512, 84])
    image_data = F.interpolate( image_data, target_shape, mode = "nearest").numpy()[0,0] # (342, 342, 63)
    
    return image_data

In [6]:
def resample_2(img, image_data):
    
    image_header = img.header
    
    target_shape = [int(image_data.shape[0] * image_header['pixdim'][1]), int(image_data.shape[1] * image_header['pixdim'][2]),\
                int(image_data.shape[2] * image_header['pixdim'][3])]
    
    
    image_data = torch.from_numpy( image_data.copy() ).unsqueeze(0).unsqueeze(0) # torch.Size([1, 1, 512, 512, 84])
    image_data = F.interpolate( image_data, target_shape, mode = "nearest").numpy()[0,0] # (342, 342, 63)
    
    return image_data

In [7]:
def corp_base_on_min_max_label(image_data):

    heart_indices = np.where(image_data != 0)
    min_coords = np.min(heart_indices, axis=1)
    max_coords = np.max(heart_indices, axis=1)
    
    cropped_volume = image_data[min_coords[0]:max_coords[0], min_coords[1]:max_coords[1], min_coords[2]:max_coords[2]]
    #print('cropped_volume', cropped_volume.shape)
    
    return cropped_volume

In [8]:
for idx in range(1001, 1021):
    filepath = os.path.join(folder_path + f'mr_train_{idx}_label.nii.gz')
    img_0 = nib.load(filepath)
    img = reorient_1(img_0)
    image_data = img.get_fdata(dtype=np.float32)
    image_data = resample_2(img_0, image_data)
    #print(image_data.shape)
    image_data = reset_label_value(image_data)
    image_data = corp_base_on_min_max_label(image_data)
    print(image_data.shape)
    print(np.unique(image_data))
    print('_'*26)
    
    np.save(f"{save_folder_path}mr_train_{idx}_label.npy", image_data)

(137, 56, 274)
[0. 1. 2. 3. 4.]
__________________________
(308, 65, 295)
[0. 1. 2. 3. 4.]
__________________________
(123, 85, 215)
[0. 1. 2. 3. 4.]
__________________________
(87, 90, 186)
[0. 1. 2. 3. 4.]
__________________________
(123, 87, 137)
[0. 1. 2. 3. 4.]
__________________________
(171, 75, 127)
[0. 1. 2. 3. 4.]
__________________________
(147, 69, 141)
[0. 1. 2. 3. 4.]
__________________________
(113, 123, 149)
[0. 1. 2. 3. 4.]
__________________________
(140, 43, 354)
[0. 1. 2. 3. 4.]
__________________________
(138, 90, 220)
[0. 1. 2. 3. 4.]
__________________________
(146, 126, 136)
[0. 1. 2. 3. 4.]
__________________________
(131, 62, 316)
[0. 1. 2. 3. 4.]
__________________________
(137, 48, 333)
[0. 1. 2. 3. 4.]
__________________________
(115, 52, 258)
[0. 1. 2. 3. 4.]
__________________________
(159, 129, 187)
[0. 1. 2. 3. 4.]
__________________________
(100, 78, 149)
[0. 1. 2. 3. 4.]
__________________________
(112, 80, 180)
[0. 1. 2. 3. 4.]
______________________