##### Import Image Registration Libraries

In [None]:
import json 
import cv2
from matplotlib import pyplot as plt
import numpy as np
from preprocess import *
import SimpleITK as sitk
from collections import OrderedDict
import os
from register_images_GNN import *
import sys
sys.path.insert(0, '../')
from parse_registration_json import ParserRegistrationJson
from parse_study_dict import ParserStudyDict
import time

### Load Input, Set and Create Output Dirs

In [None]:
####### INPUTS
json_path = "./jsonData/TCIA_FUSION.json"
preprocess_moving = True
preprocess_fixed = True
run_registration = True
extension = 'nii.gz'
timings = {}

try:
    with open('coord.txt') as f:
        coord = json.load(f)   
        
except:
    coord = {}

############### START REGISTRATION HERE

json_obj = ParserRegistrationJson(json_path)

studies = json_obj.studies
toProcess = json_obj.ToProcess
outputPath = json_obj.output_path
#cases = toProcess.keys()

if not os.path.isdir(outputPath):
    os.mkdir(outputPath) 

###### PREPROCESSING DESTINATIONS ######################################
preprocess_moving_dest = os.path.join(outputPath, 'preprocess', 'hist')
preprocess_fixed_dest = os.path.join(outputPath, 'preprocess', 'mri')

if not os.path.isdir(os.path.join(outputPath, 'preprocess')):
    os.mkdir(os.path.join(outputPath, 'preprocess'))

if not os.path.isdir(preprocess_moving_dest):
    os.mkdir(preprocess_moving_dest)
    
if not os.path.isdir(preprocess_fixed_dest):
    os.mkdir(preprocess_fixed_dest)

# print(json_obj.studies)

### Register Image and Save Output in nii.gz

