### Imports

In [None]:
import scipy.io as scio
import numpy as np    
import matplotlib.pyplot as plt
import sys
import os
import math
import pprint
import cv2
from scipy.misc import imsave
from helper import *
from create_labels import *
from stats_helper import *

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'

### Setting Directory

In [None]:
# Setting the directories
import os

# wanted_folder = 'alldata/'
wanted_folder = 'pruned/'
# wanted_folder = 'Atrium/'
# wanted_folder = 'Ventricle/'

cwd = os.getcwd()
check_directory = cwd
if check_directory == '/home/sim/notebooks/relaynet_pytorch':
    cwd = cwd + '/datasets/OCTData/'+wanted_folder
elif check_directory == '/Users/sim/Desktop/Imperial/Project/PreTrained/relaynet_pytorch':
    cwd = cwd + '/datasets-24-aug/OCTData/'+wanted_folder

print(cwd)

### Raw Files

In [None]:
whole_raw_image_folder = cwd + 'whole_raw_image/'
print(whole_raw_image_folder)

In [None]:
filenames, raw_images = get_data(whole_raw_image_folder, '.tif')

In [None]:
print(len(filenames))
print (len(raw_images))
label = raw_images[0]
plt.imshow(raw_images[0], cmap='gray')

### Labels

In [None]:
manual_label_folder = cwd + 'manual_label/'
_, manual_labels = get_data(manual_label_folder,'.JPG')
print(filenames[0])

In [None]:
f, axs = plt.subplots(1,2,figsize=(20,20))
plt.subplot(121), plt.imshow(raw_images[0], cmap = "gray")
plt.title('Raw OCT Image'), plt.xticks([]), plt.yticks([])
plt.subplot(122), plt.imshow(manual_labels[0])
plt.title('Manually Labelled Image'), plt.xticks([]), plt.yticks([])
# plt.subplot(133),plt.imshow(output)
# plt.title('Automated Label'), plt.xticks([]), plt.yticks([])
plt.show()

### Ids

In [None]:
ids_folder = cwd + 'png_labels_method/'
_, gnd_ids = get_data(ids_folder,'.png')

print (len(gnd_ids))
plt.imshow(gnd_ids[0])

### Results

In [None]:
folder_of_interest = 'Pruned/'
results_folder = os.getcwd() +'/results/' + folder_of_interest
print (results_folder)

In [None]:
# For raw vs normalised 
# chosen_result = np.load(results_folder + 'tf_nonnormalised_raw.npy')
# chosen_result = np.load(results_folder + 'tf_normalised.npy')
# chosen_result = np.load(results_folder + 'tf_atrven_normalised.npy')
# chosen_result = np.load(results_folder + 'torch_nonnormalised_raw.npy')
# chosen_result = np.load(results_folder + 'torch_normalised.npy')
# chosen_result = np.load(results_folder + 'torch_atrven_normalised.npy')

# Whole vs Pruned
# chosen_result = np.load(results_folder + 'tf_whole.npy')
# chosen_result = np.load(results_folder + 'tf_pruned.npy')
chosen_result = np.load(results_folder + 'torch_whole.npy')
print(np.unique(chosen_result[0]))
# chosen_result = np.load(results_folder + 'torch_pruned.npy')
print(chosen_result.shape)

In [None]:
dice_stats = []
avg_thickness_list = []
err_thickness_list = []
sqrerr_thickness_list = []
for i in range(len(filenames)):
    ind = i

    # Raw Test Image 
    testing_image = raw_images[ind]
    test_label = manual_labels[ind]
    true_id = gnd_ids[ind]
    predicted_id = chosen_result[ind]
    
    # Creating one hot encoding of true labels and predicted labels
    true_labels = list_of_labels(true_id,8)
    pred_labels = list_of_labels(predicted_id,8)
    
    # Making sure they're same shape
    th,tw,_ = true_labels.shape
    ph,pw,_ = pred_labels.shape
    true_labels = true_labels[:min(th,ph), :min(tw,pw), :]
    pred_labels = pred_labels[:min(th,ph), :min(tw,pw), :]
    
    stats = find_stats(true_labels, pred_labels)
    _,avg_pred_thickness_list,mean_abs_error_list, mean_squared_error_list, _ = thickness_metrics(true_labels,pred_labels)
    dice_stats.append(stats)
    avg_thickness_list.append(avg_pred_thickness_list)
    err_thickness_list.append(mean_abs_error_list)
    sqrerr_thickness_list.append(mean_squared_error_list)
    color = label_img_to_rgb(chosen_result[ind])
    
