In [1]:
import nibabel as nib
from glob import glob
import os
from tqdm import tqdm
import numpy as np
import random
import pandas as pd
from PIL import Image
import torchio as tio
from volumentations import *

In [2]:
def cube_pos(rand_pos, nod_pos):
    '''
    return 8 point of nodule

    args:   random point (x,y,z), nod.shape[0]
    output: [(x,y,z) point]*8
    '''
    
    pos_list = []
    pos_list.append((rand_pos[0], rand_pos[1], rand_pos[2]))
    pos_list.append((rand_pos[0], rand_pos[1], rand_pos[2]+nod_pos[2]))
    pos_list.append((rand_pos[0], rand_pos[1]+nod_pos[1], rand_pos[2]))
    pos_list.append((rand_pos[0], rand_pos[1]+nod_pos[1], rand_pos[2]+nod_pos[2]))
    pos_list.append((rand_pos[0]+nod_pos[0], rand_pos[1], rand_pos[2]))
    pos_list.append((rand_pos[0]+nod_pos[0], rand_pos[1], rand_pos[2]+nod_pos[2]))
    pos_list.append((rand_pos[0]+nod_pos[0], rand_pos[1]+nod_pos[1], rand_pos[2]))
    pos_list.append((rand_pos[0]+nod_pos[0], rand_pos[1]+nod_pos[1], rand_pos[2]+nod_pos[2]))
    return pos_list

def get_minmax(lung_LR):  
    '''
    return min, max 

    args: lung_LR
    output: array([[min],
                   [max]])
    '''
    result = []
    lung_wh = np.where(lung_LR != 0)
    result.append([np.min(lung_wh[0]), np.min(lung_wh[1]), np.min(lung_wh[2])])
    result.append([np.max(lung_wh[0]), np.max(lung_wh[1]), np.max(lung_wh[2])])
    np_result = np.array(result)
    return np_result

def validate_pos(rand_pos, nod_pos, lung_LR):
    '''
    validate pos 

    args:    rand_pos, nod_pos (nod.shape), lung_LR
    output:  True or False
    '''
    check = False
    cube_pos_ = cube_pos(rand_pos, nod_pos)

    for idx in range(len(cube_pos_)):
        if cube_pos_[idx] not in lung_LR:
            check = False
            break

        if idx == 7:
            check = True
    return check

def subtract_nodule(img, nod):
    '''
    get rid of nod 
    
    args:     img, nod
    output:   subtracted_img
    '''
    return np.where(nod!=0, np.min(img), img)

def nod_img(img, mask, nod, location):
    '''
    copy-paste 
    
    args:    img, nod, location to copy paste
    output:  nod_img, mask
    '''
    nod_img = img.copy()
    nod_mask = mask
    
    ns = nod.shape
    nod_mask[location[0]:location[0]+ns[0], location[1]:location[1]+ns[1], location[2]:location[2]+ns[2]] = np.where(nod!=0, nod, 0)

    nod_img = np.where(nod_mask>=10, nod_mask, img)
    
    return nod_img, nod_mask

def normalize255(arr):
    return np.array(((arr - np.min(arr)) / (np.max(arr) - np.min(arr)))*255).astype(np.uint8)

def make_lung(lung):
    '''
    return lung pos
    
    args:   lung_LR (numpy)
    output: [(x,y,z), ... ,(x,y,z)] 
    '''
    where_lung = np.where(lung != 0)
    new_lung = list(zip(where_lung[0].tolist(), where_lung[1].tolist(), where_lung[2].tolist()))
    return new_lung

In [15]:
# nodule transform (augmentation)

spatial = tio.OneOf({
        tio.RandomAffine(): 0.8,
        tio.RandomElasticDeformation(): 0.2,
        },
        p = 0.75
    )
lat_flip = tio.RandomFlip(axes=('AP',), flip_probability=0.5)
long_flip = tio.RandomFlip(axes=('inferior-superior',), flip_probability=0.5)
lr_flip = tio.RandomFlip(axes=('lr',), flip_probability=0.5)

transforms = [spatial, long_flip, lat_flip, lr_flip]
transform = tio.Compose(transforms)

def get_augmentation():
    return Compose([
        RandomRotate90((1, 2), p=0.5),
    ], p=1.0)
aug = get_augmentation()