In [None]:
# start doing preprocessing on each case and register
for s in json_obj.studies:
    if json_obj.ToProcess:
        if not (s in json_obj.ToProcess):
            print("Skipping", s)
            continue

    print("x"*30, "Processing", s,"x"*30)
    studyDict = json_obj.studies[s] 


    studyParser = ParserStudyDict(studyDict)
    
    sid = studyParser.id
    fixed_img_mha = studyParser.fixed_filename
    fixed_seg = studyParser.fixed_segmentation_filename
    moving_dict = studyParser.ReadMovingImage()

    ###### PREPROCESSING HISTOLOGY HERE #############################################################
    if preprocess_moving == True: 
        print('Preprocessing moving sid:', sid, '...')
        print(preprocess_moving_dest)
        preprocess_hist(moving_dict, preprocess_moving_dest, sid)
        print('Finished preprocessing', sid)

    ###### PREPROCESSING MRI HERE #############################################################
    if preprocess_fixed == True:
        print ("Preprocessing fixed case:", sid, '...')
        print(preprocess_fixed_dest)
        coord = preprocess_mri(fixed_img_mha, fixed_seg, preprocess_fixed_dest, coord, sid)

        print("Finished processing fixed mha", sid)

        with open('coord.txt', 'w') as json_file: 
            json.dump(coord, json_file)
    ##### ALIGNMENT HERE ########################################################################
    if run_registration == True: 
        
        ######## LOAD MODELS
        print('.'*30, 'Begin deep learning registration for ' + sid + '.'*30)

        try:
            model_cache
        except NameError:
            model_aff_path = './trained_models/best_CombinedLoss_affine_GNN_LR0.00045_segments100.pth.tar'
            model_tps_path = './trained_models/best_CombinedLoss_tps_GNN_LR0.00045_segments100.pth.tar'

            model_cache = load_models(model_aff_path, model_tps_path, do_deformable=True)
        
        start = time.time()
        output3D_cache = register(preprocess_moving_dest, preprocess_fixed_dest, coord, model_cache, sid)
        out3Dhist_highRes, out3Dmri_highRes, out3Dcancer_highRes, out3D_region00, out3D_region10, out3D_region09, out3Dmri_mask = output3D_cache

        # print(f"out3Dhist_highRes: {out3Dhist_highRes.shape}, " \
        #   f"out3Dmri_highRes: {out3Dmri_highRes.shape}, " \
        #   f"out3Dcancer_highRes: {out3Dcancer_highRes.shape}, " \
        #   f"out3D_region00: {out3D_region00.shape}, " \
        #   f"out3D_region10: {out3D_region10.shape}, " \
        #   f"out3D_region09: {out3D_region09.shape}, " \
        #   f"out3Dmri_mask: {out3Dmri_mask.shape}"
        # )
        end = time.time()
        print("Registration done in {:6.3f}(min)".format((end-start)/60.0))
        imMri = sitk.ReadImage(fixed_img_mha)
        mriOrigin = imMri[:,:,coord[sid]['slice'][0]:coord[sid]['slice'][-1]].GetOrigin()
        mriSpace  = imMri.GetSpacing()
        mriDirection = imMri.GetDirection()

        imSpatialInfo = (mriOrigin, mriSpace, mriDirection)

        # write output hist 3D volume to .nii.gz format
        fn_moving_highRes = '_moved_highres_rgb.'
        print('_moved_highres_rgb')
        output_results_high_res(preprocess_moving_dest,preprocess_fixed_dest,outputPath, out3Dhist_highRes, sid, fn_moving_highRes, imSpatialInfo, coord, imMri, extension = "nii.gz")

        #write output mri 3D volume to .nii.gz format
        fn_fixed_highRes = '_fixed_image.'
        print('_fixed_image')
        output_results(outputPath, out3Dmri_highRes, sid, fn_fixed_highRes, imSpatialInfo, extension = "nii.gz")

        #write output cancer outline 3D volume to .nii.gz format
        fn_cancer_highRes = '_moved_highres_region01_label.'
        print('_moved_highres_region01_label')
        output_results_high_res(preprocess_moving_dest,preprocess_fixed_dest,outputPath, out3Dcancer_highRes, sid, fn_cancer_highRes, imSpatialInfo, coord, imMri, extension = "nii.gz")
        
        #write region00
        fn_region00 = '_moved_highres_region00_label.'

        output_results_high_res(preprocess_moving_dest,preprocess_fixed_dest,outputPath, out3D_region00, sid, fn_region00, imSpatialInfo, coord, imMri, extension = "nii.gz")
        
        #write region10
        fn_region00 = '_moved_highres_region10_label.'
        output_results_high_res(preprocess_moving_dest,preprocess_fixed_dest,outputPath, out3D_region10, sid, fn_region00, imSpatialInfo, coord, imMri, extension = "nii.gz")
        
        #write region09
        fn_region00 = '_moved_highres_region09_label.'
        output_results_high_res(preprocess_moving_dest,preprocess_fixed_dest,outputPath, out3D_region09, sid, fn_region00, imSpatialInfo, coord, imMri, extension = "nii.gz")
        
        #write mriMask
        fn_mriMask = '_fixed_mask_label.'
        output_results(outputPath, out3Dmri_mask, sid, fn_mriMask, imSpatialInfo, extension = "nii.gz")

        timings[s] = (end-start)/60.0
        print('Done!')

In [None]:
json.dump(timings, open("timings_LR0_00045_seg_100.txt",'w'))

### Extract and Save output in png format from nii.gz

In [None]:
import SimpleITK as sitk
import numpy as np
import cv2
import os
from pathlib import Path