#     axis_name = filenames[ind][4:-4]
    
#     f, axs = plt.subplots(1,3,figsize=(20,20))
#     print(axis_name)
#     plt.suptitle(axis_name, size=14)
#     plt.subplots_adjust(top=1.58)
#     plt.subplot(131), plt.imshow(raw_images[ind], cmap = "gray")
#     plt.title('Raw OCT Image'), plt.xticks([]), plt.yticks([])
#     plt.subplot(132), plt.imshow(manual_labels[ind])
#     plt.title('Manually Labelled Image'), plt.xticks([]), plt.yticks([])
#     plt.subplot(133),plt.imshow(color)
#     plt.title('Automated Label'), plt.xticks([]), plt.yticks([])
    
    plt.show()

dice_stats = np.asarray(dice_stats)
avg_thickness_list = np.asarray(avg_thickness_list)
err_thickness_list = np.asarray(err_thickness_list)
sqrerr_thickness_list = np.asarray(sqrerr_thickness_list)

layers = ['Void - Black', 'Myocardium - Red', 'Endocardium - Blue', 'Fibrosis - Purple', 'Fat - Green', 'Dense Collagen - Orange', 'Loose Collagen - Yellow', 'Smooth Muscle - Pink']

def get_layer_stats(input_list):
    averages = [0]
    for i in range(1,8):
    #     ii = np.isfinite(overall_stats[:,i])
        ii = input_list[:,i] > 0.001
        new_stats = input_list[:,i][ii]
        if len(new_stats)>0:
            best_val = np.argmax(new_stats)
            print('Best_val is at index: ', best_val)
        val = round(np.average(new_stats),3)
        averages.append(val)
        print('Label: {} {}, Average Score: {}'.format(i, layers[i],averages[i]))
    print('Average Scores', np.around(averages,3))
    print('Average Overall Score', np.average(averages[1:]))
    
# Dice Stats
print('Dice Stats\n')
get_layer_stats(dice_stats)
print()
# Avg Thickness Stats
print('Average Thickness Stats\n')
get_layer_stats(avg_thickness_list)
print()
# Mean Error Thickness Stat
print('Absolute Error Stats\n')
get_layer_stats(err_thickness_list)
print()
# # Squared Error Thickness State
# print('Squared Error Stats\n')
# get_layer_stats(sqrerr_thickness_list)

def get_layer_stats2(input_list):
    averages = [0]
    for i in range(1,8):
    #     ii = np.isfinite(overall_stats[:,i])
        ii = input_list[:,i] > 0.001
        new_stats = input_list[:,i][ii]
        if len(new_stats)>0:
            best_val = np.argmax(new_stats)
        val = round(np.average(new_stats),2)
        averages.append(val)
    print('Average Scores', np.around(averages[1:],2))
# Dice Stats
get_layer_stats2(dice_stats)
# Avg Thickness Stats
get_layer_stats2(avg_thickness_list)
# Mean Error Thickness Stat
get_layer_stats2(err_thickness_list)

