In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
import os
import nibabel as nib
import glob
import shutil
from matplotlib import pyplot as plt
import numpy as np

In [None]:
from plot import draw_evaluate, four_in_all
draw_evaluate('data/results/last/Stats_Training_final.csv','log/training_figs', val=False)
draw_evaluate('data/results/last/Stats_Validation_final.csv','log/val_figs')
four_in_all('log/training_figs')
four_in_all('log/val_figs')

In [None]:
'''
PLOT SEARCHING LOG
'''
import torch

fig, ax = plt.subplots(figsize=(5*0.8,4*0.8))
fig.subplots_adjust(left=0.15, right=0.95, bottom=0.15, top=0.95,)
                        
sd = torch.load('log/last_search.pt')
his = sd['history']
shell_loss = his['shell_loss']
kernel_loss = his['kernel_loss']
val_loss = his['val_loss']
n = range(len(shell_loss))
plt.plot(n, shell_loss, label='Hybrid Loss',linewidth=3)
plt.plot(n, kernel_loss, label='Kernel Loss', linewidth=3)
plt.xlabel('Epochs', size=15,labelpad=0)
plt.ylabel('Weighted Dice Loss', size=15, labelpad=0)
plt.tick_params(labelsize=14,direction='in',pad=2.5)
plt.xlim(-2,60)
plt.ylim(0.48,0.73)
ax.vlines(56,0,1,colors='r',linestyles='--',linewidth=1.2)
ax.legend(fontsize=14, framealpha=1)
plt.grid(alpha=0.5)
ax.set_xticks(list(np.arange(0,52,10))+[56])
ax.set_xticklabels([0,'',20,'',40,'',56])
ax.set_yticks(list(np.arange(0.5,0.74,0.05)))
ax.set_yticklabels([0.5,'',0.6,'',0.7,])
# plt.savefig('log/searching_log.png',dpi=200)

In [None]:
'''
PLOT DC & UC
'''
import pickle
from genotype import Genotype
from plot import plot_cell,plot_searched_cell,plot_ops

plot_cell('log/dc')
plot_cell('log/uc', dc=False)

with open('log/best_genotype.pkl','rb') as f:
    g = eval(pickle.load(f)[0])
plot_searched_cell(g.down, 'log/searched_dc')
plot_searched_cell(g.up, 'log/searched_uc', dc=False)

In [None]:
'''
FIG.~1
'''

FONTSIZE=15
plt.figure()
fig, axs = plt.subplots(3, 5,figsize=(15,15))

plt.subplots_adjust(left=None, bottom=None, right=None, top=None,
                wspace=0, hspace=0.05)

img = nib.load('data/MICCAI_BraTS_2019_Data_Training/HGG/BraTS19_CBICA_AAB_1/BraTS19_CBICA_AAB_1_t1.nii.gz').get_data()
dim = img.shape

ax = axs[0,0]
ax.imshow(np.pad(np.rot90(img[round(dim[0]/2),:,:]),((42,43),(0,0)),'constant',constant_values=0), 
                cmap=plt.cm.gray)
ax.set_title('T1',fontsize=FONTSIZE)
ax.set_xticks([])
ax.set_yticks([])
ax.set_ylabel('Sagittal',fontsize=FONTSIZE)

ax = axs[1,0]
ax.imshow(np.pad(np.rot90(img[:,round(dim[1]/2),:]),((42,43),(0,0)),'constant',constant_values=0), 
                cmap=plt.cm.gray)
ax.set_xticks([])
ax.set_yticks([])
ax.set_ylabel('Coronal',fontsize=FONTSIZE)

ax = axs[2,0]
ax.imshow(np.rot90(img[:,:,round(dim[2]/2)]), cmap=plt.cm.gray)
ax.set_xticks([])
ax.set_yticks([])
ax.set_ylabel('Axial',fontsize=FONTSIZE)

files = ['','t2','flair','t1ce','seg']
titles = ['','T2','FLAIR','T1Gd','Ground Truth']
for i in range(1,5):
    img = nib.load('data/MICCAI_BraTS_2019_Data_Training/HGG/'+
                   'BraTS19_CBICA_AAB_1/BraTS19_CBICA_AAB_1_{}.nii.gz'.format(files[i])).get_data()
    dim = img.shape

    ax = axs[0,i]
    ax.imshow(np.pad(np.rot90(img[round(dim[0]/2),:,:]),((42,43),(0,0)),'constant',constant_values=0), 
                    cmap=plt.cm.gray)
    ax.set_title(titles[i],fontsize=FONTSIZE)
    ax.axis('off')

    ax = axs[1,i]
    ax.imshow(np.pad(np.rot90(img[:,round(dim[1]/2),:]),((42,43),(0,0)),'constant',constant_values=0), 
                    cmap=plt.cm.gray)
    ax.axis('off')

    ax = axs[2,i]
    ax.imshow(np.rot90(img[:,:,round(dim[2]/2)]), cmap=plt.cm.gray)
    ax.axis('off')

