In [6]:
import os
import torch
import pickle
import random
import h5py as h5
import pandas as pd
import numpy as np
import nibabel as nib
import seaborn as sns
from scipy import ndimage
from skimage import morphology
from torchio import transforms
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast
from torch.utils.data import Dataset
from torchio.data.subject import Subject
from skimage import filters, exposure
from torch.utils.data import Dataset, DataLoader

from dataloaders.ixi_torchiowrap import IXI_H5DSImage
import ceVae
from ceVae.aes import VAE
from ceVae.helpers import kl_loss_fn, rec_loss_fn, geco_beta_update, get_ema, get_square_mask

In [7]:
def dice(true_mask, pred_mask, non_seg_score=1.0):
    """
        Computes the Dice coefficient.
        Args:
            true_mask : Array of arbitrary shape.
            pred_mask : Array with the same shape than true_mask.  
        
        Returns:
            A scalar representing the Dice coefficient between the two segmentations. 
        
    """
    assert true_mask.shape == pred_mask.shape

    true_mask = np.asarray(true_mask).astype(bool)
    pred_mask = np.asarray(pred_mask).astype(bool)

    # If both segmentations are all zero, the dice will be 1. (Developer decision)
    im_sum = true_mask.sum() + pred_mask.sum()
    if im_sum == 0:
        return non_seg_score

    # Compute Dice coefficient
    intersection = np.logical_and(true_mask, pred_mask)
    return 2. * intersection.sum() / im_sum

### Load Model checkpoint

#### From a local checkpoint file

In [8]:
model = torch.load('checkpoint/brain.ptrh', map_location='cuda:0')

#### From Huggingface

In [None]:
from transformers import AutoModel
modelHF = AutoModel.from_pretrained("soumickmj/StRegA_cceVAE2D_Brain_MOOD_IXIT1_IXIT2_IXIPD", trust_remote_code=True)
model = modelHF.model.to('cuda:0')

### Load Data files

In [None]:
device = torch.device('cuda:2')
model = torch.load('brain.ptrh', map_location='cuda:2')

file_list = os.listdir('./brats/non_seg')
file_list.sort()

### StRegA Pipeline

In [None]:
s_index = 80
for idx, file in enumerate(file_list):
    orig = nib.load('./brats/non_seg/' + file).get_fdata()
    seg_vol = nib.load('./brats/seg/' + file).get_fdata()
    mask = nib.load('./brats/mask/' + file).get_fdata()
    
    orig = np.moveaxis(orig, 2, 0)
    seg_vol = np.moveaxis(seg_vol, 2, 0)
    mask = np.moveaxis(mask, 2, 0)
    mask[mask > 0] = 1
    
    orig_data_item = torch.tensor(orig).unsqueeze(dim = 0)
    orig_out = transforms.CropOrPad((155, 256, 256))(orig_data_item)
    orig_out = orig_out.squeeze(dim = 0).unsqueeze(dim = 1)
    
    data_item = torch.tensor(seg_vol).unsqueeze(dim = 0)
    out = transforms.CropOrPad((155, 256, 256))(data_item)
    out = out.squeeze(dim = 0).unsqueeze(dim = 1)
    
    m_data_item = torch.tensor(mask).unsqueeze(dim = 0)
    m_out = transforms.CropOrPad((155, 256, 256))(m_data_item)
    m_out = m_out.squeeze(dim = 0).unsqueeze(dim = 1)
        
    img = out.float().to(device)
    mask = m_out.float().to(device)
    img = (img - torch.min(img))/ (torch.max(img) - torch.min(img))
    img = torch.nan_to_num(img, nan=0.0)

    with autocast():
        x_r, _ = model(img)
        
    # Difference of reconstruction and input
    x_r = x_r.float()
    diff_mask = (x_r.detach().cpu().numpy() - img.detach().cpu().numpy())
    
    # Manual Thresholding
    m_diff_mask = diff_mask.copy()
    m_diff_mask[m_diff_mask < 0] = 0
    m_diff_mask[m_diff_mask > 0.2] = 1
    
    # Otsu Thresholding
    val = filters.threshold_otsu(m_diff_mask)
    thr = m_diff_mask > val
    thr[thr < 0] = 0
    
    # Morphological Opening
    final = np.zeros_like(thr)
    for i in range(thr.shape[0]):
        final[i,0] = torch.tensor(morphology.area_opening(thr[i,0], area_threshold=256))
    final[img.cpu() == 0] = 0
    
    s_index = 90
    fig = plt.figure(figsize=(15, 15))
    ax1 = fig.add_subplot(1,6,1)
    #ax1.set_title('Input image', fontsize=12)
    rotated_img = ndimage.rotate(orig_out[s_index, 0], -90)
    ax1.imshow(rotated_img, cmap='gray')
    ax1.tick_params(axis='both', which='major', labelsize=4)
    plt.axis('off')
    
    ax1 = fig.add_subplot(1,6,1)
    #ax1.set_title('Input image', fontsize=12)
    rotated_img = ndimage.rotate(orig_out.cpu().detach().numpy()[s_index, 0], -90)
    ax1.imshow(rotated_img, cmap='gray')
    ax1.tick_params(axis='both', which='major', labelsize=4)
    plt.axis('off')
    
    ax1 = fig.add_subplot(1,6,2)
    #ax1.set_title('Input image', fontsize=12)
    rotated_img = ndimage.rotate(img.cpu().detach().numpy()[s_index, 0], -90)
    ax1.imshow(rotated_img, cmap='gray')
    ax1.tick_params(axis='both', which='major', labelsize=4)
    plt.axis('off')

    ax2 = fig.add_subplot(1,6,3)
    #ax2.set_title('Reconstruction', fontsize=12)
    rotated_img = ndimage.rotate(x_r.float().cpu().detach().numpy()[s_index, 0], -90)
    ax2.imshow(rotated_img, cmap='gray')
    ax2.tick_params(axis='both', which='major', labelsize=4)
    plt.axis('off')

#     ax3 = fig.add_subplot(1,6,4)
#     #ax3.set_title('Difference', fontsize=12)
#     rotated_img = ndimage.rotate(diff_mask[s_index,0], -90)
#     ax3.imshow(rotated_img, cmap='gray')
#     ax3.tick_params(axis='both', which='major', labelsize=4)
#     plt.axis('off')

    ax5 = fig.add_subplot(1,6,4)
    rotated_img = ndimage.rotate(thr[s_index,0], -90)
    #ax5.set_title('Otsu Auto Threshold', fontsize=12)
    ax5.imshow(rotated_img, cmap='gray')
    ax5.tick_params(axis='both', which='major', labelsize=4)
    plt.axis('off')

    ax6 = fig.add_subplot(1,6,5)
    rotated_img = ndimage.rotate(final[s_index,0], -90)
    #ax6.set_title('Morphological Opening', fontsize=12)
    ax6.imshow(rotated_img, cmap='gray')
    ax6.tick_params(axis='both', which='major', labelsize=4)
    plt.axis('off')
    
    ax6 = fig.add_subplot(1,6,6)
    rotated_img = ndimage.rotate(m_out[s_index,0], -90)
    #ax6.set_title('Morphological Opening', fontsize=12)
    ax6.imshow(rotated_img, cmap='gray')
    ax6.tick_params(axis='both', which='major', labelsize=4)
    plt.axis('off')

    fig.tight_layout()