In [1]:
# Importing Necesaary Library
import scipy.io
import math
import numpy as np
import sys
import timeit

import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import confusion_matrix

from skimage.transform import rotate
import scipy.fftpack as fft
from skimage.transform import rotate

from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

Loading CT Scans

In [2]:
data = scipy.io.loadmat("./ctscan_hw1.mat")
temp_ct_scans = data['ctscan']
ct_scans = []
for i in range(temp_ct_scans.shape[2]):
  ct_scans.append(temp_ct_scans[:,:,i])
ct_scans = np.array(ct_scans)
print(ct_scans.shape)

(3554, 512, 512)


Loading Infection Masks

In [3]:
data = scipy.io.loadmat("./infmsk_hw1.mat")
infmask = data['infmsk']
infection_masks = []
for i in range(infmask.shape[2]):
  infection_masks.append(infmask[:,:,i])
infection_masks = np.array(infection_masks)
print(infection_masks.shape)

N = ct_scans.shape[0]

(3554, 512, 512)


# Part B, Functions

Functions to correct predicted mask, i.e., correct background, infection and healthy region as expected

In [4]:
def find_counts(mask):
  count_background = np.count_nonzero(mask == 0)
  count_infection = np.count_nonzero(mask == 1)
  count_healthy = np.count_nonzero(mask == 2)
  return np.array([count_background, count_infection, count_healthy])

def check_pred_mask(pred_mask):
  pred_count_list = find_counts(pred_mask)
  original_count_list = np.array([100,10,50])
  mapping = dict({})

  for i in range(3):
    pred_max_idx = np.argmax(pred_count_list)
    original_max_idx = np.argmax(original_count_list)
    mapping[pred_max_idx] = original_max_idx
    pred_count_list[pred_max_idx] = -1
    original_count_list[original_max_idx] = -1

  corrected_mask = np.empty(shape=(512,512), dtype=int)
  for i in range(512):
    for j in range(512):
      corrected_mask[i][j] = mapping[pred_mask[i][j]]

  return corrected_mask

Using k-means for Image Segmentation

In [5]:
def get_predicted_mask(ct_scans):
  start = timeit.default_timer()

  n_subset = 100
  pred_masks = []
  N = n_subset
  for i in range(N):
    sys.stdout.write('\r'+"Processing Image "+str(i))
    sample = ct_scans[i]
    kmeans_obj = KMeans(n_clusters=3, random_state=0)
    ct_scan_flattened = sample.flatten().reshape((512*512,1))
    clusters = kmeans_obj.fit_predict(ct_scan_flattened)
    curr_pred_mask = clusters.reshape((512,512))
    curr_pred_mask = check_pred_mask(curr_pred_mask)
    pred_masks.append(curr_pred_mask)

  pred_masks = np.array(pred_masks)
  print('\n', pred_masks.shape)

  stop = timeit.default_timer()
  print('Time Taken = ', stop - start) 
  
  return pred_masks

Evaluating the model performance using several evaluation metrics

In [6]:
def get_confusion_metric(true_y, pred_y):
  true_y = true_y.flatten()
  pred_y = pred_y.flatten()
  return confusion_matrix(true_y, pred_y,labels=[0,1,2])
  
