In [1]:
import torch
import torch.nn.functional as F
from  reorient_nii import reorient_1
import os
import numpy as np
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import nibabel as nib
import nilearn.plotting as nlplt
from nibabel.testing import data_path



In [3]:
folder_path1 = 'ct_train1/'
folder_path2 = 'ct_train2/'
save_folder_path = 'ct_npy/'

In [4]:
def reset_label_value(label):
    
    label[label == 820] = 4 # AA  # small isolated
    label[label == 500] = 3 # LV  # center 
    label[label == 420] = 2 # LA  # long tail 
    label[label == 205] = 1 # Myo # blue semi-cicle close to red   

    label[label == 550] = 0
    label[label == 600] = 0
    label[label == 850] = 0
    
    return label

In [5]:
def corp_base_on_min_max_label(image , label):

    heart_indices = np.where(label != 0)
    min_coords = np.min(heart_indices, axis=1)
    max_coords = np.max(heart_indices, axis=1)
    
    cropped_image = image[min_coords[0]:max_coords[0], min_coords[1]:max_coords[1], min_coords[2]:max_coords[2]]
    cropped_label = label[min_coords[0]:max_coords[0], min_coords[1]:max_coords[1], min_coords[2]:max_coords[2]]
    
    return np.array(cropped_image), np.array(cropped_label)

In [6]:
def resample_label(nii, label):
    
    label_header = nii.header
    
    target_shape = [int(label.shape[0] * label_header['pixdim'][1]), int(label.shape[1] * label_header['pixdim'][2]),\
                int(label.shape[2] * label_header['pixdim'][3])]
    
    
    label = torch.from_numpy( label.copy() ).unsqueeze(0).unsqueeze(0) # torch.Size([1, 1, 512, 512, 84])
    label = F.interpolate( label, target_shape, mode = "nearest").numpy()[0,0] # (342, 342, 63)
    
    return label

In [7]:
def resample_img(nii, image):
    
    image_header = nii.header
    
    target_shape = [int(image.shape[0] * image_header['pixdim'][1]), int(image.shape[1] * image_header['pixdim'][2]),\
                int(image.shape[2] * image_header['pixdim'][3])]
    
    
    image = torch.from_numpy( image.copy() ).unsqueeze(0).unsqueeze(0) # torch.Size([1, 1, 512, 512, 84])
    image = F.interpolate( image, target_shape, mode = "trilinear").numpy()[0,0] # (342, 342, 63)
    
    return image

In [8]:
for idx in range(1001, 1011):
    
    '''label'''
    label_filepath = os.path.join(folder_path1 + f'ct_train_{idx}_label.nii.gz')
    label_img_0 = nib.load(label_filepath)
    label_img = reorient_1(label_img_0)
    label_data = label_img.get_fdata(dtype=np.float32)
    label_data = resample_label(label_img_0, label_data)
    label_data = reset_label_value(label_data)
    
    '''image'''
    image_filepath = os.path.join(folder_path1 + f'ct_train_{idx}_image.nii.gz')
    img_0 = nib.load(image_filepath)
    img = reorient_1(img_0)
    image_data = img.get_fdata(dtype=np.float32)
    image_data = resample_img(img_0, image_data)
    
    assert label_data.shape == image_data.shape
    
    image_data , _ = corp_base_on_min_max_label(image_data , label_data)
    
    # Clip the top 2% of the intensity histogram
    percentile_98 = np.percentile(image_data.ravel(), 98)
    image_data = np.clip(image_data, a_min=None, a_max=percentile_98)
    
    # Subtract the mean and divide by the standard deviation
    mean_val = np.mean(image_data)
    std_val = np.std(image_data)
    image_data = (image_data - mean_val) / std_val
    
    print(image_data.shape)
    np.save(f"{save_folder_path}ct_train_{idx}_image.npy", image_data)

(133, 107, 123)
(95, 95, 113)
(121, 84, 109)
(90, 93, 111)
(124, 117, 94)
(108, 88, 116)
(113, 94, 97)
(102, 104, 99)
(106, 113, 139)
(119, 107, 118)


In [9]:
for idx in range(1011, 1021):
    
    '''label'''
    label_filepath = os.path.join(folder_path2 + f'ct_train_{idx}_label.nii.gz')
    label_img_0 = nib.load(label_filepath)
    label_img = reorient_1(label_img_0)
    label_data = label_img.get_fdata(dtype=np.float32)
    label_data = resample_label(label_img_0, label_data)
    label_data = reset_label_value(label_data)
    
    '''image'''
    image_filepath = os.path.join(folder_path2 + f'ct_train_{idx}_image.nii.gz')
    img_0 = nib.load(image_filepath)
    img = reorient_1(img_0)
    image_data = img.get_fdata(dtype=np.float32)
    image_data = resample_img(img_0, image_data)
    
    assert label_data.shape == image_data.shape
    
    image_data , _ = corp_base_on_min_max_label(image_data , label_data)
    
    # Clip the top 2% of the intensity histogram
    percentile_98 = np.percentile(image_data.ravel(), 98)
    image_data = np.clip(image_data, a_min=None, a_max=percentile_98)
    
    # Subtract the mean and divide by the standard deviation
    mean_val = np.mean(image_data)
    std_val = np.std(image_data)
    image_data = (image_data - mean_val) / std_val
    
    print(image_data.shape)
    np.save(f"{save_folder_path}ct_train_{idx}_image.npy", image_data)

(105, 117, 133)
(118, 118, 98)
(97, 129, 111)
(87, 86, 159)
(108, 117, 126)
(108, 98, 119)
(119, 113, 128)
(115, 94, 100)
(101, 117, 124)
(133, 108, 127)
