In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Palatino"],
})
%config InlineBackend.figure_format = "svg"
from pathlib import Path
import nibabel as nib
import cv2

In [3]:
models_path = Path.cwd() / 'results/swap_channel_spatial'
print(models_path)
model = list(models_path.iterdir())
print(model)
print(len(model))
model = model[1:]
print(model)
print(len(model))

/home/students/arimond/EffFormer/results/224x224_head4
[PosixPath('/home/students/arimond/EffFormer/results/224x224_head4/transfilm_epoch_359.pth'), PosixPath('/home/students/arimond/EffFormer/results/224x224_head4/log'), PosixPath('/home/students/arimond/EffFormer/results/224x224_head4/transfilm_epoch_319.pth'), PosixPath('/home/students/arimond/EffFormer/results/224x224_head4/transfilm_epoch_399.pth'), PosixPath('/home/students/arimond/EffFormer/results/224x224_head4/test'), PosixPath('/home/students/arimond/EffFormer/results/224x224_head4/transfilm_epoch_299.pth'), PosixPath('/home/students/arimond/EffFormer/results/224x224_head4/transfilm_epoch_259.pth'), PosixPath('/home/students/arimond/EffFormer/results/224x224_head4/log.txt'), PosixPath('/home/students/arimond/EffFormer/results/224x224_head4/transfilm_epoch_379.pth'), PosixPath('/home/students/arimond/EffFormer/results/224x224_head4/transfilm_epoch_239.pth'), PosixPath('/home/students/arimond/EffFormer/results/224x224_head4/tra

In [7]:
test_cases = ['case0001', 'case0002', 'case0003', 'case0004', 'case0008', 'case0022',
             'case0025', 'case0029', 'case0032', 'case0035', 'case0036', 'case0038']

In [8]:
img_names = [i + '_img.nii.gz' for i in test_cases]
gt_names = [i + '_gt.nii.gz' for i in test_cases]
pred_names = [i + '_pred.nii.gz' for i in test_cases]
print(gt_names)

['case0001_gt.nii.gz', 'case0002_gt.nii.gz', 'case0003_gt.nii.gz', 'case0004_gt.nii.gz', 'case0008_gt.nii.gz', 'case0022_gt.nii.gz', 'case0025_gt.nii.gz', 'case0029_gt.nii.gz', 'case0032_gt.nii.gz', 'case0035_gt.nii.gz', 'case0036_gt.nii.gz', 'case0038_gt.nii.gz']


In [10]:
def select_best_slice(gt,class_num):
    best_slices = [sl for sl in range(s1_gt.shape[-1]) if len(np.unique(gt[...,sl])) == class_num]
    return best_slices

In [14]:
case_idx = [5, 11]

s1_img = nib.load(model[0] / img_names[case_idx[0]]).get_fdata()
s1_gt = nib.load(model[0] / gt_names[case_idx[0]]).get_fdata()

s2_img = nib.load(model[0] / img_names[case_idx[1]]).get_fdata()
s2_gt = nib.load(model[0] / gt_names[case_idx[1]]).get_fdata()


class_num = 9
s1_best_slices = select_best_slice(s1_gt, class_num)
print(f'best slices for case {test_cases[case_idx[0]]} is: {s1_best_slices}')

s2_best_slices = select_best_slice(s2_gt, class_num)
print(f'best slices for case {test_cases[case_idx[1]]} is: {s2_best_slices}')

Idx1, Idx2 = s1_best_slices[0], s2_best_slices[0]

mask_color = {"1":(0, 0, 255),
                  "2":(0, 255, 0),
                  "3":(255, 0, 0),
                  "4":(0, 255, 255),
                  "5":(255, 0, 255),
                  "6":(255, 255, 0),
                  "7":(63, 208, 244),
                  "8":(241, 240, 234)}
label_name = {"1":"aorta",
              "2":"gallbladder",
              "3":"left kidney",
              "4":"right kidney",
              "5":"liver",
              "6":"pancreas",
              "7":"spleen",
              "8":"stomach"}

COLORS = ["#0000FF","#00FF00","#FF0000","#00FFFF","#FF00FF","#FFFF00","#3FD0F4","#F1F0EA"]
SPECIES_ = ['aorta','gallbladder','left kidney','right kidney','liver','pancreas','spleen','stomach']
handles = [
    Patch(facecolor=color, edgecolor="k", label=label, alpha=1)
    for label, color in zip(SPECIES_, COLORS)
]


if len(s1_best_slices) != 0 and len(s2_best_slices) != 0:
    

    s1_img = cv2.cvtColor(s1_img[...,Idx1].astype('float32'),cv2.COLOR_GRAY2RGB)
    s1_img = cv2.normalize(s1_img, None, alpha = 0, beta = 255, norm_type = cv2.NORM_MINMAX, dtype = cv2.CV_32F)
    s1_img = s1_img.astype(np.uint8)

    s2_img = cv2.cvtColor(s2_img[...,Idx2].astype('float32'),cv2.COLOR_GRAY2RGB)
    s2_img = cv2.normalize(s2_img, None, alpha = 0, beta = 255, norm_type = cv2.NORM_MINMAX, dtype = cv2.CV_32F)
    s2_img = s2_img.astype(np.uint8)


    # preds = []

    s1_pred_0 = nib.load(model[0] / pred_names[case_idx[0]]).get_fdata() # Unet
    s1_pred_1 = nib.load(model[1] / pred_names[case_idx[0]]).get_fdata() # TransUnet
    s1_pred_2 = nib.load(model[2] / pred_names[case_idx[0]]).get_fdata() # Swin Unet
    s1_pred_3 = nib.load(model[3] / pred_names[case_idx[0]]).get_fdata() # proposed method

    s2_pred_0 = nib.load(model[0] / pred_names[case_idx[1]]).get_fdata() # Unet
    s2_pred_1 = nib.load(model[1] / pred_names[case_idx[1]]).get_fdata() # TransUnet
    s2_pred_2 = nib.load(model[2] / pred_names[case_idx[1]]).get_fdata() # Swin Unet
    s2_pred_3 = nib.load(model[3] / pred_names[case_idx[1]]).get_fdata() # proposed method

    ll = [s1_gt, s1_pred_0, s1_pred_1, s2_gt, s2_pred_0, s2_pred_1]



    fig, axes = plt.subplots(2, 3, figsize= (20,15))
    fig.legend(ncol=8,handles=handles, handlelength=4, handleheight=4,loc='upper center',frameon=False,fontsize = 12,bbox_to_anchor=[0.5, 0.95])
    
    s1_ll = [s1_gt, s1_pred_0, s1_pred_1, s1_pred_2, s1_pred_3]
    s2_ll = [s2_gt, s2_pred_0, s2_pred_1, s2_pred_2, s2_pred_3]

    ll_prime = []

    img_hat = s1_img.copy()
    for i in range(len(s1_ll)):
        print(f"i: {i}")
        Idx = Idx1
        img_hat = s1_img.copy()
        print(f"Idx: {Idx}")
        for j in range(1,9):
                img_hat[np.where(s1_ll[i][...,Idx] == float(j))] = mask_color[str(j)]
        ll_prime.append(np.fliplr(img_hat))

    img_hat = s1_img.copy()
    for i in range(len(s2_ll)):
        print(f"i: {i}")
        Idx = Idx2
        img_hat = s2_img.copy()
        print(f"Idx: {Idx}")
        for j in range(1,9):
                img_hat[np.where(s2_ll[i][...,Idx] == float(j))] = mask_color[str(j)]
        ll_prime.append(np.fliplr(img_hat))


    fig, axes = plt.subplots(2, 5, figsize= (20,15))
    fig.legend(ncol=8,handles=handles, handlelength=4, handleheight=4,loc='upper center',frameon=False,fontsize = 12,bbox_to_anchor=[0.5, 0.95])
    j = 0
    for row in range(2):
        for col in range(len(ll_prime)//2):
            axes[row][col].imshow(ll_prime[j])
            axes[row][col].axis("off")
            j += 1
            if row == 1 and col == 0:
                axes[row][col].text(0.5,-0.080, "(a) Ground Truth", size=14, ha="center", transform=axes[row][col].transAxes,weight="bold")
            elif row == 1 and col == 1:
                axes[row][col].text(0.5,-0.080, "(b) U-Net", size=14, ha="center", transform=axes[row][col].transAxes,weight="bold")
            elif row == 1 and col == 2 :
                axes[row][col].text(0.5,-0.080, "(c) Trans U-Net", size=14, ha="center", transform=axes[row][col].transAxes,weight="bold")
            elif row == 1 and col == 3:
                axes[row][col].text(0.5,-0.080, "(c) Swin U-Net", size=14, ha="center", transform=axes[row][col].transAxes,weight="bold")
            elif row ==1 and col == 4:
                axes[row][col].text(0.5,-0.080, "(c) Proposed method", size=14, ha="center", transform=axes[row][col].transAxes,weight="bold")


    fig.subplots_adjust(wspace=0.05, hspace=0.05)
else:
    print(f'There is not slices with {class_num} classes. Select other cases!')

best slices for case case0022 is: [61, 62, 63, 64, 65, 66, 67, 68]
best slices for case case0038 is: [69, 70, 71, 72]
i: 0
Idx: 61
i: 1
Idx: 61
i: 2
Idx: 61
i: 3
Idx: 61
i: 4
Idx: 61
i: 0
Idx: 69
i: 1
Idx: 69
i: 2
Idx: 69
i: 3
Idx: 69
i: 4
Idx: 69


In [13]:
fig.savefig("synapsevisualization.png",dpi=300)