def nii_to_image(nii_file_path, output_dir, slice_index=None):
    """
    Convert a .nii.gz file to a set of images.

    Parameters:
    - nii_file_path: Path to the .nii.gz file.
    - output_dir: Directory where the images will be saved.
    - slice_index: Index of the slice to be saved as an image. Default is 0.
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    for f in os.listdir(nii_file_path):
        if not os.path.isdir(os.path.join(nii_file_path, f)):
            # Read the .nii.gz file
            sitk_image = sitk.ReadImage(os.path.join(nii_file_path, f))
            
            # Get the image as a NumPy array
            image_array = sitk.GetArrayFromImage(sitk_image)
            
            # Ensure the output directory exists
            os.makedirs(output_dir, exist_ok=True)
            
            # Save each slice as an image
            for i in range(image_array.shape[0]):
                slice_img = image_array[i, :, :]
                slice_img_normalized = cv2.normalize(slice_img, None, 0, 255, cv2.NORM_MINMAX)
                output_path = os.path.join(output_dir, f"{f.split('.')[0]}_slice_{i}.png")
                cv2.imwrite(output_path, slice_img_normalized)
                print(f"Saved slice {i} as {output_path}")

            # Optionally save a specific slice
            if slice_index is not None and isinstance(slice_index, int):
                if 0 <= slice_index < image_array.shape[0]:
                    specific_slice = image_array[slice_index, :, :]
                    specific_slice_normalized = cv2.normalize(specific_slice, None, 0, 255, cv2.NORM_MINMAX)
                    specific_output_path = os.path.join(output_dir, f"{f.split('.')[0]}_slice_{slice_index}.png")
                    cv2.imwrite(specific_output_path, specific_slice_normalized)
                    print(f"Saved specific slice {slice_index} as {specific_output_path}")


In [None]:
output_path = "./results_LR0_00045_seg100/aaa0069/images"

nii_to_image("./results_LR0_00045_seg100/aaa0069", output_path)

### Slice and View images using AVX2 feature of CPU for 7x faster slicing  

In [None]:
import numpy as np

# Much faster than the standard class
from fast_slic.avx2 import SlicAvx2
from PIL import Image
import os
import cv2
import matplotlib.pyplot as plt

os.chdir("/home/ubuntu/Document/ProsGraphNet/")
os.listdir()

img_path = "./datasets/datasets/training/mri_TCIA-0002_10.jpg"

# with Image.open() as f:
#    image = np.array(f)

num_of_seg = 60
compactness = 30

image = cv2.imread(img_path)   # You can convert the image to CIELAB space if you need.
slic = SlicAvx2(num_components=num_of_seg, compactness=compactness)
assignment = slic.iterate(image) # Cluster Map
# print(assignment)
# print(slic.slic_model.clusters) # The cluster information of superpixels.
plt.figure(figsize=(8, 8))
plt.imshow(assignment)
plt.title(f" MRI SLIC Segmentation:- {num_of_seg} segments, {compactness} compactness")
plt.axis('off')
plt.show()

### Create CSV from logged Training and Testing data

In [None]:
import pandas as pd
import os

In [None]:
training_data = []
names = ['Epoch', 'LR-Segments', 'Training Loss', 'Test Loss', 'Dice']
base_path = "./training_data/training_logs"
geometric_model = 'tps'
for fname in os.listdir(base_path):
    if geometric_model in fname:
        print(f"{geometric_model} Current File: {fname}")
        f_data = None
        
        if 'train' in fname:
            with open(os.path.join(base_path, fname)) as f:
                f_data = f.readlines()
                
            epoch = 1
            for line in f_data:
                if 'Average loss' in line:
                    training_data.append([epoch, f"{fname.split('_')[1].split('LR')[1]}-{fname.split('_')[4]}", float(line.split(' ')[-1].replace('\n', ''))])
                    epoch += 1

testing_data_x = []
testing_data_y = []
for fname in os.listdir(base_path):
    if geometric_model in fname:
        print(f"{geometric_model} Current File: {fname}")
        x = []
        y = []
        f_data = None
        
        if 'test' in fname:
            with open(os.path.join(base_path, fname)) as f:
                f_data = f.readlines()
                
            epoch = 1
            for idx, line in enumerate(f_data):
                if 'Average loss' in line:
                    testing_data_x.append([epoch, f"{fname.split('_')[1].split('LR')[1]}-{fname.split('_')[4]}", float(line.split(' ')[-1].replace('\n', ''))])
                    epoch += 1
            epoch = 1
            for idx, line in enumerate(f_data):
                if 'Dice' in line:
                    testing_data_y.append([epoch, f"{fname.split('_')[1].split('LR')[1]}-{fname.split('_')[4]}", float(line.split(' ')[-1].replace('\n', ''))])
                    epoch += 1


In [None]:
df1 = pd.DataFrame.from_records(training_data, columns=names[:3])
df2 = pd.DataFrame.from_records(testing_data_x, columns=[names[0], names[1], names[3]])
df3 = pd.DataFrame.from_records(testing_data_y, columns=[names[0], names[1], names[4]])

In [None]:
df_m1 = pd.merge(df1, df2, on=['Epoch', 'LR-Segments'], how='inner')
df_m2 = pd.merge(df_m1, df3, on=['Epoch', 'LR-Segments'], how='inner')

In [None]:
df_m2.to_csv("./training_data/training_testing_loss_dice_tps.csv", index=False)

### Overlay detected outputs and save as image in results

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os

img_dir = "./results_LR0_00045_seg100/aaa0069/images/"
imgs = os.listdir(img_dir)

grpd_imgs = []
for img in imgs:
    if int(img.split('.')[0].split('_')[-1]) == 0:
        if len(grpd_imgs) == 0:
            grpd_imgs.append([os.path.join(img_dir, img)])
        else:
            grpd_imgs[0].append(os.path.join(img_dir, img))
    elif int(img.split('.')[0].split('_')[-1]) == 1:
        if len(grpd_imgs) == 0:
            grpd_imgs.append([])
            grpd_imgs.append([os.path.join(img_dir, img)])
        elif len(grpd_imgs) == 1:
            grpd_imgs.append([os.path.join(img_dir, img)])
        else:
            grpd_imgs[1].append(os.path.join(img_dir, img))
    else:
        if len(grpd_imgs) == 0:
            grpd_imgs.append([])
            grpd_imgs.append([])
            grpd_imgs.append([os.path.join(img_dir, img)])
        elif len(grpd_imgs) == 1:
            grpd_imgs.append([])
            grpd_imgs.append([os.path.join(img_dir, img)])
        elif len(grpd_imgs) == 2:
            grpd_imgs.append([os.path.join(img_dir, img)])
        else:
            grpd_imgs[2].append(os.path.join(img_dir, img))


# # Load the images
for i, grp in enumerate(grpd_imgs):
    mri_img_name = [n for n in grp if 'image_slice' in n][0]
    mask1_name = [n for n in grp if 'mask_label_slice' in n][0]
    mask2_name = [n for n in grp if 'region00' in n][0]
    mask3_name = [n for n in grp if 'region01' in n][0]
    mask4_name = [n for n in grp if 'region09' in n][0]
    mask5_name = [n for n in grp if 'region10' in n][0]
    
    mri_image = cv2.imread(mri_img_name)
    overlay_image = mri_image.copy()
    mask1 = cv2.imread(mask1_name, cv2.IMREAD_GRAYSCALE)
    mask2 = cv2.imread(mask2_name, cv2.IMREAD_GRAYSCALE)
    mask3 = cv2.imread(mask3_name, cv2.IMREAD_GRAYSCALE)
    mask4 = cv2.imread(mask4_name, cv2.IMREAD_GRAYSCALE)
    mask5 = cv2.imread(mask5_name, cv2.IMREAD_GRAYSCALE)

    mask2 = cv2.resize(mask2, (320,320), cv2.INTER_AREA)
    mask3 = cv2.resize(mask3, (320,320), cv2.INTER_AREA)
    mask4 = cv2.resize(mask4, (320,320), cv2.INTER_AREA)
    mask5 = cv2.resize(mask5, (320,320), cv2.INTER_AREA)

    contours = []
    # Detect contours
    contours1, _ = cv2.findContours(mask1, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours2, _ = cv2.findContours(mask2, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours3, _ = cv2.findContours(mask3, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours4, _ = cv2.findContours(mask4, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours5, _ = cv2.findContours(mask5, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours = [(contours2, (255, 0, 255)), (contours3, (0, 0, 254)), (contours4, (0, 0, 255)), (contours5, (0, 0, 255))]

    # Create a blank canvas to draw contours on
    # contour_image = np.zeros_like(mask1, dtype=np.uint8)  # Create a black canvas

    # draw base image label slice
    cv2.drawContours(overlay_image, contours1, -1, (0,255, 0), 1)

    for contour in contours:
        shifted_contours = []
        for c in contour[0]:
            shifted_contour = c - [15, 15]
            shifted_contours.append(shifted_contour)

        # Draw the contours on the black canvas
        cv2.drawContours(overlay_image, shifted_contours, -1, contour[1], 1)  # White contours for the first mask


    # Display the contours in grayscale
    # plt.figure(figsize=(6, 6))
    # plt.imshow(contour_image, cmap='gray')
    # plt.axis('off')
    # plt.show()

    # Display the final image
    plt.figure(figsize=(6, 6))
    plt.imshow(cv2.cvtColor(overlay_image, cv2.COLOR_BGR2RGB))
    plt.axis('off')
    plt.show()

    cv2.imwrite(os.path.join(img_dir, f'aaa0069_LR0_00035_seg_80_slice_{i}.png'), overlay_image)