def transform_nod(nod):
    nod_4d = np.expand_dims(nod, axis=0)                 # (x,y,z) -> (1,x,y,z)
    nod_4d = np.pad(nod_4d, ((0,0),(1,1),(1,1),(1,1)))   # padding
    transformed = transform(nod_4d)
    transformed_3d = transformed.squeeze()
    
    nod_transform = {'image': transformed_3d}
    aug_data = aug(**nod_transform)
    nod_transformed = aug_data['image']
    
    return nod_transformed

In [None]:
## if you need tqdm pop
while len(tqdm._instances) > 0:
    tqdm._instances.pop().close()

In [64]:
# img path (valid)
val_img_path = sorted(glob('../data/NII_normwinall_val_10mm_delO/*img*'))

## Data (mask O)
- copy paste X

In [67]:
# png

img_path = []
for idx in range(len(val_img_path)):
    mask_path = sorted(glob(val_img_path[idx].replace('img', 'mask*')))
    if len(mask_path) != 0:
        img_path.append(val_img_path[idx])
        
for idx in tqdm(range(len(img_path))):
    mask_path = sorted(glob(img_path[idx].replace('img', 'mask*')))
    fname = img_path[idx].split('/')[-1].replace('.nii.gz','.png')
    img = nib.load(img_path[idx]).get_fdata()
    
    mean_img = np.mean(img.T, axis=2)
    norm_img = normalize255(mean_img)
    Image.fromarray(norm_img).save(f'../data/fake_DRR_valid_10mm_delO/{fname}') 
    
    for mask_idx in range(len(mask_path)):
        mask_fname = mask_path[mask_idx].split('/')[-1].replace('.nii.gz','.png')
        
        mask = nib.load(mask_path[mask_idx]).get_fdata()
        mask_10 = np.where(mask>=10, mask, 0)
        mean_mask = np.mean(mask_10.T, axis=2)
        mean_mask[mean_mask!=0]=255
        norm_mask = normalize255(mean_mask)
        Image.fromarray(norm_mask).save(f'../data/fake_DRR_valid_10mm_delO/{mask_fname}')
    

100%|██████████| 78/78 [03:26<00:00,  2.64s/it]


# Data (mask X)
- copy paste O

In [None]:
img_list = []
empty_list = []

for idx in range(len(val_img_path)):
    mask_path = sorted(glob(val_img_path[idx].replace('img', 'mask*')))
    if len(mask_path) == 0:
        img_list.append(val_img_path[idx])
        

for idx in tqdm(range(len(img_list))):
    fname = '.'.join(img_list[idx].split('/')[-1].split('.')[:-2])
    if os.path.exists(f'../npy_lungLR/{fname}_Right.npy') ==False:
        empty_list.append(img_list[idx])
        
for idx in range(len(empty_list)):
    img_list.remove(empty_list[idx])

# nod list
nod_list = sorted(glob('../data/NII_nod_smooth10_10/valid/*.nii.gz'))
nod_len = len(nod_list)

