In [1]:
import torch
import torch.nn.functional as F
from  reorient_nii import reorient_1
import os
import numpy as np
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import nibabel as nib
import nilearn.plotting as nlplt
from nibabel.testing import data_path



In [3]:
folder_path = 'mr_train/'
save_folder_path = 'mr_npy/'

In [4]:
def reset_label_value(label):
    
    label[label == 820] = 4 # AA  # small isolated
    label[label == 500] = 3 # LV  # center 
    label[label == 421] = 2 # LA  # long tail # only 1 case has 421
    label[label == 420] = 2 # LA  # long tail 
    label[label == 205] = 1 # Myo # blue semi-cicle close to red   

    label[label == 550] = 0
    label[label == 600] = 0
    label[label == 850] = 0
    
    return label

In [5]:
def corp_base_on_min_max_label(image , label):

    heart_indices = np.where(label != 0)
    min_coords = np.min(heart_indices, axis=1)
    max_coords = np.max(heart_indices, axis=1)
    
    cropped_image = image[min_coords[0]:max_coords[0], min_coords[1]:max_coords[1], min_coords[2]:max_coords[2]]
    cropped_label = label[min_coords[0]:max_coords[0], min_coords[1]:max_coords[1], min_coords[2]:max_coords[2]]
    
    return np.array(cropped_image), np.array(cropped_label)

In [6]:
def resample_label(nii, label):
    
    label_header = nii.header
    
    target_shape = [int(label.shape[0] * label_header['pixdim'][1]), int(label.shape[1] * label_header['pixdim'][2]),\
                int(label.shape[2] * label_header['pixdim'][3])]
    
    
    label = torch.from_numpy( label.copy() ).unsqueeze(0).unsqueeze(0) # torch.Size([1, 1, 512, 512, 84])
    label = F.interpolate( label, target_shape, mode = "nearest").numpy()[0,0] # (342, 342, 63)
    
    return label

In [7]:
def resample_img(nii, image):
    
    image_header = nii.header
    
    target_shape = [int(image.shape[0] * image_header['pixdim'][1]), int(image.shape[1] * image_header['pixdim'][2]),\
                int(image.shape[2] * image_header['pixdim'][3])]
    
    
    image = torch.from_numpy( image.copy() ).unsqueeze(0).unsqueeze(0) # torch.Size([1, 1, 512, 512, 84])
    image = F.interpolate( image, target_shape, mode = "trilinear").numpy()[0,0] # (342, 342, 63)
    
    return image

In [8]:
for idx in range(1001, 1021):
    
    '''label'''
    label_filepath = os.path.join(folder_path + f'mr_train_{idx}_label.nii.gz')
    label_img_0 = nib.load(label_filepath)
    label_img = reorient_1(label_img_0)
    label_data = label_img.get_fdata(dtype=np.float32)
    label_data = resample_label(label_img_0, label_data)
    label_data = reset_label_value(label_data)
    
    
    image_filepath = os.path.join(folder_path + f'mr_train_{idx}_image.nii.gz')
    img_0 = nib.load(image_filepath)
    img = reorient_1(img_0)
    image_data = img.get_fdata(dtype=np.float32)
    print(image_data.shape)
    image_data = resample_img(img_0, image_data)
    
    assert label_data.shape == image_data.shape
    
    image_data , _ = corp_base_on_min_max_label(image_data , label_data)
    
    # Clip the top 2% of the intensity histogram
    percentile_98 = np.percentile(image_data.ravel(), 98)
    image_data = np.clip(image_data, a_min=None, a_max=percentile_98)
    
    # Subtract the mean and divide by the standard deviation
    mean_val = np.mean(image_data)
    std_val = np.std(image_data)
    image_data = (image_data - mean_val) / std_val
    
    print(image_data.shape)
    print('_'*26)
    np.save(f"{save_folder_path}mr_train_{idx}_image.npy", image_data)

(512, 160, 512)
(137, 56, 274)
__________________________
(512, 128, 512)
(308, 65, 295)
__________________________
(160, 288, 288)
(123, 85, 215)
__________________________
(120, 288, 288)
(87, 90, 186)
__________________________
(130, 288, 288)
(123, 87, 137)
__________________________
(160, 256, 256)
(171, 75, 127)
__________________________
(180, 288, 288)
(147, 69, 141)
__________________________
(130, 288, 288)
(113, 123, 149)
__________________________
(512, 120, 512)
(140, 43, 354)
__________________________
(160, 288, 288)
(138, 90, 220)
__________________________
(160, 288, 288)
(146, 126, 136)
__________________________
(512, 128, 512)
(131, 62, 316)
__________________________
(512, 112, 512)
(137, 48, 333)
__________________________
(512, 160, 512)
(115, 52, 258)
__________________________
(200, 340, 340)
(159, 129, 187)
__________________________
(130, 288, 288)
(100, 78, 149)
__________________________
(140, 288, 288)
(112, 80, 180)
__________________________
(150, 288, 2