# fig.savefig('log/all_mods.png',dpi=200)

In [None]:
'''
PLOT METRICS
'''

import pandas as pd

def metric_plot(csv_3dunet, csv_nas, csv_unet_test, csv_nas_test, save_name):
    df_3dunet = pd.read_csv(csv_3dunet)
    df_nas = pd.read_csv(csv_nas)
    df_unet_test = pd.read_csv(csv_unet_test)
    df_nas_test = pd.read_csv(csv_nas_test)
    metrics = list(df_nas.keys())[1:]
    FS = 6
    fig, axs = plt.subplots(2, 4, figsize=(FS*2,5),sharex=True)
    fig.subplots_adjust(left=0.04, right=0.96, bottom=0.06, top=0.94,
                        hspace=0.05, wspace=0.2)
    
    for row in range(2):
        metric_i = 0
        df0 = df_3dunet if row == 0 else df_unet_test
        df1 = df_nas if row == 0 else df_nas_test
        for col in range(4):
            ax = axs[row,col]
            values = []
            labels = []
            positions = [0]
            means = []
            DIST = 0.55
            for metric in metrics[3*metric_i:3*(metric_i+1)]:
                for df in [df0, df1]:
                    values.append([x for x in df[metric][:-5] if not np.isnan(x)])
                    means.append(list(df[metric])[-5])
                    positions.append(positions[-1]+DIST)
                positions[-1] += 0.4
                labels += [metric.split('_')[-1]]
            positions = positions[:6]    
            box = ax.boxplot(values, positions=positions,
                             showfliers=False,showmeans=False,widths=0.5,
                             patch_artist=True,
                             notch=False,
                             medianprops={'linewidth':1,'color':'r'})
            colors = ['pink', 'skyblue']
            for i,bar in enumerate(box['boxes']):
                bar.set_facecolor(colors[i%2])
            ax.scatter(positions, means,
                         c='k', marker='*', s=100 ,zorder=10)
            ax.set_xticks([(positions[i*2]+positions[i*2+1])/2 for i in range(3)])
            ax.set_xticklabels(labels,size=14)
            
            if row == 0:
                ax.set_title('{} '.format(metric.split('_')[0][:4]),size=17, fontstyle='italic')
            ax.vlines(positions,0,30,alpha=0.8,linewidth=0.1)
            ax.yaxis.grid(True,alpha=0.5)
            ax.tick_params(bottom=False)
            ax.tick_params(axis='x',labelsize=15, pad=2)
            ax.tick_params(axis='y',labelsize=15, length=2, pad=2, direction='in')
            metric_i += 1
        ax_right = ax.twinx()
        ax_right.tick_params(right=False,labelright=False)
        ax_right.set_ylabel('Training' if row==0 else 'Testing', size=17, labelpad=20, rotation=270)

    
    ax = axs[0,0]
    offset_x = 1.15
    offset_y = 0.08
    ax.text(positions[-2]-offset_x, 0.61+offset_y, '3D-U-Net',
             backgroundcolor='pink', color='pink', weight='roman',
             size=11,)
    ax.text(positions[-2]-offset_x, 0.545+offset_y, '3D-NAS-U-Net',
             backgroundcolor='skyblue', color='skyblue', weight='roman',
             size=11,)
    ax.text(positions[-2]-offset_x-0.1, 0.605+offset_y, '3D-U-Net',
             color='k', weight='roman',
             size=11,)
    ax.text(positions[-2]-offset_x-0.1, 0.54+offset_y, 'NAS-3D-U-Net',
            color='k', weight='roman',
             size=11,)
    ax.set_ylim(0.6,1.04)
    ax.set_yticks([0.6,0.7,0.8,0.9,1.0])
    ax.set_yticklabels([0.6,'',0.8,'',1.0])

    ax = axs[0,1]
    ax.set_ylim(0.6,1.04)
    ax.set_yticks(np.arange(0.6,1.01,0.1))
    ax.set_yticklabels([0.6,'',0.8,'',1.0])

    ax = axs[0,2]
    ax.set_ylim(0.9880,1.001)
    ax.set_yticks(np.arange(0.990, 1.001, 0.0025))
    ax.set_yticklabels([0.990,'',0.995,'',1.0], size=10)
    
    ax = axs[0,3]
    ax.set_ylim(0,9.5)
    ax.set_yticks(np.arange(0, 10, 1.5))
    ax.set_yticklabels([0,'',3,'',6,'',9])
    
    ax = axs[1,0]
    ax.set_ylim(0.4,1.04)
    ax.set_yticks(np.arange(0.4,1.04,0.1))
    ax.set_yticklabels([0.4,'',0.6,'',0.8,'',1.0])


    ax = axs[1,1]
    ax.set_ylim(0.4,1.04)
    ax.set_yticks(np.arange(0.4,1.01,0.1))
    ax.set_yticklabels([0.4,'',0.6,'',0.8,'',1.0])

    ax = axs[1,2]
    ax.set_ylim(0.9834,1.001)
    ax.set_yticks(np.arange(0.985, 1.001, 0.0025))
    ax.set_yticklabels([0.985,'',0.99,'',0.995,'',1.0],size=10)
    

    ax = axs[1,3]
    ax.set_ylim(0,16)
    ax.set_yticks(np.arange(0, 16, 2.5))
    ax.set_yticklabels([0,'',5,'',10,'',15],)
    
    
    plt.savefig(save_name,dpi=300)
    plt.show()