In [None]:
# png
for dir_num in range(1,11):
    for num_ in tqdm(range(len(img_list))):
        rand_count = 0    # loop count
        mask_count = 0    # mask num
        copy_count = random.randint(1,3)
        img = nib.load(img_list[num_]).get_fdata()
        img_shape = img.shape

        mask = np.zeros(img_shape)
        unique_mask = np.zeros(img_shape)

        fname = '.'.join(img_list[num_].split('/')[-1].split('.')[:-2]) 
        lung_npy_R = np.load(f'../npy_lungLR/{fname}_Right.npy')
        lung_npy_L = np.load(f'../npy_lungLR/{fname}_Left.npy')  
        L_minmax = get_minmax(lung_npy_L) #lung_npy_L

        new_lung_L = make_lung(lung_npy_L)
        new_lung_R = make_lung(lung_npy_R)

        for idx in range(copy_count):
            rand_nod = random.randint(0, nod_len-1)    # get nod
            nod = nib.load(nod_list[rand_nod]).get_fdata()
            nod = transform_nod(nod)     #transform
            nod_pos = nod.shape

            LR = random.randint(0,1)         
            if LR == 0:
                lung_LR = new_lung_L
            else:
                lung_LR = new_lung_R

            len_lung_LR = len(lung_LR)

            while(1):
                if rand_count == 50:       # change left and right
                    if LR == 0:
                        LR = 1
                        lung_LR = new_lung_R
                    else:
                        LR = 0
                        lung_LR = new_lung_L

                    len_lung_LR = len(lung_LR)

                elif rand_count == 100:  # change nodule
                    rand_nod = random.randint(0, nod_len-1)              # get nod
                    nod = nib.load(nod_list[rand_nod]).get_fdata()  
                    nod = transform_nod(nod)
                    nod_pos = (nod.shape[0],nod.shape[1],nod.shape[2])   # nod_x,nod_y,nod_z = nod.shape
                    rand_count = 0

                rand_num = random.randint(0, len_lung_LR-1)
                rand_pos = lung_LR[rand_num]

                if (rand_pos[0]+nod_pos[0] > img_shape[0]) | (rand_pos[1]+nod_pos[1] > img_shape[1]) | (rand_pos[2]+nod_pos[2] > img_shape[2]):
                    rand_count+=1
                    continue

                overlap = np.sum(unique_mask[rand_pos[0]:rand_pos[0]+nod_pos[0], rand_pos[1]:rand_pos[1]+nod_pos[1], rand_pos[2]:rand_pos[2]+nod_pos[2]])
                if overlap != 0:
                    rand_count+=1
                    continue
                else:
                    valid = validate_pos(rand_pos, nod_pos, lung_LR)    # (validate pos)
                    if valid == False:
                        rand_count+=1
                        continue
                    else:
                        break
                rand_count+=1

            nod_pos_ = np.where(nod!=0)
            nod_unique = np.zeros((img_shape[0],nod_pos[1],nod_pos[2]))   

            np_xy = np.array(list(zip(nod_pos_[1], nod_pos_[2]))) 
            unique_xy = np.unique(np_xy, axis=0)

            cp_img, cp_mask = nod_img(img, mask, nod, rand_pos)  # copy-paste, save_mask -> mask

            for idx2 in range(nod_unique.shape[0]):
                for idx in range(len(unique_xy)):
                    nod_unique[idx2][unique_xy[idx][0]][unique_xy[idx][1]] = 255

            unique_mask[:, rand_pos[1]:rand_pos[1]+nod_pos[1], rand_pos[2]:rand_pos[2]+nod_pos[2]] = nod_unique

            ############### save mask png 
            save_mask = np.zeros(img_shape)
            save_mask[rand_pos[0]:rand_pos[0]+nod_pos[0], rand_pos[1]:rand_pos[1]+nod_pos[1], rand_pos[2]:rand_pos[2]+nod_pos[2]]\
                = np.where(nod>=10, nod, 0)
            mean_mask = np.mean(save_mask.T, axis=2)
            mean_mask[mean_mask!=0]=255
            norm_mask=normalize255(mean_mask)

            Image.fromarray(norm_mask).save(f'../data/fake_DRR_valid_10mm_delO/valid_{dir_num}/paste_drr_{str(num_+1).zfill(4)}_mask{mask_count}.png')    

            mask_count+=1


            ################# copy-paste nod 
            if LR == 0:                            
                lung_npy_L[:, rand_pos[1]:rand_pos[1]+nod_pos[1], rand_pos[2]:rand_pos[2]+nod_pos[2]] \
                = np.where(nod_unique!=0, 0, lung_npy_L[:, rand_pos[1]:rand_pos[1]+nod_pos[1], rand_pos[2]:rand_pos[2]+nod_pos[2]])
                new_lung_L = make_lung(lung_npy_L)
            else: #rand_pos[0]:rand_pos[0]+nod_pos[0]
                lung_npy_R[:, rand_pos[1]:rand_pos[1]+nod_pos[1], rand_pos[2]:rand_pos[2]+nod_pos[2]] \
                = np.where(nod_unique!=0, 0, lung_npy_R[:, rand_pos[1]:rand_pos[1]+nod_pos[1], rand_pos[2]:rand_pos[2]+nod_pos[2]])
                new_lung_R = make_lung(lung_npy_R)

            img = cp_img
            mask = cp_mask


        # img2png
        fimg = np.mean(img.T, axis=2)
        norm_fimg = normalize255(fimg)

        Image.fromarray(norm_fimg).save(f'../data/fake_DRR_valid_10mm_delO/valid_{dir_num}/paste_drr_{str(num_+1).zfill(4)}.png')