def get_req_avg_eval_metrics(infection_masks, pred_masks, N):

  # Dice Score is same as F1-Score, NO!
   
  avg_infection_sensitivity = 0
  avg_infection_specificity = 0
  avg_infection_accuracy = 0
  avg_infection_dice_score = 0
  
  avg_healthy_sensitivity = 0
  avg_healthy_specificity = 0
  avg_healthy_accuracy = 0
  avg_healthy_dice_score = 0

  count_infection_sensitivity = 0               # nan error

  for i in range(N):
    
    curr_confusion_metric = (get_confusion_metric(infection_masks[i],pred_masks[i])).T
    
    infection_TP = curr_confusion_metric[1][1]
    infection_TN = curr_confusion_metric[0][0] + curr_confusion_metric[2][0] + curr_confusion_metric[0][2] + curr_confusion_metric[2][2]
    infection_FP = curr_confusion_metric[1][0] + curr_confusion_metric[1][2] 
    infection_FN = curr_confusion_metric[0][1] + curr_confusion_metric[2][1]

    healthy_TP = curr_confusion_metric[2][2]
    healthy_TN = curr_confusion_metric[0][0] + curr_confusion_metric[0][1] + curr_confusion_metric[1][0] + curr_confusion_metric[1][1]
    healthy_FP = curr_confusion_metric[2][0] + curr_confusion_metric[2][1] 
    healthy_FN = curr_confusion_metric[0][2] + curr_confusion_metric[1][2]

    # Sensitivity = Recall = TP/(TP+FN)
    # Preicision = TP/(TP+FP)
    # Specificity = TN/(TN+FP)
    # Dice Score = 2.TP / (2.TP + FP + FN)

    infection_sensitivity = 0
    if((infection_TP+infection_FN)!=0):
      count_infection_sensitivity += 1
      infection_sensitivity = (infection_TP)/(infection_TP+infection_FN)
        
    infection_specificity = (infection_TN)/(infection_TN+infection_FP)
    infection_accuracy = (infection_TP+infection_TN)/(infection_TP+infection_TN+infection_FP+infection_FN)
    infection_dice_score = (2*infection_TP)/(2*infection_TP + infection_FP + infection_FN)

    healthy_sensitivity = (healthy_TP)/(healthy_TP+healthy_FN)
    healthy_specificity = (healthy_TN)/(healthy_TN+healthy_FP)
    healthy_accuracy = (healthy_TP+healthy_TN)/(healthy_TP+healthy_TN+healthy_FP+healthy_FN)
    healthy_dice_score = (2*healthy_TP)/(2*healthy_TP + healthy_FP + healthy_FN)

    avg_infection_sensitivity += infection_sensitivity
    avg_infection_specificity += infection_specificity
    avg_infection_accuracy += infection_accuracy
    avg_infection_dice_score += infection_dice_score

    avg_healthy_sensitivity += healthy_sensitivity
    avg_healthy_specificity += healthy_specificity
    avg_healthy_accuracy += healthy_accuracy
    avg_healthy_dice_score += healthy_dice_score

  avg_infection_sensitivity = avg_infection_sensitivity/count_infection_sensitivity
  avg_infection_specificity = avg_infection_specificity/N
  avg_infection_accuracy = avg_infection_accuracy/N
  avg_infection_dice_score = avg_infection_dice_score/N

  avg_healthy_sensitivity = avg_healthy_sensitivity/N
  avg_healthy_specificity = avg_healthy_specificity/N
  avg_healthy_accuracy = avg_healthy_accuracy/N
  avg_healthy_dice_score = avg_healthy_dice_score/N

  return avg_infection_dice_score, avg_infection_sensitivity, avg_infection_specificity, avg_infection_accuracy, avg_healthy_dice_score, avg_healthy_sensitivity, avg_healthy_specificity, avg_healthy_accuracy

In [7]:
def find_eval_metrics(infection_masks, pred_masks, N):
  inf_ds, inf_sen, inf_spec, inf_acc, hea_ds, hea_sen, hea_spec, hea_acc = get_req_avg_eval_metrics(infection_masks, pred_masks, N)
  print("Average Dice Score for Infection: ", inf_ds)
  print("Average Sensitivity for Infection: ", inf_sen)
  print("Average Specificity for Infection: ", inf_spec)
  print("Average Accuracy for Infection: ", inf_acc)
  print()
  print("Average Dice Score for Healthy: ", hea_ds)
  print("Average Sensitivity for Healthy: ", hea_sen)
  print("Average Specificity for Healthy: ", hea_spec)
  print("Average Accuracy for Healthy: ", hea_acc)

# Part C, Reconstruction

Functions for obtaining reconstructed images

