In [1]:
#saliency map paper : https://github.com/LLNL/fastcam.git and adapted/modified  to our usecase

!git clone https://github.com/LLNL/fastcam.git

!pip install pytorch_gradcam 

Cloning into 'fastcam'...
remote: Enumerating objects: 209, done.[K
remote: Counting objects: 100% (209/209), done.[K
remote: Compressing objects: 100% (130/130), done.[K
remote: Total 623 (delta 127), reused 150 (delta 78), pack-reused 414[K
Receiving objects: 100% (623/623), 19.58 MiB | 14.02 MiB/s, done.
Resolving deltas: 100% (367/367), done.
Collecting pytorch_gradcam
[?25l  Downloading https://files.pythonhosted.org/packages/e6/0a/55251f7cbea464581c6fb831813d38a41fdeb78f3dd8193522248cb98744/pytorch-gradcam-0.2.1.tar.gz (6.0MB)
[K     |████████████████████████████████| 6.0MB 6.4MB/s 
Building wheels for collected packages: pytorch-gradcam
  Building wheel for pytorch-gradcam (setup.py) ... [?25l[?25hdone
  Created wheel for pytorch-gradcam: filename=pytorch_gradcam-0.2.1-cp36-none-any.whl size=5270 sha256=6219c71f188c78a8ec580a058832c33e5f44f3a95c323aabee6ba243a0ea6289
  Stored in directory: /root/.cache/pip/wheels/e8/1e/35/d24150a078a90ce0ad093586814d4665e945466baa8990730

#**Compute Saliency,Gradcam,Gradcam++ modules**

### This cell contains all the functions to compute the various maps. It also contains functions which combine two different maps eg: Saliency Map + Gradcam++ .

In [2]:
# -*- coding: utf-8 -*-
%cd fastcam

#importing libraries 


import os
import torch
import warnings

import tensorflow as tf
import math
import numpy as np
import torch
import torchvision
import torch.utils.data
from torch.utils.data import DataLoader,TensorDataset
import copy
import logging
from keras.layers import Input
from keras.layers.merge import concatenate
import torch.nn as nn
import torch.nn.functional as F
from datetime import datetime
from torch.utils import data
import time
from scipy import ndimage
from skimage.transform import resize
import yaml
import skimage.io as sio
import shutil
from random import shuffle
from skimage.transform import resize
import skimage.io as sio
from scipy.io import savemat,loadmat
import cv2
import mask
import draw
import norm
import misc
from torchvision import models
from random import shuffle
from torchvision.utils import make_grid, save_image
import pandas as pd
from gradcam.utils import visualize_cam
from matplotlib import pyplot as plt



'''''''''
A: Saliency Maps 
'''''''''

def get_smoe_map(x,relu=False):

  '''
  Scaled map order equivalent map computation fumction:
  reference from the saliency map paper : https://github.com/LLNL/fastcam.git and adapted/modified  to our usecase

  Arguments:  numpy array : x -> intermediate layer output.
              bool: relu -> to avoid any negative values for log , 
              if we pass a relu activated conv. output then this argument can be set to False.
          
  Returns: np array : smoe_map for the intermediate layer output
  '''
  print(f' smoe input shape={x.shape}')
  if relu:
    x=tf.nn.relu(x).numpy()
  print(f'x range={np.amax(x),np.amin(x)}')
  
  m   = np.mean(x,axis=-1)+0.0000001 

  
  x   = x + 0.0000001

  k   = np.log2(m) - np.mean(np.log2(x), axis=-1)
  print(f'log of mean={np.log2(m)}, mean of log={np.mean(np.log2(x), axis=-1)}')
  print(f'k={k}')
  k   = k + 0.0000001
 
  print(np.array_equal(np.zeros(k.shape),k))
  print(f'{x.shape,k.shape,np.amin(k)}')
  print(f'kmax, kmin={np.min(k),np.max(k)}')
  print(f'mean={m}')
  smoe_map  = k * m
  print(f'smoe map={smoe_map}')
  print(f'smoe output shape={smoe_map.shape}')
  return smoe_map

def get_std_map(x):
  '''
  STD based map alternative to SMOE.

  Arguments:  numpy array : x -> intermediate layer output.
          
  Returns: np array : m (standard deviation based map for the intermediate layer output)

  '''
  print(f'before std map shape ={x.shape}')
  m = np.std(x,axis=-1)

  print(f'std map shape ={m.shape}')

  return m

def get_norm(x,const_mean=None,const_std=None):
  '''
  get norm refrence from the saliency map paper : https://github.com/LLNL/fastcam.git and adapted/modified  to our usecase

  Arguments:  numpy array : x -> intermediate layer output.
              float: const_mean (optional) -> only if a constant mean need to be used
              float: const_std (optional) -> only if a constant std. dev. need to be used
       

  Returns: torch.Tensor: csal_maps (combined saliency maps )
  '''
  s0      = x.shape[0]
  s1      = x.shape[1]
  s2      = x.shape[2]


  x       = np.reshape(x,(1,s1*s2))
  print(f'get norm func x after reshape={x.shape} ')

  '''
      Compute Mean
  '''
  if const_mean is None:
      m       = np.mean(x,axis=1)
      m       = np.reshape(m,(m.shape[0],1))
  else:
      m       = const_mean

  print(f'get norm func x after mean reshape={m.shape} ') 
  '''
      Compute Standard Deviation
  '''
  if const_std is None:
      s       = np.std(x,axis=1)
      s       = np.reshape(s,(s.shape[0],1))
  else:
      s       = const_std
  
  '''
      The normal cumulative distribution function is used to squash the values from within the range of 0 to 1
  '''

  s=torch.tensor(s)
  x       = 0.5*(1.0 + torch.erf((x-m)/(s*torch.sqrt(torch.tensor(2.0)))))
  print(x.shape)    

  x       = x.reshape(1,s1,s2)

  print(f'map after norm={x,x.shape}')
  return x


def combine_sal_maps(smaps,output_size,weights,map_num,resize_mode='bilinear',do_relu=False):
  '''
  Combined saliency maps are computed here .
  '''
  bn  = smaps[0].shape[0]
  cm  = torch.zeros((bn, 1, output_size[0], output_size[1]), dtype=smaps[0].dtype, device=smaps[0].device)
  ww  = []
  
  '''
      Now get each saliency map and resize it. Then store it and also create a combined saliency map.
  '''
  for i in range(len(smaps)):
   
      wsz = smaps[i].shape
      w   = np.reshape(smaps[i],(wsz[0], 1, wsz[1], wsz[2]))
   
      w   = nn.functional.interpolate(w, size=output_size, mode=resize_mode, align_corners=False) 
      ww.append(w)  
      
      cm  += (w * weights[i])

  '''
      Finish the combined saliency map to make it a weighted average.
  '''
  weight_sum =sum(weights)
  cm  = cm / weight_sum
  cm  = cm.reshape(bn, output_size[0],output_size[1])
  
  ww  = torch.stack(ww,dim=1)
  ww  = ww.reshape(bn, map_num, output_size[0], output_size[1])
  

  
  return cm, ww



def compute_saliency_tf(base_path,inputs,tf_model):
  '''
   Saliency maps are computed for specicied layers and then combine them.

   Arguments: str: base_path -> path to save the  map.
              list: inputs -> [input_image_tensor,gender_tensor]
              tf model: model -> tensorflow pretrained model
       

  Returns: torch.Tensor: csal_maps (combined saliency maps )
  '''

  gender=inputs[1]
  gender=tf.reshape(gender,[1,1])
  img=inputs[0]
  img_chunk=tf.convert_to_tensor(img)
  print(img_chunk.shape)
  img_chunk = tf.reshape(img_chunk,[1,121,145,6])
  layers=[layer.name for layer in tf_model.layers]
  outputs=[]

  #select all layer activations after conv for eg: if there ae 66 conv layers then there are 66 activation layers.
  for l in layers:
   
    if l.startswith('activation'):
    
        outputs.append(tf_model.get_layer(name=l).output) 

  outputs.append(tf_model.output)                                         
  test_tf_model=tf.keras.models.Model([tf_model.inputs], outputs)
 
  predictions = test_tf_model([img_chunk,gender])

  # Specify or experiment  with layers we want to compute saliency maps for.

  # hooks=[predictions[0],predictions[1],predictions[2],predictions[8],predictions[14],predictions[20],predictions[23]\
  #        ,predictions[29],predictions[35],predictions[41],predictions[47],predictions[50],\
  #        predictions[56],predictions[62],predictions[65]]#predictions[:layer_end]
  # hooks= [predictions[0],predictions[2],predictions[17],predictions[47],predictions[62]] 


  #these layers were picked as the outputs have diffrent scale dimensions,  we can experimentwith other layers as well.
  hooks=[predictions[0],predictions[2],predictions[14],predictions[47],predictions[65]] 
  
  # choose specific channels / filters
  for x in hooks:
    print('ouput shapes layerwise')
    print(x.shape)

  
  # sal_maps       = [ get_norm(get_smoe_map(np.expand_dims(np.mean(x.numpy()[:,:,:,:,:],axis=-2)[:,:,:,2],axis=-1))) for x in hooks ]

  #smoe saliency map
  sal_maps       = [ get_norm(get_smoe_map(np.mean(x.numpy()[:,:,:,:,:],axis=-2))) for x in hooks ]

  #std dev saliency maps
  # sal_maps       = [ get_norm(get_std_map(np.mean(x.numpy()[:,:,:,:,:],axis=-2))) for x in hooks ]

  
  for smaps in sal_maps:
    print(smaps.shape)
    
  # all layer scale maps with equal weightage
  weights=np.ones(len(hooks))
  
  # all layer scale maps with progressive increasing weightage
  # weights=[i+1 for i in range(len(hooks))]
  # weights = [i for i in range(len(hooks),0,-1)]
  
  map_num=len(hooks)

  f, axarr = plt.subplots(1,1,figsize=(10,10))
  raw=np.mean(img_chunk[0,:,:,:],axis=-1)
  raw= raw/np.max(raw)
  r=axarr.imshow(raw,cmap='jet')
  axarr.set_title('Input image mean along 3rd dimension')
  plt.colorbar(r,fraction=0.01, pad=0.04)
  plt.savefig(base_path+'mean_input_chunk.png')
  plt.close()

  csal_maps,sal_maps = combine_sal_maps(sal_maps,output_size=[in_height,in_width],weights=weights,map_num=map_num)
  output_path = base_path +'Map_Combined.png'
  f, axarr = plt.subplots(1,1,figsize=(10,10))
  csal_map=csal_maps[0,:,:].numpy()
  imcs=csal_map/np.max(csal_map)
  im = axarr.imshow(imcs,cmap='jet')
  axarr.set_title('Combined saliency map')
  plt.colorbar(im,fraction=0.01, pad=0.04)
  plt.savefig(output_path)
  plt.close()

  il = [sal_maps[0,i,:,:] for i in range(map_num)] # Put each saliency map into the figure
  il.append(csal_maps[0,:,:])                       # add in the combined map at the end of the figure
  images        = [torch.stack(il, 0)]          
  images        = make_grid(images, nrow=5)
  sal_img=images.unsqueeze(1)
  output_path=base_path +'Sal_Maps.png'
  save_image(sal_img,output_path)

  input_path = output_path
  f, axarr = plt.subplots(1,1,figsize=(10,10))
  im=sio.imread(input_path)
  im=axarr.imshow(np.mean(im,axis=-1)/255, cmap='jet');
  axarr.set_title('layerwise saliency maps')
  plt.colorbar(im,fraction=0.01, pad=0.04)
  output_path=base_path +'Sal_Maps_jet.png'
  plt.savefig(output_path)
  plt.close()
  return csal_maps


'''''''''
B: GradCAM /GradCAM++
'''''''''

def get_grads(layer_name,tf_model,inputs):
  '''
  computes gradients for GCAM/GCAM++

  Arguments:  str: layer_name -> name of the last convolution activation layer.
              tf model: model -> tensorflow pretrained model
              list: inputs -> [input_image_tensor,gender_tensor]
       

  Returns: cam_list,grads,y,weights,output,img_chunk
  '''

  cam_list=[]
  gender= inputs[1] #check the gender tensor dimensions tf.constant([[1]],dtype=tf.float32)
  gender=tf.reshape(gender,[1,1])
  img=inputs[0]
  model = tf.keras.models.Model([tf_model.inputs], [tf_model.get_layer(name=layer_name).output, tf_model.output])

  img_chunk=tf.convert_to_tensor(img)
  img_chunk = tf.reshape(img_chunk,[1,121,145,6])

  # cdr_ohe_dict={0:[1.0,0.0,0.0,0.0],0.5:[0.0,1.0,0.0,0.0],1:[0.0,0.0,1.0,0.0],2:[0.0,0.0,0.0,1.0]}
  # cdr_keys= list(cdr_ohe_dict.keys())
  with tf.GradientTape() as tape:
      conv_outputs, predictions = model([img_chunk,gender])
      print(f'predictions={predictions}')
     
      y = predictions[0] # in case of an extra dimension [[]]
     

  output = conv_outputs[0]#[0,:,:,:,100]
  print(f'entering tape gradients')

  grads = tape.gradient(y, conv_outputs)[0]#[0,:,:,:,100]
  print(type(grads))
  print(f'Crossed tape gradients')

  # now there are 2 choice either use grads(raw grads) or use guided grads)
  # guided_grads = tf.cast(output > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads

  # print(f'Entering reduce mean using guided_grads with shape={guided_grads.shape}')
  
  #guided grads
  # weights = tf.reduce_mean(guided_grads, axis=(0,1,2))

  weights =  tf.reduce_mean(grads, axis=(0,1,2))


  print(f'Computing CAM using output with shape:{output.shape}')

  print(f'weights={weights.shape}')
  cam = np.zeros(output.shape[0:3], dtype=np.float32)
  print(cam.shape)


  cam=tf.reduce_sum(tf.multiply(output,weights),axis=-1)
  cam_list.append(cam)
  return cam_list,grads,y,weights,output,img_chunk




def compute_gcam_and_gcam_pp(layer_name,model,inputs):
  '''
  Generates GCAM/GCAM++ maps

  Arguments:  str: layer_name -> name of the last convolution activation layer.
              tf model: model -> tensorflow pretrained model
              list: inputs -> [input_image_tensor,gender_tensor]
       

  Returns: all numpy arrays: image, gcam_img,gcam_pp_img,y ( y is prediction)
  '''
  cam_list,grads,y,weights,output,img_chunk = get_grads(layer_name,model,inputs)
  
  heatmap_list=[]
  for i,cam in enumerate(cam_list):#as we are doing chunk wise so this camlist will have only one cam

    print(f'cam shape={cam.shape}')
    
    #gcam
    cam_map=resize(cam,(img_chunk.shape[1],img_chunk.shape[2],img_chunk.shape[3]))

    cam_map = np.maximum(cam_map,0)
    original_image=img_chunk.numpy()
   
    heatmap = (cam_map - cam_map.min()) / (cam_map.max() - cam_map.min())

  
    print(original_image.shape)
    image=np.mean(original_image[0,:,:,:],axis=-1)
    print(image.shape)

    mri_img=image#np.squeeze(image)
    heatmap_list.append(heatmap)


    heatmap_gcam = (cam_map - cam_map.min()) / (cam_map.max() - cam_map.min())

      
      
    gcam_img=(np.mean(heatmap_gcam,axis=-1)* 255).astype("uint8")
   
    #gcam++
    print(f'grads shape ={grads.shape},tf.exp(y) shape={tf.exp(y).shape}')
    conv_first_grad = tf.exp(y)[0]*grads
    #second_derivative
    conv_second_grad = tf.exp(y)[0]*grads*grads
    #triple_derivative
    conv_third_grad = tf.exp(y)[0]*grads*grads*grads
    
    
    global_sum = np.sum(tf.reshape(output,(-1,conv_first_grad[0].shape[2])), axis=0)
    print(f'conv_first_grad shape={conv_first_grad.shape},conv_second_grad shape={conv_second_grad.shape} ,  conv_third_grad shape={conv_third_grad.shape}, global_sum.shape={global_sum.shape}  ')
    alpha_num = conv_second_grad[0]

    alpha_denom = conv_second_grad*2.0 + conv_third_grad*global_sum.reshape((1,1,1,conv_first_grad[0].shape[2]))
    alpha_denom = np.where(alpha_denom != 0.0, alpha_denom, np.ones(alpha_denom.shape))
    alphas = alpha_num/alpha_denom
    #missing line added after refering Gcam++ paper
    weights= np.maximum(conv_first_grad[0], 0.0)
    

    alphas_thresholding = np.where(weights, alphas, 0.0)
    print(f'alphas_thresholding shape={alphas_thresholding.shape}')
    alpha_normalization_constant = np.sum(np.sum(alphas_thresholding, axis=0),axis=0)
    alpha_normalization_constant_processed = np.where(alpha_normalization_constant != 0.0, alpha_normalization_constant, np.ones(alpha_normalization_constant.shape))
    print(f'alpha_normalization_constant_processed shape={alpha_normalization_constant_processed.shape}')
    
    # alphas /= alpha_normalization_constant_processed.reshape((1,1,conv_first_grad[0].shape[2]))
    alphas /= alpha_normalization_constant_processed.reshape((1,1,3,conv_first_grad[0].shape[2]))
    print(f'weights.shape={weights.shape},alphas.shape={alphas.shape}')
    weights_alpha=tf.reduce_sum(tf.multiply(weights,alphas),axis=0)
    
    cam=tf.reduce_sum(tf.multiply(output,weights_alpha),axis=-1)
    
    cam_map=resize(cam,(img_chunk.shape[1],img_chunk.shape[2],img_chunk.shape[3]))
  
    
    print(f'cam_map={cam_map.shape}')
    cam_map = np.maximum(cam_map, 0)

    heatmap_gcam_pp = (cam_map - cam_map.min()) / (cam_map.max() - cam_map.min())


    gcam_pp_img=(np.mean(heatmap_gcam_pp,axis=-1) * 255).astype("uint8")
    
    print(img_chunk.shape,mri_img.shape,gcam_img.shape,type(mri_img),type(gcam_img))
  
        
    return image, gcam_img,gcam_pp_img,y


'''''''''
C: Combine Saliency Maps with GradCAM /GradCAM++
'''''''''

def combine_sal_gcam(base_path,csmap,gcam_img,gcam_pp_img,image,layer_name='',angle=0,result_path='' ):

  '''
  Arguments:  str: base_path -> base path to store .mat files for maps
              torch.Tensor: csmap -> combined saliency map 
              numpy array: gcam_pp_img -> gradcam map
              numpy array: gcam_pp_img -> gradcam ++ map
              numpy array: image -> image with size (121,145) , a mean is performed on third axis to create 2D maps for 3D inputs.
              str: layer_name -> name of the last convolution activation layer.
              float: angle  -> rotation angle for the final result.
              str: result_path -> path to store only final results and exclude supplementary files which are save in base_path

  Returns: all lists: scans,labels,gender,ids,cdr


  3 kinds of map computed : saliency map , saliency map combined  with GCAM, saliency map combined with GCAM++ 

  For each map following intermediate output arrays are important in this function block:

  - Gray matter:
  raw_tensor

  -  Saliency Map
  csmap

  - GradCAM Map
  gcam_img

  - GradCAM ++ Map
  gcam_pp_img
  
  - Alpha Blending : eg: 0.75*map + 0.25*gray

  result_*      (i.e. result_csmap --> only saliency map, result_gcam---> saliency+gradcam,  result_gcam_pp --> saliency + gcam++)
  
  - Hard Masked top x% pixels
  hard_masked_* (i.e. hard_masked_csmap --> only saliency map, hard_masked_gcam---> saliency+gradcam,  hard_masked_gcam_pp --> saliency + gcam++)

  - Alpha Mask : gray*map
  ***** This is the one we used in our to be published paper and results , we use this and then remove its blue background and overlay on gray *****
  masked_*  (i.e. masked_csmap --> only saliency map, masked_gcam---> saliency+gradcam,  masked_gcam_pp --> saliency + gcam++)

  '''
  print(gcam_img.shape,csmap.shape,gcam_pp_img.shape,image.shape)

  '''
  I : Only the Saliency Map
  '''
  raw_tensor=torch.from_numpy(image).unsqueeze(0)

  # saving gray matter
  
  output_path   = base_path+"raw_img.mat"
  savemat(output_path,{'data':raw_tensor.numpy() ,'shape':raw_tensor.shape})
  background_img=loadmat(output_path)['data']

  base_path+='_'+layer_name


  heatmap_csmap, result_csmap = visualize_cam(csmap, raw_tensor) 
  getMask                 = mask.SaliencyMaskDropout(keep_percent = 0.1, scale_map=False)
  hard_masked_csmap,_       = getMask(raw_tensor.unsqueeze(0),csmap)#.squeeze(0))
  hard_masked_csmap        = hard_masked_csmap.squeeze(0)
  masked_csmap             = misc.AlphaMask(raw_tensor, csmap.squeeze(0)).squeeze(0)
  
  

  # Supplementary plots not important

  vmin=0
  vmax=1.0
  f, axarr = plt.subplots(2,3,figsize=(20,20))
  img_plot = axarr[0][0].imshow(torch.mean(raw_tensor,axis=0),vmin=vmin,vmax=vmax, cmap='jet');
  axarr[0][0].set_title('input')
  cbar=plt.colorbar(img_plot,fraction=0.046, pad=0.04,ax=axarr[0][0])
  cbar.set_clim(0,1)
  img_plot = axarr[0][1].imshow(torch.mean(csmap,axis=0),vmin=vmin,vmax=vmax, cmap='jet');
  axarr[0][1].set_title('combined saliency map')
  cbar=plt.colorbar(img_plot,fraction=0.046, pad=0.04,ax=axarr[0][1])
  cbar.set_clim(0,1)
  img_plot = axarr[0][2].imshow(torch.mean(heatmap_csmap,axis=0),vmin=vmin,vmax=vmax, cmap='jet');
  axarr[0][2].set_title('saliency map')
  cbar=plt.colorbar(img_plot,fraction=0.046, pad=0.04,ax=axarr[0][2])
  cbar.set_clim(0,1)
  img_plot = axarr[1][0].imshow(torch.mean(result_csmap,axis=0),vmin=vmin,vmax=vmax, cmap='jet');
  axarr[1][0].set_title('saliency map with alpha blend')
  cbar=plt.colorbar(img_plot,fraction=0.046, pad=0.04,ax=axarr[1][0])
  cbar.set_clim(0,1)
  img_plot = axarr[1][1].imshow(masked_csmap,vmin=vmin,vmax=vmax, cmap='jet');
  axarr[1][1].set_title('saliency map with alpha mask')
  cbar=plt.colorbar(img_plot,fraction=0.046, pad=0.04,ax=axarr[1][1])
  cbar.set_clim(0,1)
  img_plot = axarr[1][2].imshow(hard_masked_csmap[0],vmin=vmin,vmax=vmax, cmap='jet');
  axarr[1][2].set_title('hard mask')
  cbar=plt.colorbar(img_plot,fraction=0.046, pad=0.04,ax=axarr[1][2])
  cbar.set_clim(0,1)
  plt.savefig(base_path+'saliency_only_fig.png')
  plt.close()

  print(hard_masked_csmap.permute([2,0,1]).shape)

  #IMPORTANT files for Saliency MAP: Actual .mat files.#########################
  csmap_img = torch.mean(csmap,axis=0).numpy()
  output_path   = base_path+"csmap.mat"
  savemat(output_path,{'data':csmap_img ,'shape':csmap_img.shape})
  output_path   = base_path+"heatmap_csmap.mat"
  savemat(output_path,{'data':heatmap_csmap.permute([1,2,0]).numpy() ,'shape':heatmap_csmap.permute([1,2,0]).numpy().shape})
  output_path   = base_path+"result_csmap.mat"
  savemat(output_path,{'data':result_csmap.permute([1,2,0]).numpy() ,'shape':result_csmap.permute([1,2,0]).numpy().shape})
  output_path   = base_path+"hard_masked_csmap.mat" 
  savemat(output_path,{'data':hard_masked_csmap.permute([1,2,0]).numpy() ,'shape':hard_masked_csmap.permute([1,2,0]).shape})
  output_path   = base_path+"masked_csmap.mat" 
  savemat(output_path,{'data':masked_csmap.numpy() ,'shape':masked_csmap.numpy().shape})
  masked_csmap_mat = loadmat(output_path)['data']
  plt.clf()

  ##############################################################################
  
  # to avoid any divide by zero
  if np.max(gcam_img) ==0:
    gcam_img = gcam_img+0.0000001
  if np.max(gcam_pp_img) ==0:
    gcam_pp_img = gcam_pp_img+0.0000001
  gcam_img_tensor=torch.from_numpy(gcam_img).unsqueeze(0)
  mask_gcam = csmap*(gcam_img_tensor)
  mask_gcam=mask_gcam/mask_gcam.max()


  
  
  #save gcam and gcam++ map side by side fig
  vmin=np.amin([np.min(gcam_img),np.min(gcam_pp_img)])
  vmax=np.amax([np.max(gcam_img),np.max(gcam_pp_img)])

  f, axarr = plt.subplots(1,2,figsize=(10,10))
  img_plot = axarr[0].imshow(gcam_img,vmin=vmin,vmax=vmax, cmap='jet');
  axarr[0].set_title('Gradcam')
  img_plot = axarr[1].imshow(gcam_pp_img,vmin=vmin,vmax=vmax, cmap='jet');
  axarr[1].set_title('Gradcam++')
  plt.colorbar(img_plot,fraction=0.046, pad=0.04)
  plt.savefig(base_path+'gcam_gcam++_fig.png')
  plt.close()

  
  
  '''
  II: Saliency + GRADCAM
  '''
  
  heatmap_gcam, result_gcam = visualize_cam(mask_gcam, raw_tensor) 
  getMask                 = mask.SaliencyMaskDropout(keep_percent = 0.1, scale_map=False)
  hard_masked_gcam,_       = getMask(raw_tensor.unsqueeze(0),mask_gcam)#.squeeze(0))
  hard_masked_gcam        = hard_masked_gcam.squeeze(0)
  masked_gcam             = misc.AlphaMask(raw_tensor, mask_gcam.squeeze(0)).squeeze(0)
  # mx= str(np.max(masked_gcam.numpy()))
  # plt.imsave(base_path+'masked_gcam_unnormalized_{0}max.png'.format(mx),masked_gcam.numpy(),cmap='jet')
  # masked_gcam              = misc.RangeNormalize(masked_gcam)

  # Supplementary plots not important

  vmin=0
  vmax=1.0
  f, axarr = plt.subplots(2,3,figsize=(20,20))
  img_plot = axarr[0][0].imshow(torch.mean(raw_tensor,axis=0),vmin=vmin,vmax=vmax, cmap='jet');
  axarr[0][0].set_title('input')
  cbar=plt.colorbar(img_plot,fraction=0.046, pad=0.04,ax=axarr[0][0])
  cbar.set_clim(0,1)
  img_plot = axarr[0][1].imshow(torch.mean(csmap,axis=0),vmin=vmin,vmax=vmax, cmap='jet');
  axarr[0][1].set_title('combined saliency map')
  cbar=plt.colorbar(img_plot,fraction=0.046, pad=0.04,ax=axarr[0][1])
  cbar.set_clim(0,1)
  img_plot = axarr[0][2].imshow(torch.mean(heatmap_gcam,axis=0),vmin=vmin,vmax=vmax, cmap='jet');
  axarr[0][2].set_title('saliency map + gradcam')
  cbar=plt.colorbar(img_plot,fraction=0.046, pad=0.04,ax=axarr[0][2])
  cbar.set_clim(0,1)
  img_plot = axarr[1][0].imshow(torch.mean(result_gcam,axis=0),vmin=vmin,vmax=vmax, cmap='jet');
  axarr[1][0].set_title('saliency map+gradcam with alpha blend')
  cbar=plt.colorbar(img_plot,fraction=0.046, pad=0.04,ax=axarr[1][0])
  cbar.set_clim(0,1)
  img_plot = axarr[1][1].imshow(masked_gcam,vmin=vmin,vmax=vmax, cmap='jet');
  axarr[1][1].set_title('saliency map+gradcam with alpha mask')
  cbar=plt.colorbar(img_plot,fraction=0.046, pad=0.04,ax=axarr[1][1])
  cbar.set_clim(0,1)
  img_plot = axarr[1][2].imshow(hard_masked_gcam[0],vmin=vmin,vmax=vmax, cmap='jet');
  axarr[1][2].set_title('hard mask')
  cbar=plt.colorbar(img_plot,fraction=0.046, pad=0.04,ax=axarr[1][2])
  cbar.set_clim(0,1)
  plt.savefig(base_path+'sal+gcam_fig.png')
  plt.close()

  print(hard_masked_gcam.permute([2,0,1]).shape)

  #IMPORTANT files for Saliency MAP + GradCAM : Actual .mat files.##############
  output_path   = base_path+"gcam_img.mat"
  savemat(output_path,{'data':gcam_img ,'shape':gcam_img.shape})
  output_path   = base_path+"heatmap_gcam.mat"
  savemat(output_path,{'data':heatmap_gcam.permute([1,2,0]).numpy() ,'shape':heatmap_gcam.permute([1,2,0]).numpy().shape})
  output_path   = base_path+"result_gcam.mat"
  savemat(output_path,{'data':result_gcam.permute([1,2,0]).numpy() ,'shape':result_gcam.permute([1,2,0]).numpy().shape})
  output_path   = base_path+"hard_masked_gcam.mat" 
  savemat(output_path,{'data':hard_masked_gcam.permute([1,2,0]).numpy() ,'shape':hard_masked_gcam.permute([1,2,0]).shape})
  output_path   = base_path+"masked_gcam.mat" 
  savemat(output_path,{'data':masked_gcam.numpy() ,'shape':masked_gcam.numpy().shape})
  masked_gcam_mat = loadmat(output_path)['data']
  plt.clf()
 

  ##############################################################################
  '''
  III: Saliency + GRADCAM++
  '''
  gcam_pp_img_tensor=torch.from_numpy(gcam_pp_img).unsqueeze(0)
  mask_gcam_pp = csmap*(gcam_pp_img_tensor)
  mask_gcam_pp=mask_gcam_pp/mask_gcam_pp.max()
  raw_tensor=torch.from_numpy(image).unsqueeze(0)
  heatmap_gcam_pp, result_gcam_pp = visualize_cam(mask_gcam_pp, raw_tensor)

  hard_masked_gcam_pp,_       = getMask(raw_tensor.unsqueeze(0),mask_gcam_pp)#.squeeze(0))
  hard_masked_gcam_pp         = hard_masked_gcam_pp.squeeze(0)
  masked_gcam_pp           = misc.AlphaMask(raw_tensor, mask_gcam_pp.squeeze(0)).squeeze(0)
  # mx= str(np.max(masked_gcam_pp.numpy()))
  # plt.imsave(base_path+'masked_gcam_pp_unnormalized_{0}max.png'.format(mx),masked_gcam_pp.numpy(),cmap='jet')
  # masked_gcam_pp           = misc.RangeNormalize(masked_gcam_pp) # avoid this step as it will normalize to 0 to 1 hence not good while comparing multiple scans

  

  
  
  
  # Supplementary plots not important

  vmin=np.amin([torch.min(raw_tensor),torch.min(csmap),torch.min(heatmap_gcam_pp),torch.min(result_gcam_pp),torch.min(masked_gcam_pp),torch.min(hard_masked_gcam_pp)])
  vmax=np.amax([torch.max(raw_tensor),torch.max(csmap),torch.max(heatmap_gcam_pp),torch.max(result_gcam_pp),torch.max(masked_gcam_pp),torch.max(hard_masked_gcam_pp)])

  vmin=0
  vmax=1.0

  f, axarr = plt.subplots(2,3,figsize=(20,20))
  img_plot = axarr[0][0].imshow(torch.mean(raw_tensor,axis=0),vmin=vmin,vmax=vmax, cmap='jet');
  axarr[0][0].set_title('input')
  cbar=plt.colorbar(img_plot,fraction=0.046, pad=0.04,ax=axarr[0][0])
  cbar.set_clim(0,1)
  img_plot = axarr[0][1].imshow(torch.mean(csmap,axis=0),vmin=vmin,vmax=vmax, cmap='jet');
  axarr[0][1].set_title('combined saliency map')
  cbar=plt.colorbar(img_plot,fraction=0.046, pad=0.04,ax=axarr[0][1])
  cbar.set_clim(0,1)
  img_plot = axarr[0][2].imshow(torch.mean(heatmap_gcam_pp,axis=0),vmin=vmin,vmax=vmax, cmap='jet');
  axarr[0][2].set_title('saliency map + gradcam++')
  cbar=plt.colorbar(img_plot,fraction=0.046, pad=0.04,ax=axarr[0][2])
  cbar.set_clim(0,1)
  img_plot = axarr[1][0].imshow(torch.mean(result_gcam_pp,axis=0),vmin=vmin,vmax=vmax, cmap='jet');
  axarr[1][0].set_title('saliency map+gradcam++ with alpha blend')
  cbar=plt.colorbar(img_plot,fraction=0.046, pad=0.04,ax=axarr[1][0])
  cbar.set_clim(0,1)
  img_plot = axarr[1][1].imshow(masked_gcam_pp,vmin=vmin,vmax=vmax, cmap='jet');
  axarr[1][1].set_title('saliency map+gradcam++ with alpha mask')
  cbar=plt.colorbar(img_plot,fraction=0.046, pad=0.04,ax=axarr[1][1])
  cbar.set_clim(0,1)
  img_plot = axarr[1][2].imshow(hard_masked_gcam_pp[0],vmin=vmin,vmax=vmax, cmap='jet');
  axarr[1][2].set_title('hard mask')
  cbar=plt.colorbar(img_plot,fraction=0.046, pad=0.04,ax=axarr[1][2])
  cbar.set_clim(0,1)
  plt.savefig(base_path+'sal+gcam++_fig.png')
  plt.close()


  raw_img = torch.mean(raw_tensor,axis=0).numpy()
  output_path   = base_path+"raw_input.png"
  savemat(output_path.split('.png')[0] +'.mat',{'data':raw_img ,'shape':raw_img.shape})

  f, axarr = plt.subplots(1,1,figsize=(10,10))
  
  r=axarr.imshow(raw_img,cmap='gray')
  axarr.set_title('raw gray image')
  cbar=plt.colorbar(r,fraction=0.046, pad=0.04)
  cbar.set_clim(0,1)
  plt.savefig(base_path+'raw_gray_cbar.png')
  plt.close()

  #IMPORTANT files for Saliency MAP + GradCAM : Actual .mat files.##############
  
  output_path   = base_path+"gcam_pp_img.mat"
  savemat(output_path,{'data':gcam_pp_img ,'shape':gcam_pp_img.shape})
  output_path   = base_path+"heatmap_gcam_pp.mat"
  savemat(output_path,{'data':heatmap_gcam_pp.permute([1,2,0]).numpy() ,'shape':heatmap_gcam_pp.permute([1,2,0]).numpy().shape})
  output_path   = base_path+"result_gcam_pp.mat"
  savemat(output_path,{'data':result_gcam_pp.permute([1,2,0]).numpy() ,'shape':result_gcam_pp.permute([1,2,0]).numpy().shape})
  output_path   = base_path+"hard_masked_gcam_pp.mat" 
  savemat(output_path,{'data':hard_masked_gcam_pp.permute([1,2,0]).numpy() ,'shape':hard_masked_gcam_pp.permute([1,2,0]).numpy().shape})
  output_path   = base_path+"masked_gcam_pp.mat" 
  savemat(output_path,{'data':masked_gcam_pp.numpy() ,'shape':masked_gcam_pp.numpy().shape})
  masked_gcam_pp_mat = loadmat(output_path)['data']
  plt.clf()
  # p=plt.imshow(masked_gcam_pp_mat,cmap='jet')
  # plt.colorbar(p)      
  # plt.clim(0.8,1)
  # output_path   = base_path+"masked_gcam_pp_0.8.png" 
  # plt.savefig(output_path)
  ##############################################################################




  ### VERY IMPORTANT: Final Result images begin ################################

  # Just the Background Removed Mask
  im=ndimage.rotate(masked_gcam_pp_mat,angle)
  max=1
  im = im/max #optional step as in our case max is 1 also use any contant val.

  im[im<0.3]=np.nan
  plt.imshow(im,cmap='jet')
  plt.axis('off')
  plt.clim(0,1)
  plt.savefig(result_path+'_result.png')
  plt.close()
  
  
  '''
  These are the images used in the final results the ones inside result_path
  '''

  # Background Removed Mask + Overlay 

  plt.clf() # clear existing figure
  #mask overlaid on gray matter
  print(f'background shape before={background_img.shape}')
  im2=background_img[0]
  print(f'background shape after={im2.shape}')
  im2=ndimage.rotate(im2,angle)
  im2=1-im2
  gray=plt.imshow(im2,cmap='gray')
  plt.axis('off')
  im=ndimage.rotate(masked_gcam_pp_mat,angle)
  im[im<0.3]=np.nan
  heat=plt.imshow(im,cmap='jet')
  plt.axis('off')
  plt.clim(0,1)
  plt.colorbar()
  plt.savefig(result_path+'_result_overlay.png')
  plt.close()

  ######### Final Result images end ############################################
 

  ## Below section is to experiment different masking thresholds################# 


  ## Saliency only
  masked_csmap=masked_csmap.numpy()

  t='masked_only_saliency'



  max =  1 #or use any other constant
  
  frac=0.3 #0.5,0.8
  r1=(masked_csmap/max)

  r1[np.where(r1<frac*np.max(r1))]=0

  plt.imsave(base_path+'nodiff_{0}_{1}.png'.format(frac,t),r1,cmap='jet')

  ## GCAM
  masked_gcam=masked_gcam.numpy()
  t='masked_gcam'

  max =  1 #np.amax(masked_gcam)

  frac=0.3 # 0.5, 0.8 
  r1=(masked_gcam/max)

  r1[np.where(r1<frac*np.max(r1))]=0

  plt.imsave(base_path+'nodiff_{0}_{1}.png'.format(frac,t),r1,cmap='jet')



  ## GCAM++
  ## fraction mask  map for overlaying GCAM++
  masked_gcam_pp=masked_gcam_pp.numpy()

  

  t='masked_gcam_pp'
  max = 1 #np.amax(masked_gcam_pp)

  frac=0.3 #0.5,0.8
  r1=(masked_gcam_pp/max)

  r1[np.where(r1<frac*np.max(r1))]=0

  plt.imsave(base_path+'nodiff_{0}_{1}.png'.format(frac,t),r1,cmap='jet')
  


/content/fastcam


# **Tf records loading and parsing utility**

### This section parses through the tfrecords and reads through corresponding cdrs , genders and labels and returns along with the parsed image.

In [4]:
def get_selected_scan_from_subjects(data_path,subject_ids,label_df,selected_scans):
    '''
    Same as get_scan_from_subjects in case we need specific scan path and not all scans of a patient.
    Arguments: str: data_path -> path to tfrecords data
               str: subject_ids -> list of subjects
              list: label_df -> dataframe of full csv file i.e.  oasis1_oasis3_labels.csv
              list: selected_scans -> which scans for the above subjects is needed
    Returns: all lists: scans,labels,gender,ids,cdr
    '''
    scans=[]
    labels=[]
    gender=[]
    cdr=[]
    ids=[]
    subject_ids = set(subject_ids)
    for subject in subject_ids :
        path=os.path.join(data_path,subject)
        paths=os.listdir(path)

        ids.extend([scan_id.split('.')[0] for scan_id in paths  if scan_id.split('/')[-1].split('.')[0] in selected_scans ])
        scans.extend([ os.path.join(path,scan_id) for scan_id in paths   if scan_id.split('/')[-1].split('.')[0] in selected_scans ])
        
    
        labels.extend([label_df[label_df['MRI ID']==scan_id.split('.')[0]]['Age'].to_list()[0] for scan_id in paths   if scan_id.split('/')[-1].split('.')[0] in selected_scans ])
        gender.extend([label_df[label_df['MRI ID']==scan_id.split('.')[0]]['M/F'].to_list()[0] for scan_id in paths   if scan_id.split('/')[-1].split('.')[0] in selected_scans])
        cdr.extend([label_df[label_df['MRI ID']==scan_id.split('.')[0]]['CDR'].to_list()[0] for scan_id in paths  if scan_id.split('/')[-1].split('.')[0] in selected_scans])

    return scans,labels,gender,ids,cdr

def get_scan_from_subjects(data_path,subject_ids,label_df):
    '''
    Computes scan path for the subjects passed.

    Arguments: str: data_path -> path to tfrecords data
               str: subject_ids -> list of subjects
              list: label_df -> dataframe of full csv file i.e.  oasis1_oasis3_labels.csv
              list: selected_scans -> which scans for the above subjects is needed

    Returns: scans,labels,gender,ids,cdr
    '''
    scans=[]
    labels=[]
    gender=[]
    cdr=[]
    ids=[]
    subject_ids = set(subject_ids)
    for subject in subject_ids :
        path=os.path.join(data_path,subject)
        paths=os.listdir(path)
        ids.extend([scan_id.split('.')[0] for scan_id in paths ])
        scans.extend([ os.path.join(path,scan_path) for scan_path in paths  ])
 
        labels.extend([label_df[label_df['MRI ID']==scan_id.split('.')[0]]['Age'].to_list()[0] for scan_id in paths   ])
        gender.extend([label_df[label_df['MRI ID']==scan_id.split('.')[0]]['M/F'].to_list()[0] for scan_id in paths  ])
        cdr.extend([label_df[label_df['MRI ID']==scan_id.split('.')[0]]['CDR'].to_list()[0] for scan_id in paths ])
    #print(labels,gender)
           
    # shuffle(scans)
    
    return scans,labels,gender,ids,cdr


def get_test_files(label_path,data_path,debug_mode_subject=None,selected_scans=[]):
    '''
    Primary function to get scans.
    Arguments: str: label_path -> path to full csv file i.e.  oasis1_oasis3_labels.csv
              str: data_path -> path to tfrecord data (organised as : path_to_tfrecord/subject/scan_id.tfrecord, eg: path_to_tfrecord/OAS30001/OAS30001_MR_d0001.tfrecord )
              list: debug_mode_subject -> list of subjects to be used
              list: selected_scans(optional) -> which scans for the above subjects is needed

    Returns: all lists: test_patients,scan_ids, test_labels,test_gender,test_cdr

    '''

    data = pd.read_csv(label_path)
    data = data.rename(columns={'MR ID':'MRI ID'})
    print(data.columns)
    
    data['M/F'] = encode_gender(data)
    
    if debug_mode_subject is None:
      test_ids = os.listdir(data_path)
    else:
      test_ids=debug_mode_subject
    
    shuffle(test_ids)
    if len(selected_scans)>0:
      test_patients,test_labels,test_gender,scan_ids,test_cdr = get_selected_scan_from_subjects(data_path,test_ids,data,selected_scans)
    else:
      test_patients,test_labels,test_gender,scan_ids,test_cdr = get_scan_from_subjects(data_path,test_ids,data)
   
    return test_patients,scan_ids, test_labels,test_gender,test_cdr
   
def encode_gender(data):
    '''
    Categorical encoding. for gender required only if the column has string data eg: 'F', 'M'
    Female : 0
    Male : 1

    Arguments:  DataFrame df
    Returns: Encoded gender column.
    '''
    data['M/F'] = pd.Categorical(data['M/F'])
    
    return data['M/F'].cat.codes

def parse_function_image(example_proto):

    features = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'image_shape': tf.io.FixedLenFeature([], tf.string)
    }

    content = tf.io.parse_single_example(example_proto, features=features)

    content['image_shape'] = tf.io.decode_raw(content['image_shape'], tf.int32)
    content['image'] = tf.io.decode_raw(content['image'], tf.float32)
    content['image'] = tf.reshape(content['image'], content['image_shape'])

    return content['image']


# **Visualize maps on BA estimation agenet model on testset**

### We first use the below config dictionary to access the BA estimation model path along with the tfrecords data path and load the model. 

### There are in total 3 csv files:
**.../BA_Estimation/csv_data/oasis1_oasis3_labels.csv** #csv data for entire dataset

**.../BA_Estimation/models/exp_ba/non_outliers.csv** #csv data for non-outliers obtained after cleaning using BA Estimation model

**.../BA_Estimation/models/exp_ba/outlier.csv** #csv data for outliers obtained after cleaning using BA Estimation model

In [5]:
#load model

cf={'Pretrained_Model':{'path':'/content/drive/My Drive/BA_Estimation/models/exp_ba/age_net.hdf5'},'Paths':\
      {'labels':'/content/drive/My Drive/BA_Estimation/csv_data/oasis1_oasis3_labels.csv',\
       'test_tfrecord':'/content/drive/My Drive/BA_Estimation/tf_records_data/training_testing_exp4'}}#testing_all_cdr
model =  tf.python.keras.models.load_model(cf['Pretrained_Model']['path'],compile=False)


### Once the model is loaded we specify the subjects by passing a list argument to the function **get_test_files()** as follows: 

### **For eg:**

### **get_test_files**(label_path,data_path,debug_mode_subject=['OAS31098','OAS31156'],selected_scans=[ ]) \# gets all scans belonging to these 2 subjects
**OR**
### **get_test_files**(label_path,data_path,debug_mode_subject=['OAS31098','OAS31156'],selected_scans=['OAS31098_MR_d7178','OAS301156_MR_d0001']) \# gets only specified scans belonging to these 2 subjects
**OR**
### **get_test_files**(label_path,data_path,debug_mode_subject=None) \# gets all scans belonging to all subjects in the tfrecord directory cf['test_tfrecord']

**OR**
### **sub,scans**=**get_subject_scan_names_from_filtered_data**(df) \# dataframe: df,  can be either of healthy subject belonging to non_outliers.csv or AD subjects belonging to outlier.csv
### **get_test_files**(label_path,data_path,debug_mode_subject=sub,selected_scans=scans) \# for every subject in list: sub, it gets the corresponding scans in list: selected_scans .

In [6]:
def get_subject_scan_names_from_filtered_data(df): 
  '''
  Arguments: dataframe df can be either of healthy subject belonging to non_outliers.csv or AD subjects belonging to outlier.csv

  Returns: 2 lists : sub,scans

  For eg: 
  sub = ['OAS30756','OAS30756','OAS30020', 'OAS30535']
  scans=['OAS30756_MR_d0014', 'OAS30756_MR_d0022','OAS30020_MR_d0092','OAS30535_MR_d0139']

  (Note: length of sub and scans will be same : there can be multiple entries for same subject in sub this indicates there are multiple scans for this subject
  As we can see here 'OAS30756' is occuring twice in the list : sub, because there are two scans i.e. 'OAS30756_MR_d0014', 'OAS30756_MR_d0022' associated to  'OAS30756' )
  '''
  sub =  df['patient_id'].values.tolist()
  scans= copy.deepcopy(sub)
  for i,s in enumerate(sub) :
    if s.startswith('OAS1'):
      s= s[:9]                  #OAS1_0123_MR1 take first 9 characters
      sub[i] = s.replace('_','')
    elif s.startswith('OAS3'):  #OAS31098_MR_d7178 #take just subject id
      sub[i] = s.split('_')[0]
  
  print(f'subjects={sub}') 
  print(f'scans={scans}')
  return sub,scans

In [7]:
dt_string = datetime.now().strftime('%d-%m-%Y-%H-%M')+'_smoe_maps_blockend_scale_endlayers_equal_weights'
print(dt_string)

08-10-2020-16-14_smoe_maps_blockend_scale_endlayers_equal_weights


In [8]:
# Main Cell to run and generate the Visualization maps

print(tf.__version__)



tf_model = model


label_path= cf['Paths']['labels']
data_path= cf['Paths']['test_tfrecord']


exp_prefix='exp_ba'
base_path_prefix = '/content/drive/My Drive/BA_Estimation'

############################## SELECT CASE ################################################

# case='healthy'
case = 'outliers' 


if case == 'healthy':
  #healthy subjects
  healthy_path = base_path_prefix+'/models/{0}/non_outliers.csv'.format(exp_prefix)
  df = pd.read_csv(healthy_path)
  exp='exp_ba_healthy'
else:
  #Outlier Subjects
  ad_path = base_path_prefix+'/models/{0}/outlier.csv'.format(exp_prefix)
  df = pd.read_csv(ad_path)
  exp='exp_ba_outliers'

###########################################################################################


## Either Manually specify  scans if we need specific scans


scans=['OAS30440_MR_d0163','OAS30773_MR_d0044','OAS30189_MR_d0072',]#'OAS30294_MR_d0462','OAS30091_MR_d0092','OAS30190_MR_d0082','OAS30665_MR_d4735','OAS30971_MR_d0077']
# scans=['OAS30030_MR_d0170','OAS30253_MR_d1288','OAS30608_MR_d1272','OAS30395_MR_d1241','OAS30559_MR_d0431','OAS30590_MR_d0085','OAS31072_MR_d4621']
# scans=['OAS30189_MR_d0072','OAS30194_MR_d8874','OAS30387_MR_d0616','OAS30387_MR_d3401','OAS30543_MR_d0185','OAS30658_MR_d0237','OAS31002_MR_d5329','OAS31076_MR_d0071','OAS30294_MR_d0462','OAS30369_MR_d2819']
sub=[s.split('_')[0] for s in scans]
test_patients,scan_ids, test_labels,test_gender,test_cdr = get_test_files(label_path,data_path,debug_mode_subject=sub,selected_scans=scans)

## Or read through all subjects and scans from no_outliers.csv or outliers.csv and use these subjects and its corresponding scans

# sub,scans= get_subject_scan_names_from_filtered_data(df)

# test_patients,scan_ids, test_labels,test_gender,test_cdr = get_test_files(label_path,data_path,\
#                                                                           debug_mode_subject=sub,selected_scans=scans) #incase we want to do on smaller subsets we can pass sub[:50],scans[:50]

tfr=tf.data.TFRecordDataset(test_patients)
img_tf=tfr.map(map_func=lambda a:parse_function_image(a))

gender_dict={0:'Female',1:'Male'}
counter =0

for i,im in enumerate(img_tf): 
    #get tf records
    
    print(type(im),im.shape)
    # if test_cdr[i] !=1 : #if block to only include specific CDR subjects for visualizations
    #   continue 
    
    counter+=1
    img=im.numpy()
    print(img.shape)
  
    max_intensity=0
    csmap_list=[]
    
    for chunk_id in [6,7,9,10,11,12]: #specify chunk id between  1 to 20 #range(1,21):

      start = (chunk_id-1)*6
      end = chunk_id*6
      
      img_chunk=torch.tensor(img[:,:,start:end])
      img_chunk = img_chunk.unsqueeze(0)
      input_tensor = img_chunk.unsqueeze(0)

      in_height   = input_tensor.size()[2]
      in_width    = input_tensor.size()[3]
      print(test_gender[i],input_tensor.shape,scan_ids[i])
     
      n=66 # last conv layer  for gcam/++
      layer_name='activation_'+str(n)
      conv_path ='conv1_'+str(n)
      
      cdr_val = str(test_cdr[i])


      #axial

      base_dir= base_path_prefix+'/results/sal_map_axial/'+dt_string\
      +'_'+exp+'/'+scan_ids[i]+'_cdr'+str(test_cdr[i])
      path = base_dir+'/'+str(chunk_id)+'_'+conv_path+'/'
      print(path)
      if not os.path.exists(path):
        os.makedirs(path)
      
      base_path  = path+scan_ids[i]+'_chunk_'+str(chunk_id)
      result_path = base_path_prefix+'/final_results/{0}/'.format(case)+\
      dt_string+'/cdr'+cdr_val+ '/axial/'+str(chunk_id)+'/'

      if not os.path.exists(result_path):
        os.makedirs(result_path)

      csmap_a=compute_saliency_tf(base_path,inputs=[img[:,:,start:end],test_gender[i]],tf_model=tf_model)
      image,gcam_img,gcam_pp_img,pred = compute_gcam_and_gcam_pp(layer_name,tf_model,[img[:,:,start:end],test_gender[i]])
      result_path += scan_ids[i] #+'_'+ str(pred.numpy()[0])
      combine_sal_gcam(path+scan_ids[i]+'_cdr'+str(test_cdr[i])+'_'+gender_dict[test_gender[i]],csmap_a,gcam_img,gcam_pp_img,image,layer_name=layer_name,angle=-270,result_path=result_path ) 
      
      
      
      #sagittal

      base_dir= base_path_prefix+'/results/sal_map_sagittal/'+dt_string\
      +'_'+exp+'/'+scan_ids[i]+'_cdr'+str(test_cdr[i])
      path = base_dir+'/'+str(chunk_id)+'_'+conv_path+'/'
      print(path)
      
      if not os.path.exists(path):
        os.makedirs(path)
      base_path  = path+scan_ids[i]+'_chunk_'+str(chunk_id)
      img_s= torch.from_numpy(img[start:end,:,:]).permute(2,1,0)
      result_path = base_path_prefix+'/final_results/{0}/'.format(case)+dt_string +'/cdr'+cdr_val+'/sagittal/'+str(chunk_id)+'/'
      if not os.path.exists(result_path):
        os.makedirs(result_path)
      
      csmap_s=compute_saliency_tf(base_path,inputs=[img_s.numpy(),test_gender[i]],tf_model=tf_model)
      image,gcam_img,gcam_pp_img,pred = compute_gcam_and_gcam_pp(layer_name,tf_model,[img_s.numpy(),test_gender[i]])
      result_path += scan_ids[i] #+'_'+str(pred.numpy()[0])
      combine_sal_gcam(path+scan_ids[i]+'_cdr'+str(test_cdr[i])+'_'+gender_dict[test_gender[i]],csmap_s,gcam_img,gcam_pp_img,image,layer_name=layer_name,angle=180,result_path=result_path ) 



      #coronal

      base_dir= base_path_prefix+'/results/sal_map_coronal/'+dt_string\
      +'_'+exp+'/'+scan_ids[i]+'_cdr'+str(test_cdr[i])
      # path = base_dir+'/'+str(chunk_id)+'_old_'+conv_path+'/'
      result_path = base_path_prefix+'/final_results/{0}/'.format(case)+dt_string+'/cdr'+cdr_val+'/coronal/'+str(chunk_id)+'/'
      path = base_dir+'/'+str(chunk_id)+'_'+conv_path+'/'
      print(path)
      if not os.path.exists(result_path):
        os.makedirs(result_path)
      
      if not os.path.exists(path):
        os.makedirs(path)
      base_path  = path+scan_ids[i]+'_chunk_'+str(chunk_id)
      img_c= torch.from_numpy(img[:,start:end,:]).permute(2,0,1)
      img_c=img_c.unsqueeze(0)
      img_c = torch.nn.functional.upsample(img_c.unsqueeze(0), size=(121,145,6), mode='nearest') #interpolation required for coronal otherwise we will get 121,121,6 instead of 121,145,6
      csmap_c=compute_saliency_tf(base_path,inputs=[img_c.numpy(),test_gender[i]],tf_model=tf_model)
      image,gcam_img,gcam_pp_img,pred = compute_gcam_and_gcam_pp(layer_name,tf_model,[img_c.numpy(),test_gender[i]])
      result_path += scan_ids[i] #+'_'+str(pred.numpy()[0])
      combine_sal_gcam(path+scan_ids[i]+'_cdr'+str(test_cdr[i])+'_'+gender_dict[test_gender[i]],csmap_c,gcam_img,gcam_pp_img,image,layer_name=layer_name,angle=180,result_path=result_path ) 

      
    print(f'scan count={counter}')

     
    


2.3.0
Index(['Unnamed: 0', 'MRI ID', 'Subject', 'M/F', 'Hand', 'Age', 'CDR'], dtype='object')
<class 'tensorflow.python.framework.ops.EagerTensor'> (121, 145, 121)
(121, 145, 121)
0 torch.Size([1, 1, 121, 145, 6]) OAS30440_MR_d0163
/content/drive/My Drive/BA_Estimation/results/sal_map_axial/08-10-2020-16-14_smoe_maps_blockend_scale_endlayers_equal_weights_exp_ba_outliers/OAS30440_MR_d0163_cdr2.0/6_conv1_66/
(121, 145, 6)
ouput shapes layerwise
(1, 61, 73, 6, 64)
ouput shapes layerwise
(1, 31, 37, 6, 192)
ouput shapes layerwise
(1, 16, 19, 6, 64)
ouput shapes layerwise
(1, 8, 10, 3, 128)
ouput shapes layerwise
(1, 4, 5, 3, 256)
 smoe input shape=(1, 61, 73, 64)
x range=(2.0798614, 0.0)
log of mean=[[[-2.2742128 -2.2742128 -2.2742128 ... -2.2742128 -2.2742128 -2.2742128]
  [-2.2742128 -2.2742128 -2.2742128 ... -2.2742128 -2.2742128 -2.2742128]
  [-2.2742128 -2.2742128 -2.2742128 ... -2.2742128 -2.2742128 -2.2742128]
  ...
  [-2.2742128 -2.2742128 -2.2742128 ... -2.2742128 -2.2742128 -2.2

The set_clim function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use ScalarMappable.set_clim instead.
The set_clim function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use ScalarMappable.set_clim instead.
The set_clim function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use ScalarMappable.set_clim instead.
The set_clim function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use ScalarMappable.set_clim instead.
The set_clim function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use ScalarMappable.set_clim instead.
The set_clim function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use ScalarMappable.set_clim instead.


torch.Size([145, 1, 121])


The set_clim function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use ScalarMappable.set_clim instead.
The set_clim function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use ScalarMappable.set_clim instead.
The set_clim function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use ScalarMappable.set_clim instead.
The set_clim function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use ScalarMappable.set_clim instead.
The set_clim function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use ScalarMappable.set_clim instead.
The set_clim function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use ScalarMappable.set_clim instead.


torch.Size([145, 1, 121])


The set_clim function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use ScalarMappable.set_clim instead.
The set_clim function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use ScalarMappable.set_clim instead.
The set_clim function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use ScalarMappable.set_clim instead.
The set_clim function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use ScalarMappable.set_clim instead.
The set_clim function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use ScalarMappable.set_clim instead.
The set_clim function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use ScalarMappable.set_clim instead.
The set_clim function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use ScalarMappable.set_clim instead.


background shape before=(1, 121, 145)
background shape after=(121, 145)
/content/drive/My Drive/BA_Estimation/results/sal_map_sagittal/08-10-2020-16-14_smoe_maps_blockend_scale_endlayers_equal_weights_exp_ba_outliers/OAS30440_MR_d0163_cdr2.0/6_conv1_66/
(121, 145, 6)
ouput shapes layerwise
(1, 61, 73, 6, 64)
ouput shapes layerwise
(1, 31, 37, 6, 192)
ouput shapes layerwise
(1, 16, 19, 6, 64)
ouput shapes layerwise
(1, 8, 10, 3, 128)
ouput shapes layerwise
(1, 4, 5, 3, 256)
 smoe input shape=(1, 61, 73, 64)
x range=(2.0148308, 0.0)
log of mean=[[[-2.2742128 -2.2742128 -2.2742128 ... -2.2742128 -2.2742128 -2.2742128]
  [-2.2742128 -2.2742128 -2.2742128 ... -2.2742128 -2.2742128 -2.2742128]
  [-2.2742128 -2.2742128 -2.2742128 ... -2.2742128 -2.2742128 -2.2742128]
  ...
  [-2.2742128 -2.2742128 -2.2742128 ... -2.2742128 -2.2742128 -2.2742128]
  [-2.2742128 -2.2742128 -2.2742128 ... -2.2742128 -2.2742128 -2.2742128]
  [-2.2742128 -2.2742128 -2.2742128 ... -2.2742128 -2.2742128 -2.2742128]]]



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
   -8.065555 ]
  [-7.9893365 -4.903755  -4.5329676 -4.25421   -3.8384686 -3.051575
   -3.82358   -3.5740337 -3.7643502 -3.789906  -3.5149517 -3.6206484
   -2.7950718 -3.2250667 -3.2845073 -2.9412217 -3.5824409 -4.5924463
   -8.750191 ]
  [-8.229824  -6.000616  -4.455201  -4.708456  -4.260369  -3.574072
   -4.5821238 -3.6377954 -3.4341068 -2.249676  -3.8895292 -4.3092337
   -4.6564035 -4.217499  -3.3563251 -4.034939  -5.3230386 -4.1913276
   -8.967519 ]
  [-8.02641   -5.2006645 -4.876198  -4.3053226 -4.0250406 -4.319597
   -4.2547617 -4.161668  -4.192554  -2.5128372 -3.0919542 -3.3415725
   -3.0181003 -3.2375603 -3.6589284 -4.3500504 -3.878634  -5.4148607
   -9.492826 ]
  [-8.221429  -5.2610044 -4.8959284 -4.6905413 -4.6204805 -4.933116
   -4.5246563 -3.7863212 -3.653244  -3.1961567 -3.4312217 -3.122566
   -4.11899   -4.742709  -4.565643  -4.9416237 -5.2059894 -5.023781
   -9.712852 ]
  [-7.98959   -5.9730253 -6.2111893 -4

# **End of code notebook (the following cells below are only for experimental trying outs.)**

## **Visualizing mean testset on BA models.**
### The mean was computed on the test set at each cdr value for male and female separately and then visualized .

In [None]:

mat_path = '/content/drive/My Drive/BA_Estimation/means/exp_ba/full/'
files = os.listdir(mat_path)

# dt_string = datetime.now().strftime('%d-%m-%Y-%H-%M')+'_smoe_maps_blockend_scale_endlayers_equal_weights'

label_path= cf['Paths']['labels'] 
data_path= cf['Paths']['test_tfrecord'] 
# exp='exp_siam_ad'
exp='exp_ba_means'
tf_model = model

gender_dict={0:'Female',1:'Male'}
reverse_gender_dict = {'Female':0,'Male':0}
for i,f in enumerate(files): #OAS30686_d0030
  # print(type(f),im.shape)
  
  test_gender = f.split('_')[-1].split('.')[0]
  gender_tensor= np.array(reverse_gender_dict[test_gender])
  test_cdr = f.split('_')[-2][-3:]
  img=loadmat(mat_path+f)['data']
  print(img.shape)
  csmap_list=[]

  for chunk_id in [6,7]:#range(1,21): #select chunks from 1,21 (significant ones are from chunk 6-11)

    start = 6*(chunk_id-1)
    end = start+6
    # img = loadmat(f)['data']

    img_chunk=torch.tensor(img[:,:,start:end]).unsqueeze(0)
    img_chunk = img_chunk.unsqueeze(-1)
    
    print(f'input shape={img_chunk.shape}')
   
    input_tensor = img_chunk

    in_height   = input_tensor.size()[1]
    in_width    = input_tensor.size()[2]
    print(test_gender,input_tensor.shape,f)

    a=51
    n=35

    layer_name='activation_'+str(a)

    conv_path ='conv1_'+str(n)


    #axial
    base_dir= '/content/drive/My Drive/BA_Estimation/means/'+dt_string+ '/sal_map_axial/' \
    +exp+'/'+f+'_cdr'+str(test_cdr)
    path = base_dir+'/'+str(chunk_id)+'_'+conv_path+'/'
    print(path)
    if not os.path.exists(path):
      os.makedirs(path)
    base_path  = path+f+'_chunk_'+str(chunk_id)
    csmap_a=compute_saliency_tf(base_path,inputs=[input_tensor,gender_tensor],tf_model=tf_model)
    image,gcam_img,gcam_pp_img = compute_gcam_and_gcam_pp(layer_name,tf_model,[input_tensor,gender_tensor])
    combine_sal_gcam(path+f+'_cdr'+str(test_cdr)+'_'+test_gender,csmap_a,gcam_img,gcam_pp_img,image,layer_name=layer_name ) 


    #sagittal
    base_dir= '/content/drive/My Drive/BA_Estimation/means/'+dt_string+'/sal_map_sagittal/' \
    +exp+'/'+f+'_cdr'+str(test_cdr)
    path = base_dir+'/'+str(chunk_id)+'_'+conv_path+'/'
    print(path)
    if not os.path.exists(path):
      os.makedirs(path)
    base_path  = path+f+'_chunk_'+str(chunk_id)
    img_s= torch.from_numpy(img[start:end,:,:]).permute(2,1,0)
    csmap_s=compute_saliency_tf(base_path,inputs=[img_s,gender_tensor],tf_model=tf_model)
    image,gcam_img,gcam_pp_img = compute_gcam_and_gcam_pp(layer_name,tf_model,[img_s,gender_tensor])
    combine_sal_gcam(path+f+'_cdr'+str(test_cdr)+'_'+test_gender,csmap_s,gcam_img,gcam_pp_img,image,layer_name=layer_name ) 


    #coronal
    base_dir= '/content/drive/My Drive/BA_Estimation/means/'+dt_string + '/sal_map_coronal/' \
    +exp+'/'+f+'_cdr'+str(test_cdr)
    path = base_dir+'/'+str(chunk_id)+'_old_'+conv_path+'/'
    print(path)
    if not os.path.exists(path):
      os.makedirs(path)
    base_path  = path+f+'_chunk_'+str(chunk_id)
    # img_c= torch.from_numpy(img[:,58:64,:]).permute(2,0,1)
    img_c= torch.from_numpy(img[:,start:end,:]).permute(2,0,1)
    img_c=img_c.unsqueeze(0)
    img_c = torch.nn.functional.upsample(img_c.unsqueeze(0), size=(121,145,6), mode='nearest') 
    csmap_c=compute_saliency_tf(base_path,inputs=[img_c,gender_tensor],tf_model=tf_model)
    image,gcam_img,gcam_pp_img = compute_gcam_and_gcam_pp(layer_name,tf_model,[img_c,gender_tensor])
    combine_sal_gcam(path+f+'_cdr'+str(test_cdr)+'_'+test_gender,csmap_c,gcam_img,gcam_pp_img,image,layer_name=layer_name ) 

    #coronal
    base_dir= '/content/drive/My Drive/BA_Estimation/means/'+dt_string + '/sal_map_coronal/' \
    +exp+'/'+f+'_cdr'+str(test_cdr)
    path = base_dir+'/'+str(chunk_id)+'_new_'+conv_path+'/'
    print(path)
    if not os.path.exists(path):
      os.makedirs(path)
    base_path  = path+f+'_chunk_'+str(chunk_id)
    # img_c= torch.from_numpy(img[:,58:64,:]).permute(2,0,1)
    img_c= torch.from_numpy(img[:,end:end+6,:]).permute(2,0,1)
    img_c=img_c.unsqueeze(0)
    img_c = torch.nn.functional.upsample(img_c.unsqueeze(0), size=(121,145,6), mode='nearest') 
    csmap_c=compute_saliency_tf(base_path,inputs=[img_c,gender_tensor],tf_model=tf_model)
    image,gcam_img,gcam_pp_img = compute_gcam_and_gcam_pp(layer_name,tf_model,[img_c,gender_tensor])
    combine_sal_gcam(path+f+'_cdr'+str(test_cdr)+'_'+test_gender,csmap_c,gcam_img,gcam_pp_img,image,layer_name=layer_name ) 




    

In [None]:
img=loadmat('/content/drive/My Drive/BA_Estimation/results/sal_map_axial/25-08-2020-11-36_smoe_maps_blockend_scale_endlayers_equal_weights_exp_ba/OAS30190_MR_d0082_cdr1.0/9_conv1_66/OAS30190_MR_d0082_cdr1.0_Female_activation_66masked_gcam_pp.mat')

img['data'].shape

(121, 145)

In [None]:
image=img['data']
image[np.where(image<0.3*np.max(image))]=0
im=plt.imshow(image,cmap='jet')
plt.colorbar(im)
plt.clim(0.6,1)

In [None]:
image=img['data']
image[np.where(image<0.3*np.max(image))]=0
im=plt.imshow(image,cmap='jet')
plt.colorbar(im)
plt.clim(0.6,0.9)

In [None]:
import nibabel
import matplotlib.pyplot as plt
import numpy as np
# nii_img=nibabel.load('/content/sub-OAS30440_ses-d0163_T1w.nii').get_fdata() /content/OAS30070_MR_d0070.tfrecord
# nii_img=nibabel.load('/content/sub-OAS30070_ses-d0070_T1w.nii').get_fdata() /content/smwc1sub-OAS30282_ses-d0040_T1w.nii
# nii_img=nibabel.load('/content/smwc1sub-OAS30282_ses-d0040_T1w.nii').get_fdata() /content/sub-OAS31035_ses-d5659_run-02_T1w.nii
# nii_img=nibabel.load('/content/sub-OAS31035_ses-d5659_run-02_T1w.nii').get_fdata() /content/sub-OAS30310_ses-d0191_T1w.nii
# nii_img=nibabel.load('/content/sub-OAS31002_ses-d4948_run-02_T1w.nii').get_fdata()
# nii_img=nibabel.load('/content/sub-OAS30102_ses-d0024_T1w.nii').get_fdata()
# nii_img=nibabel.load('/content/sub-OAS30383_ses-d0134_run-02_T1w.nii').get_fdata()
# 
# nii_img.shape

(176, 256, 256)

In [None]:


# plt.imshow(nii_img[79,:,:],cmap='gray')
# plt.imshow(nii_img[70,:,:],cmap='gray') #79
# plt.imshow(nii_img[88,:,:],cmap='gray')
# plt.imshow(nii_img[79,:,:],cmap='gray')
# plt.imshow(nii_img[67,:,:],cmap='gray')
# plt.imshow(nii_img[67,:,:],cmap='gray')
# plt.imshow(np.mean(nii_img[70:79,:,:],axis=0),cmap='gray')


# plt.imshow(np.mean(nii_img[80:86,:,:],axis=0),cmap='gray')
# plt.axis('off')
# plt.savefig('sagittal.png')


In [None]:
# plt.imshow(nii_img[:,101,:],cmap='gray')
plt.imshow(np.mean(nii_img[:,85:91,:],axis=1),cmap='gray')
# plt.imshow(nii_img[:,90,:],cmap='gray') #96
# plt.imshow(nii_img[:,103,:],cmap='gray')
# plt.imshow(np.mean(nii_img[:,103:110,:],axis=1),cmap='gray')
plt.axis('off')
plt.savefig('coronal.png')

In [None]:
# plt.imshow(nii_img[:,:,101],cmap='gray')
# plt.imshow(np.mean(nii_img[:,:,103:110],axis=-1),cmap='gray')
plt.imshow(np.mean(nii_img[:,:,116:122],axis=-1),cmap='gray')
# plt.imshow(nii_img[:,:,102],cmap='gray') #115
# plt.imshow(nii_img[:,:,90],cmap='gray') 
# plt.imshow(nii_img[:,:,96],cmap='gray')
# plt.imshow(nii_img[:,:,103:110],cmap='gray') #102, 108,112,110
plt.axis('off')
plt.savefig('axial.png')

In [None]:
nii_img2= nibabel.load('/content/smwc1sub-OAS31002_ses-d4948_run-02_T1w.nii').get_fdata()

In [None]:
# plt.imshow(nii_img2[54,:,:],cmap='gray') 
# plt.imshow(nii_img2[36,:,:],cmap='gray') 
plt.imshow(np.mean(nii_img2[30:36,:,:],axis=0),cmap='gray') 
# plt.imshow(np.mean(nii_img[70:79,:,:],axis=0),cmap='gray')
plt.axis('off')
plt.savefig('sagittal_preprocessed.png')

In [None]:
# plt.imshow(nii_img2[:,54,:],cmap='gray')
plt.imshow(np.mean(nii_img2[:,42:48,:],axis=1),cmap='gray') 
plt.axis('off')
plt.savefig('coronal_preprocessed.png')

In [None]:
# plt.imshow(nii_img2[:,:,36],cmap='gray') #54
plt.imshow(np.mean(nii_img2[:,:,36:42],axis=-1),cmap='gray') 
# plt.imshow(nii_img2[:,:,36],cmap='gray')
# plt.imshow(np.mean(nii_img2[:,:,36:42],axis=-1),cmap='gray')  
plt.axis('off')
plt.savefig('axial_preprocessed.png')

In [None]:
np.max(img['data']),np.max(img['data']),np.min(img['data']),np.min(img['data'])

(1.0, 1.0, 0.0, 0.0)