# Investigation of tampered areas in images from LuNoTim dataset

### Functions

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pydicom as dicom
from os.path import join
from skimage import measure
import pylidc as pl
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Rectangle
import cv2
import glob

def get_bbox_all(mask):
    
    bbox = None    
    mask_img = mask.astype(np.uint8)
    
    if mask_img[:].sum():    
        mask_img*=255    
        kernel = np.ones((11,11), 'uint8')
        mask_img = cv2.morphologyEx(mask_img, cv2.MORPH_DILATE, kernel, iterations=1)
        cnts, _ = cv2.findContours(mask_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)    
        bboxes = np.array([ cv2.boundingRect(cnt) for cnt in cnts ])                
        x = bboxes[:,0].min()
        y = bboxes[:,1].min()
        w = (bboxes[:,0] + bboxes[:,2]).max() - x
        h = (bboxes[:,1] + bboxes[:,3]).max() - y
        bbox = x,y,w,h
    
    return bbox


def crop_image(im, bb):
    return im[bb[1]:bb[1]+bb[3],bb[0]:bb[0]+bb[2]]


def find_sequences(folder):
    
    ids = get_list_ids_for_masks(folder)    
    list_seqs = []    
    curr_seq = []
    curr_seq.append(ids[0])
    
    for curr_id in ids[1:]:

        if curr_id == curr_seq[-1] + 1:            
            curr_seq.append(curr_id)
        else:            
            list_seqs.append(curr_seq)            
            curr_seq = []
            curr_seq.append(curr_id)            
    list_seqs.append(curr_seq)
    
    return list_seqs


def get_list_ids_for_masks(folder):
    folder_list = glob.glob(join(folder,'**_mask.npy'))
    list_ids = np.sort(np.array([ int((fn.split('/')[-1]).split('_')[0]) for fn in folder_list]))
    return list_ids


def get_mask(id):
    return np.load(join(folder_root,'%06d_mask.npy'%id))


def get_image(id):
    return np.load(join(folder_root,'%06d.npy'%id))


def get_rectangle(bbox):
    return Rectangle((bbox[0],bbox[1]), bbox[2], bbox[3], fc ='none', ec ='r',lw = 1)


def plot_prix(seq, scan):
    
    # load images and masks in the sequence of tampered images
    masks  = [ get_mask(i) for i in seq]
    images = [ get_image(i) for i in seq]

    # load prinstine images
    LIDC_imgs = scan.load_all_dicom_images()    
    
    # plot
    fontsize = 28
    for i, (mask, im) in enumerate(zip(masks,images)):
        
        id_slice = seq[i]
        bbox = get_bbox_all(mask)
        pristine =  LIDC_imgs[seq[i]-1].pixel_array
        
        if bbox:
            fig,ax = plt.subplots(1,5,figsize=(25,10))

            ax[0].imshow(im, cmap=plt.cm.gray)
            ax[0].add_patch(get_rectangle(bbox))
            ax[0].set_title(('Tampered image - slice [%d]' % id_slice), fontsize=fontsize)
            ax[0].axis('off')

            ax[1].imshow(mask, cmap=plt.cm.gray)
            ax[1].add_patch(get_rectangle(bbox))
            ax[1].set_title('Mask for tampering', fontsize=fontsize)
            ax[1].axis('off')        

            ax[2].imshow(crop_image(mask, bbox), cmap=plt.cm.gray)
            ax[2].set_title('Mask for tampering',fontsize=fontsize)
            ax[2].axis('off')

            ax[3].imshow(crop_image(im, bbox), cmap=plt.cm.gray)
            ax[3].set_title('Tampered image', fontsize=fontsize)
            ax[3].axis('off')

            ax[4].imshow(crop_image(pristine, bbox), cmap=plt.cm.gray)
            ax[4].set_title('Pristine image', fontsize=fontsize)
            ax[4].axis('off')

            plt.tight_layout()
            plt.show()
        
        else:
            fig,ax = plt.subplots(1,3,figsize=(25,10))

            ax[0].imshow(im, cmap=plt.cm.gray)
            ax[0].set_title(('Tampered image - slice [%d]' % id_slice), fontsize=fontsize)
            ax[0].axis('off')

            ax[1].imshow(mask, cmap=plt.cm.gray)    
            ax[1].set_title('Null Mask for tampering', fontsize=fontsize)
            ax[1].axis('off')
            
            ax[2].imshow(pristine, cmap=plt.cm.gray)
            ax[2].set_title('Pristine image', fontsize=fontsize)
            ax[2].axis('off')
            
            plt.tight_layout()
            plt.show()
            
            
# query the patient CT scan
def get_LIDC_scan(id_patient):
    
    pid = 'LIDC-IDRI-'+ id_patient
    scan = pl.query(pl.Scan).filter(pl.Scan.patient_id == pid).first()
    return scan


def get_nodules_info(scan):
    
    nods = scan.cluster_annotations()
    print("%s has %d nodules." % (scan, len(nods)))

    for id_n, nod_annots in enumerate(nods):
        #print("Nodule %d has %d annotations - %d masks" % (id_n+1, len(nods[id_n]), count[str(id_n)]))
        print("Nodule %d has %d annotations" % (id_n+1, len(nods[id_n])))

# Showing tampering and original regions

In [None]:
## id patient
id_patient = '0631'

# 'added_ct_gan_inpainting' '0144'

## Type of tampering
type_folder = {0:'added_inner_tissue_different_slice',
               1:'added_inner_tissue_same_slice',
               2:'added_outer_tissue_different_slice',
               3:'added_outer_tissue_same_slice',
               4:'added_ct_gan_inpainting',
               5:'removed_patchmatch_guided_inpainting',
               6:'removed_simple_inpainting',
               7:'removed_ct_gan_inpainting',
              }

id_type_tp = 4

if id_type_tp in [0,1,2,3,4]:
    ## root_folder for added
    tampering_name = type_folder[id_type_tp]
    folder_root = '../%s/LIDC-IDRI/LIDC-IDRI-%s/0/' % (type_folder[id_type_tp], id_patient)
else:
    ## root_folder for removed
    tampering_name = type_folder[id_type_tp]
    folder_root = '../%s/LIDC-IDRI/LIDC-IDRI-%s/' % (tampering_name, id_patient)


# Original CT scan from the LIDC-IDRI dataset
lidc_scan = get_LIDC_scan(id_patient)


# Find sequence of id slices that were tampered in the LuNoTim
tp_id_seqs = find_sequences(folder_root)


# show original, tampered images and its masks
for i, tp_id_seq in enumerate(tp_id_seqs):
    print("\n\n---------------------------------------------------------------------------")
    print("[%s] - Patient [%s] - Seq [%d]" % (tampering_name, id_patient, i))
    print("---------------------------------------------------------------------------")
    plot_prix(tp_id_seq, lidc_scan)
    