In [8]:
# CT Scan Image -> Sinogram
def radon_transform(ct_scan, rotation = 4): # rots is 4X or 8X
    projections = []
    d_theta = -rots
    for i in range(180//rots):
        projections.append(rotate(ct_scan,i*d_theta).sum(axis=0))
    return np.vstack(projections)

def fft_translate(projections):
    return fft.rfft(projections, axis=1)

def ramp_filter(ffts):
    ramp = np.floor(np.arange(0.5, ffts.shape[1]//2 + 0.1, 0.5))
    return ffts * ramp

def inverse_fft_translate(sinogram):
    return fft.irfft(sinogram, axis=1)

def inverse_radon_transform(sinogram): 
  intermediate_sinogram = fft_translate(sinogram)   
  intermediate_sinogram = ramp_filter(intermediate_sinogram)
  intermediate_sinogram = inverse_fft_translate(intermediate_sinogram)
  laminogram = np.zeros((intermediate_sinogram.shape[1],intermediate_sinogram.shape[1]))
  d_theta = 180.0 / intermediate_sinogram.shape[0]

  for i in range(intermediate_sinogram.shape[0]):
    temp = np.tile(intermediate_sinogram[i],(intermediate_sinogram.shape[1],1))
    temp = rotate(temp, d_theta*i)
    laminogram += temp
  return laminogram

In [9]:
class reconstruction_sinogram:
  def __init__(self, ct_scans):
    self.ct_scans = ct_scans
    self.sinograms = []
    self.reconstructed_ct_scans = []
    
  def get_sinogram(self, ct_scan, angle):
    return radon_transform(ct_scan, rotation = angle)

  def ct_scans_to_sinograms(self, angle):
    N = len(self.ct_scans)
    for i in range(N):
      sys.stdout.write('\r'+"CT Scans -> Sinogram; Image No. "+str(i))
      self.sinograms.append(self.get_sinogram(self.ct_scans[i], angle))
    
  def get_reconstructed_ct_scan(self, sinogram):
    return inverse_radon_transform(sinogram)

  def sinogram_to_ct_scans(self):
    N = len(self.ct_scans)
    for i in range(N):
      sys.stdout.write('\r'+"Sinogram -> CT Scans; Image No. "+str(i))
      self.reconstructed_ct_scans.append(self.get_reconstructed_ct_scan(self.sinograms[i]))

# 4x Limited Angle Sinogram

In [None]:
reconstruct_4x = reconstruction_sinogram(ct_scans)
reconstruct_4x.ct_scans_to_sinograms(angle = 4)
reconstruct_4x.sinogram_to_ct_scans()

CT Scans -> Sinogram; Image No. 1957

Finding PSNR and SSIM

In [None]:
N = len(ct_scans)
avg_psnr_4x = 0
avg_ssim_4x = 0
for i in range(N):
  sys.stdout.write('\r'+"Image No. "+str(i))
  avg_psnr_4x += psnr(ct_scans[i],reconstruct_4x.reconstructed_ct_scans[i])
  avg_ssim_4x += ssim(ct_scans[i],reconstruct_4x.reconstructed_ct_scans[i])
    
avg_psnr_4x = avg_psnr_4x/N
avg_ssim_4x = avg_ssim_4x/N

print("Average Peak Signal to Noise Ratio for 4x Reconstruction: ", avg_psnr_4x)
print("Average Structute Similarity Index Measure for 4x Reconstruction: ", avg_ssim_4x)

8x Limited Angle Sinogram

In [None]:
reconstruct_8x = reconstruction_sinogram(ct_scans)
reconstruct_8x.ct_scans_to_sinograms(angle = 8)
reconstruct_8x.sinogram_to_ct_scans()

Finding PSNR and SSIM

In [None]:
N = len(len(ct_scans))
avg_psnr_8x = 0
avg_ssim_8x = 0
for i in range(N):
  sys.stdout.write('\r'+"Image No. "+str(i))
  avg_psnr_8x += psnr(ct_scans[i],reconstruct_8x.reconstructed_ct_scans[i])
  avg_ssim_8x += ssim(ct_scans[i],reconstruct_8x.reconstructed_ct_scans[i])
    
avg_psnr_8x = avg_psnr_8x/N
avg_ssim_8x = avg_ssim_8x/N

print("Average Peak Signal to Noise Ratio for 8x Reconstruction: ", avg_psnr_8x)
print("Average Structute Similarity Index Measure for 8x Reconstruction: ", avg_ssim_8x)

Evaluating Segmentation on 4x and 8x Reconstruction

In [None]:
pred_masks_4x = get_predicted_mask(reconstruct_4x.reconstructed_ct_scans)
print("Evaluation Metrics for 4x Reconstruction")
find_eval_metrics(infection_masks, pred_masks_4x, N)

print()
print()
print()

pred_masks_8x = get_predicted_mask(reconstruct_8x.reconstructed_ct_scans)
print("Evaluation Metrics for 8x Reconstruction")
find_eval_metrics(infection_masks, pred_masks_4x, N)