In [1]:
import torch
import torch.nn.functional as F
from  reorient_nii import reorient_1
import os
import numpy as np
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import nibabel as nib
import nilearn.plotting as nlplt
from nibabel.testing import data_path



In [3]:
folder_path1 = 'ct_train1/'
folder_path2 = 'ct_train2/'
save_folder_path = 'ct_npy/'

In [4]:
def reset_label_value(image_data):
    
    image_data[image_data == 820] = 4 # AA  # small isolated
    image_data[image_data == 500] = 3 # LV  # center 
    image_data[image_data == 420] = 2 # LA  # long tail 
    image_data[image_data == 205] = 1 # Myo # blue semi-cicle close to red   

    image_data[image_data == 550] = 0
    image_data[image_data == 600] = 0
    image_data[image_data == 850] = 0
    
    return image_data

In [5]:
def corp_base_on_min_max_label(image_data):

    heart_indices = np.where(image_data != 0)
    min_coords = np.min(heart_indices, axis=1)
    max_coords = np.max(heart_indices, axis=1)
    
    cropped_volume = image_data[min_coords[0]:max_coords[0], min_coords[1]:max_coords[1], min_coords[2]:max_coords[2]]
    
    
    return np.array(cropped_volume)

In [6]:
def resample(img, image_data):
    
    image_header = img.header
    target_shape = [256, 256, 256]
    
    image_data = torch.from_numpy( image_data.copy() ).unsqueeze(0).unsqueeze(0) # torch.Size([1, 1, 512, 512, 84])
    image_data = F.interpolate( image_data, target_shape, mode = "nearest").numpy()[0,0] # (342, 342, 63)
    
    return image_data

In [7]:
def resample_2(img, image_data):
    
    image_header = img.header
    
    target_shape = [int(image_data.shape[0] * image_header['pixdim'][1]), int(image_data.shape[1] * image_header['pixdim'][2]),\
                int(image_data.shape[2] * image_header['pixdim'][3])]
    
    
    image_data = torch.from_numpy( image_data.copy() ).unsqueeze(0).unsqueeze(0) # torch.Size([1, 1, 512, 512, 84])
    image_data = F.interpolate( image_data, target_shape, mode = "nearest").numpy()[0,0] # (342, 342, 63)
    
    return image_data

In [8]:
for idx in range(1001, 1011):
    
    filepath = os.path.join(folder_path1 + f'ct_train_{idx}_label.nii.gz')
    img_0 = nib.load(filepath)
    img = reorient_1(img_0)
    image_data = img.get_fdata(dtype=np.float32)
    image_data = resample_2(img_0, image_data)
    image_data = reset_label_value(image_data)
    image_data = corp_base_on_min_max_label(image_data)
    
    print(image_data.shape)
    print(np.unique(image_data))
    #np.save(f"{save_folder_path}ct_train_{idx}_label.npy", image_data)

(133, 107, 123)
[0. 1. 2. 3. 4.]
(95, 95, 113)
[0. 1. 2. 3. 4.]
(121, 84, 109)
[0. 1. 2. 3. 4.]
(90, 93, 111)
[0. 1. 2. 3. 4.]
(124, 117, 94)
[0. 1. 2. 3. 4.]
(108, 88, 116)
[0. 1. 2. 3. 4.]
(113, 94, 97)
[0. 1. 2. 3. 4.]
(102, 104, 99)
[0. 1. 2. 3. 4.]
(106, 113, 139)
[0. 1. 2. 3. 4.]
(119, 107, 118)
[0. 1. 2. 3. 4.]


In [9]:
for idx in range(1011, 1021):
    
    filepath = os.path.join(folder_path2 + f'ct_train_{idx}_label.nii.gz')
    img_0 = nib.load(filepath)
    img = reorient_1(img_0)
    image_data = img.get_fdata(dtype=np.float32)
    image_data = resample_2(img_0, image_data)
    image_data = reset_label_value(image_data)
    image_data = corp_base_on_min_max_label(image_data)
    
    print(image_data.shape)
    print(np.unique(image_data))
    #np.save(f"{save_folder_path}ct_train_{idx}_label.npy", image_data)

(105, 117, 133)
[0. 1. 2. 3. 4.]
(118, 118, 98)
[0. 1. 2. 3. 4.]
(97, 129, 111)
[0. 1. 2. 3. 4.]
(87, 86, 159)
[0. 1. 2. 3. 4.]
(108, 117, 126)
[0. 1. 2. 3. 4.]
(108, 98, 119)
[0. 1. 2. 3. 4.]
(119, 113, 128)
[0. 1. 2. 3. 4.]
(115, 94, 100)
[0. 1. 2. 3. 4.]
(101, 117, 124)
[0. 1. 2. 3. 4.]
(133, 108, 127)
[0. 1. 2. 3. 4.]