In [None]:
SEG_LABELS_LIST2 = [
    {"id": -1, "name": "void", "rgb_values": [0, 0, 0]},
    {"id": 0, "name": "void", "rgb_values": [255,0,0]}, # red
    {"id": 1, "name": "Myocardium", "rgb_values": [255,0,0]}, # red
    {"id": 2, "name": "Endocardium", "rgb_values": [0, 0, 255]}, # blue
    {"id": 3, "name": "Fibrosis", "rgb_values": [177,10,255]}, # purple
    {"id": 4, "name": "Fat", "rgb_values": [0, 255, 0]}, # green
    {"id": 5, "name": "Dense Collagen", "rgb_values": [177,10,255]}, # purple
    {"id": 6, "name": "Loose Collagen", "rgb_values": [255, 255, 0]}, # yellow
    {"id": 7, "name": "Smooth Muscle", "rgb_values": [255,0,255]}, # magenta/pink
    {"id": 8, "name": "Smooth Muscle", "rgb_values": [0, 0, 0]}
]; 

def label_img_to_rgb2(label_img):
    label_img = np.squeeze(label_img)
    labels = np.unique(label_img)
    label_infos = [l for l in SEG_LABELS_LIST2 if l['id'] in labels]

    label_img_rgb = np.array([label_img,
                              label_img,
                              label_img]).transpose(1,2,0)
    for l in label_infos:
        mask = label_img == l['id']
        label_img_rgb[mask] = l['rgb_values']

    return label_img_rgb.astype(np.uint8)


ind=5
color = label_img_to_rgb2(chosen_result[ind])
f, axs = plt.subplots(1,3,figsize=(20,20))
plt.subplot(131), plt.imshow(raw_images[ind], cmap = "gray")
plt.title('Raw OCT Image'), plt.xticks([]), plt.yticks([])
plt.subplot(132), plt.imshow(manual_labels[ind])
plt.title('Manually Labelled Image'), plt.xticks([]), plt.yticks([])
plt.subplot(133),plt.imshow(color)
plt.title('Automated Label'), plt.xticks([]), plt.yticks([])

In [None]:
def get_layer_stats(input_list):
    averages = [0]
    for i in range(1,8):
    #     ii = np.isfinite(overall_stats[:,i])
        ii = input_list[:,i] > 0.001
        new_stats = input_list[:,i][ii]
        if len(new_stats)>0:
            best_val = np.argmax(new_stats)
            print('Best_val is at index: ', best_val)
        val = round(np.average(new_stats),3)
        averages.append(val)
        print('Label: {} {}, Average Score: {}'.format(i, layers[i],averages[i]))
    print('Average Scores', np.around(averages,3))
    print('Average Overall Score', np.average(averages[1:]))
    
# Dice Stats
print('Dice Stats\n')
get_layer_stats(dice_stats)


In [None]:
# ind = 4
# for i in range(8):
#     print('Label: {} {}, Average Dice Score: {}'.format(i, layers[i], overall_stats[ind,i]))

# color = label_img_to_rgb(chosen_result[ind]) 
# f, axs = plt.subplots(1,3,figsize=(20,20))
# plt.subplot(131), plt.imshow(raw_images[ind], cmap = "gray")
# plt.title('Raw OCT Image'), plt.xticks([]), plt.yticks([])
# plt.subplot(132), plt.imshow(manual_labels[ind])
# plt.title('Manually Labelled Image'), plt.xticks([]), plt.yticks([])
# plt.subplot(133),plt.imshow(color)
# plt.title('Automated Label'), plt.xticks([]), plt.yticks([])
# plt.show()

# t_label = list_of_labels(gnd_ids[ind],8)
# p_label = list_of_labels(chosen_result[ind],8)
# # Plotting Labels of the layers
# fig, axes = plt.subplots(nrows=1, ncols=8, figsize=(20,20))
# for i, ax in enumerate(axes):
#     ax.imshow((t_label[:,:,i]), alpha=0.2)
#     ax.set_title("label " + str(i))

# fig, axes = plt.subplots(nrows=1, ncols=8, figsize=(20,20))
# for i, ax in enumerate(axes):
#     ax.imshow((p_label[:,:,i]), alpha=0.2)
#     ax.set_title("label " + str(i))