if __name__ == '__main__':
    csv_3dunet = 'data/results/Stats_Training_best_3dunet.csv'
    csv_nas = 'data/results/best_records/best_training/Stats_Training_final.csv'
    csv_unet_test = 'data/results/Stats_Validation_best_3dunet.csv'
    csv_nas_test = 'data/results/best_records/best_val/Stats_Validation_final.csv'
    metric_plot(csv_3dunet, csv_nas, csv_unet_test, csv_nas_test, 'log/metrics.png')
    

In [None]:
'''
FIG.~8
'''

import nibabel as nib
from helper import minmax_normalize as norm

H = 50
FONTSIZE=15
plt.figure()
fig, axs = plt.subplots(3, 4,figsize=(15,15))

plt.subplots_adjust(left=None, bottom=None, right=None, top=None,
                wspace=0.05, hspace=0.05)

t1 = nib.load('data/MICCAI_BraTS_2019_Data_Training/HGG/BraTS19_CBICA_AAB_1/BraTS19_CBICA_AAB_1_t1.nii.gz').get_data()
truth = nib.load('data/MICCAI_BraTS_2019_Data_Training/HGG/BraTS19_CBICA_AAB_1/BraTS19_CBICA_AAB_1_seg.nii.gz').get_data()
pred_nas = nib.load('log/BraTS19_CBICA_AAB_1.nii.gz').get_data()
pred_unet = nib.load('log/pred0_BraTS19_CBICA_AAB_1.nii.gz').get_data()
dim = t1.shape

     
cols = [t1, truth, pred_unet, pred_nas]
titles = ['t1','ground truth', 'unet', 'nas']

for col in range(4):
    img = cols[col]
    
    ax = axs[0,col]
    ax.imshow(np.rot90(img[round(dim[0]/2),:,:]), cmap=plt.cm.gray)
    temp = np.rot90(img[round(dim[0]/2),:,:])
#     print('1: ', np.sum(temp==1), '2: ',np.sum(temp==2), '4: ',np.sum(temp==4))
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(titles[col])

    ax = axs[1,col]
    ax.imshow(np.rot90(img[:,round(dim[1]/2),:]), cmap=plt.cm.gray)
    temp = np.rot90(img[:,round(dim[1]/2),:])
#     print('1: ', np.sum(temp==1), '2: ',np.sum(temp==2), '4: ',np.sum(temp==4))
    ax.set_xticks([])
    ax.set_yticks([])

    ax = axs[2,col]
    ax.imshow(np.rot90(img[:,:,round(dim[2]/2)]), cmap=plt.cm.gray)
    temp = np.rot90(img[:,:,round(dim[2]/2)])
#     print('1: ', np.sum(temp==1), '2: ',np.sum(temp==2), '4: ',np.sum(temp==4))
    ax.set_xticks([])
    ax.set_yticks([])
#     ax.set_ylabel('Axial',fontsize=FONTSIZE)

# fig.savefig('log/two.png',dpi=200)