# **Interpretable AI For Protein Expression Images and Associated Clinical Metadata (Concept whitening)**

## **1. Introduction**

With a conservative frequency estimate of about 1:5,000, mitochondrial disorders are among the most prevalent inheritable diseases [1]. Diagnosis and understanding of the different mitochondrial diseases are extremely difficult because they have a wide range of symptoms in each patient and affect different organs and tissues of the body [2]. However, recent studies show deep learning algorithms with interpretability and explainability, especially Convolutional Neural Networks (CNN), can help us automatically diagnose and evaluate different diseases by detecting the different patterns in the images. In the study, we developed deep learning models using transfer learning to predict mitochondrial diseases and used existing machine learning interpretability and explainability AI approaches for computer vision, like Grad-CAM [3], and Neural Disentanglement (concept whitening) [4], to understand the features that result in the prediction of mitochondrial disease from protein expression images.

## **2. Aim**

To determine if Deep Learning (DL) can be used as a reliable method to classify mitochondrial diseases using interpretability
and explainability approaches.

• Interpretability Method: Neural Disentanglement

• Explainability Method: Saliency Map 

## **3. Objective**

This investigation has primarily two objectives:

3.1. Determine whether it is possible to accurately diagnose mitochondrial diseases for single channel protein images.

>  Adapt the pre-trained deep learning models like VGG16 and ResNet-50 to classify different mitochondrial diseases using protein expression images obtained by image mass cytometry. Evaluate, compare, and fine-tune different pre-trained model architectures using parameters like accuracy, precision, recall, f1 score, and confusion matrix.

3.2. Understand the underlying pathology of mitochondrial diseases.

> Once we know that the pre-trained deep learning models are able to classify different mitochondrial diseases, we can make then apply different interpretability and explainability AI approaches such as Grad-CAM and Neural Disentanglement on top of the pre-train models to understand the underlying pathology of mitochondrial diseases. 

#**Importing essential libraries**

In [1]:
# initiating gpu using tensorflow.
import tensorflow as tf
from tensorflow.compat.v1.keras.backend import set_session
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
config.log_device_placement = True
sess = tf.compat.v1.Session(config=config)
set_session(sess)

Device mapping:
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5



In [2]:
import numpy as np
import pandas as pd
import glob
import random 
import os
import cv2
from shutil import copy,copytree
import shutil

!pip install patchify
from patchify import patchify

from google.colab import files
import os
import zipfile

#visulaziation
import matplotlib.pyplot as plt
from google.colab.patches import cv2_imshow


# image
from PIL import Image

# sklearn
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
#tensorflow
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras import layers
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img,img_to_array
from tensorflow.keras.applications import VGG16, InceptionV3, ResNet101V2, Xception, ResNet50V2,VGG19
from tensorflow.keras import optimizers
from tensorflow.keras.optimizers import Adam
#from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.models import load_model
from tensorflow.keras import models, layers
from keras.utils.vis_utils import plot_model
from tensorflow.keras.layers import Dense,Conv2D,MaxPool2D,Flatten,Dropout,BatchNormalization,Activation,GlobalAveragePooling2D

from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import confusion_matrix

#setting seed to reproduce the same result every time the code is re-run
seed_num = 1
tf.random.set_seed(seed_num)
np.random.seed(seed_num)

# install split-folder
!pip install split-folders

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting patchify
  Downloading patchify-0.2.3-py3-none-any.whl (6.6 kB)
Installing collected packages: patchify
Successfully installed patchify-0.2.3
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting split-folders
  Downloading split_folders-0.5.1-py3-none-any.whl (8.4 kB)
Installing collected packages: split-folders
Successfully installed split-folders-0.5.1


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


**Unziping the dataset**

In [4]:
%%capture
!unzip /content/drive/MyDrive/Dataset.zip

In [5]:
%%capture
!unzip /content/drive/MyDrive/Concept_whitening/checkpoints.zip

**Upload or copy from google dirve the necessary python script to implement concept whitening:**

Plot_function.py : to plot the required graph

train_isis.py : To train the model

resnet18365.py : To build the intial resent 18 model

iterative_normaliztion : contains code to implement concept whitening

model_resent.py : to build the concept whitening resnet model 

In [6]:
!cp /content/drive/MyDrive/Concept_whitening/iterative_normalization.py /content/
!cp /content/drive/MyDrive/Concept_whitening/model_resnet.py /content/
!cp /content/drive/MyDrive/Concept_whitening/plot_functions.py /content/
!cp /content/drive/MyDrive/Concept_whitening/resnet18365places.py /content/
!cp /content/drive/MyDrive/Concept_whitening/train_isic.py /content/


#**Data Engineering**
**Renaming all the files in the subfolder for both control and patient with subfolder names**

In [7]:
d = ["/content/Dataset/controls/C01","/content/Dataset/controls/C02","/content/Dataset/controls/C03","/content/Dataset/controls/C04"]
for i in d:
  for path in os.listdir(i):
      full_path = os.path.join(i, path)
      if os.path.isfile(full_path):
          new_path = os.path.join(i, "C0"+str(d.index(i)+1)+"_"+ path)
          os.rename(full_path, new_path)

In [8]:
d = ["/content/Dataset/patients/P01","/content/Dataset/patients/P02","/content/Dataset/patients/P03","/content/Dataset/patients/P04","/content/Dataset/patients/P05","/content/Dataset/patients/P06","/content/Dataset/patients/P07","/content/Dataset/patients/P08","/content/Dataset/patients/P09","/content/Dataset/patients/P10"]
for i in d:
  for path in os.listdir(i):
      full_path = os.path.join(i, path)
      if os.path.isfile(full_path):
          new_path = os.path.join(i, "P0"+str(d.index(i)+1)+"_" + path)
          os.rename(full_path, new_path)

**Copying all the JPG images into Dataset_JPG directory and maintaing the tree structure of the directories**

In [9]:
shutil.copytree('/content/Dataset', '/content/Dataset_JPG' , ignore=shutil.ignore_patterns('*.ome.tiff', '*.db'))

'/content/Dataset_JPG'

**Copying all the TIFF images into Dataset_TIFF directory and maintaing the tree structure of the directories**

In [10]:
shutil.copytree('/content/Dataset', '/content/Dataset_TIFF' , ignore=shutil.ignore_patterns('*.jpg', '*.db'))

'/content/Dataset_TIFF'

#**Concept Whitining on single-channel protein images**



In [11]:
#Creating a new path to store all jpg images(Control SDHA Images)
newpath = r'/content/SDHA_Images/Controls' 
if not os.path.exists(newpath):
    os.makedirs(newpath)
dir_src = r"/content/Dataset_JPG/controls"
dir_dst = r"/content/SDHA_Images/Controls"
for file in glob.iglob('%s/**/*SDHA.jpg' % dir_src, recursive=True):
    copy(file, dir_dst)

In [12]:
#Creating a new path to store all jpg images(Patients SDHA Images)
newpath = r'/content/SDHA_Images/Patients' 
if not os.path.exists(newpath):
    os.makedirs(newpath)
dir_src = r"/content/Dataset_JPG/patients"
dir_dst = r"/content/SDHA_Images/Patients"
for file in glob.iglob('%s/**/*SDHA.jpg' % dir_src, recursive=True):
    copy(file, dir_dst)

**TIFF IMAGES**

In [None]:
#Creating a new path to store all jpg images(Control SDHA Images)
newpath = r'/content/SDHA_Images_TIFF/Controls' 
if not os.path.exists(newpath):
    os.makedirs(newpath)
dir_src = r"/content/Dataset_TIFF/controls"
dir_dst = r"/content/SDHA_Images_TIFF/Controls"
for file in glob.iglob('%s/**/*SDHA.ome.tiff' % dir_src, recursive=True):
    copy(file, dir_dst)

In [None]:
#Creating a new path to store all jpg images(Patients SDHA Images)
newpath = r'/content/SDHA_Images_TIFF/Patients' 
if not os.path.exists(newpath):
    os.makedirs(newpath)
dir_src = r"/content/Dataset_TIFF/patients"
dir_dst = r"/content/SDHA_Images_TIFF/Patients"
for file in glob.iglob('%s/**/*SDHA.ome.tiff' % dir_src, recursive=True):
    copy(file, dir_dst)

**Spliting the image into sub-images using patchify function**

**For Control:**

In [13]:
#Creating a new path to store all spliting jpg images(Control SDHA Images)
newpath = r'/content/SDHA_Split_Images/Controls' 
if not os.path.exists(newpath):
    os.makedirs(newpath)
my_path = "/content/SDHA_Images/Controls"
files = glob.glob(my_path + '/**/*.jpg', recursive=True)
for file in files:
  img = cv2.imread(file)
  patches_img = patchify(img, (224,224,3), step=128)
  for i in range(patches_img.shape[0]):
      for j in range(patches_img.shape[1]):
          single_patch_img = patches_img[i, j, 0, :, :, :]
          if not cv2.imwrite(r'/content/SDHA_Split_Images/Controls/' + "C0"+str(files.index(file)+1)+'_image_' + '_'+ str(i)+str(j)+'.jpg', single_patch_img):
              raise Exception("Could not write the image")

**For Patient:**

In [14]:
#Creating a new path to store all spliting jpg images(Control SDHA Images)
newpath = r'/content/SDHA_Split_Images/Patients' 
if not os.path.exists(newpath):
    os.makedirs(newpath)
my_path = "/content/SDHA_Images/Patients"
files = glob.glob(my_path + '/**/*.jpg', recursive=True)
for file in files:
  img = cv2.imread(file)
  patches_img = patchify(img, (224,224,3), step=128)
  for i in range(patches_img.shape[0]):
      for j in range(patches_img.shape[1]):
          single_patch_img = patches_img[i, j, 0, :, :, :]
          if not cv2.imwrite(r'/content/SDHA_Split_Images/Patients/' + "P0"+str(files.index(file)+1)+'_image_' + '_'+ str(i)+str(j)+'.jpg', single_patch_img):
              raise Exception("Could not write the image")

**TIFF**

In [None]:
#Creating a new path to store all spliting jpg images(Control SDHA Images)
newpath = r'/content/SDHA_Split_Images_TIFF/Controls' 
if not os.path.exists(newpath):
    os.makedirs(newpath)
my_path = "/content/SDHA_Images_TIFF/Controls"
files = glob.glob(my_path + '/**/*.ome.tiff', recursive=True)
for file in files:
  img = cv2.imread(file)
  patches_img = patchify(img, (224,224,3), step=224)
  for i in range(patches_img.shape[0]):
      for j in range(patches_img.shape[1]):
          single_patch_img = patches_img[i, j, 0, :, :, :]
          if not cv2.imwrite(r'/content/SDHA_Split_Images_TIFF/Controls/' + "C0"+str(files.index(file)+1)+'_image_' + '_'+ str(i)+str(j)+'.ome.tiff', single_patch_img):
              raise Exception("Could not write the image")

In [None]:
#Creating a new path to store all spliting jpg images(Control SDHA Images)
newpath = r'/content/SDHA_Split_Images_TIFF/Patients' 
if not os.path.exists(newpath):
    os.makedirs(newpath)
my_path = "/content/SDHA_Images_TIFF/Patients"
files = glob.glob(my_path + '/**/*.ome.tiff', recursive=True)
for file in files:
  img = cv2.imread(file)
  patches_img = patchify(img, (224,224,3), step=224)
  for i in range(patches_img.shape[0]):
      for j in range(patches_img.shape[1]):
          single_patch_img = patches_img[i, j, 0, :, :, :]
          if not cv2.imwrite(r'/content/SDHA_Split_Images_TIFF/Patients/' + "P0"+str(files.index(file)+1)+'_image_' + '_'+ str(i)+str(j)+'.ome.tiff', single_patch_img):
              raise Exception("Could not write the image")

**Using split folder function we are dividing the our data into train, test and validation data**

In [15]:
import splitfolders 

#### input dataset that want to split
input_folder = '/content/SDHA_Split_Images'  

output_folder= '/content/Second_Assumption'

splitfolders.ratio(input_folder, output= output_folder, seed=1337, ratio = (0.8, 0.1,0.1))

Copying files: 4482 files [00:00, 5358.66 files/s]


# **Pytorch Implementation**

**Setting Directory**

In [16]:
traindir="/content/Second_Assumption/train"
valdir="/content/Second_Assumption/val"
testdir="/content/Second_Assumption/test"

In [17]:
%%capture
!pip install --upgrade trax

**Builidng a RestNet18 model and saving the weights**

In [None]:
!python3 /content/resnet18365places.py  /content/Second_Assumption/ --workers 2 --arch resnet18  --batch-size 2 --lr 0.005 --epochs 25


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch: [0][1040/1793]	Time 0.018 (0.021)	Data 0.000 (0.000)	Loss 0.7590 (0.7913)	Prec@1 50.000 (70.221)	Prec@5 100.000 (100.000)
Epoch: [0][1050/1793]	Time 0.018 (0.021)	Data 0.000 (0.000)	Loss 0.3018 (0.7880)	Prec@1 100.000 (70.409)	Prec@5 100.000 (100.000)
Epoch: [0][1060/1793]	Time 0.017 (0.021)	Data 0.000 (0.000)	Loss 0.3680 (0.7881)	Prec@1 100.000 (70.358)	Prec@5 100.000 (100.000)
Epoch: [0][1070/1793]	Time 0.018 (0.021)	Data 0.000 (0.000)	Loss 0.3769 (0.7864)	Prec@1 100.000 (70.355)	Prec@5 100.000 (100.000)
Epoch: [0][1080/1793]	Time 0.017 (0.021)	Data 0.000 (0.000)	Loss 0.2540 (0.7856)	Prec@1 100.000 (70.398)	Prec@5 100.000 (100.000)
Epoch: [0][1090/1793]	Time 0.019 (0.021)	Data 0.000 (0.000)	Loss 0.4721 (0.7846)	Prec@1 100.000 (70.394)	Prec@5 100.000 (100.000)
Epoch: [0][1100/1793]	Time 0.018 (0.021)	Data 0.000 (0.000)	Loss 1.2962 (0.7825)	Prec@1 50.000 (70.481)	Prec@5 100.000 (100.000)
Epoch: [0][1110/1793]	Time 

**Saving the file after building the model:**

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
save_path = '/content/checkpoints/resnet18_isic_model_best.pth.tar'
#torch.save(model.state_dict(), save_path)

**Loading the model**

In [None]:
# Model class must be defined somewhere
model = torch.load('/content/resnet18_best.pth.tar')
#model= model.load_state_dict(torch.load('/content/checkpoints/resnet18_isic_model_best.pth.tar'))
#model.eval(test_loader)

In [None]:
print(model)

**Creating the folder structure adding both patient and control as concept:**

In [21]:
newpath = r'/content/Second_Assumption/concept_train/Controls/Controls' 
if not os.path.exists(newpath):
    os.makedirs(newpath)
src_dir = "/content/Second_Assumption/train/Controls/*.jpg"
dst_dir = "/content/Second_Assumption/concept_train/Controls/Controls/"
files =glob.iglob(src_dir)
for jpgfile in sorted(files)[:25]:
    shutil.copy(jpgfile, dst_dir)
newpath = r'/content/Second_Assumption/concept_test/Controls' 
if not os.path.exists(newpath):
    os.makedirs(newpath)
src_dir_1 = "/content/Second_Assumption/test/Controls/*.jpg"
dst_dir = "/content/Second_Assumption/concept_test/Controls"
files_1 =glob.iglob(src_dir_1)
for jpgfile in sorted(files_1)[:15]:
    shutil.copy(jpgfile, dst_dir)
newpath = r'/content/Second_Assumption/concept_train/Patients/Patients' 
if not os.path.exists(newpath):
    os.makedirs(newpath)
src_dir_2 = "/content/Second_Assumption/train/Patients/*.jpg"
dst_dir = "/content/Second_Assumption/concept_train/Patients/Patients/"
files_2 =glob.iglob(src_dir_2)
for jpgfile in sorted(files_2)[:25]:
    shutil.copy(jpgfile, dst_dir)
newpath = r'/content/Second_Assumption/concept_test/Patients' 
if not os.path.exists(newpath):
    os.makedirs(newpath)
src_dir_3 = "/content/Second_Assumption/test/Patients/*.jpg"
dst_dir = "/content/Second_Assumption/concept_test/Patients/"
files_3 =glob.iglob(src_dir_3)
for jpgfile in sorted(files_3)[:15]:
    shutil.copy(jpgfile, dst_dir)

# **Concept-whetening**

In [25]:
!python3 /content/train_isic.py --ngpu 1 --workers 2 --arch resnet_cw --depth 18 --batch-size 2 --lr 0.005 --whitened_layers 8 --epochs 2 --start-epoch 1 --concepts Controls,Patients --prefix resnet18_isic /content/Second_Assumption/


args Namespace(act_mode='pool_max', arch='resnet_cw', batch_size=2, concepts='Controls,Patients', data='/content/Second_Assumption/', depth=18, epochs=2, evaluate=None, lr=0.005, momentum=0.9, ngpu=1, prefix='resnet18_isic', print_freq=10, resume='', seed=1234, start_epoch=1, weight_decay=0.0001, whitened_layers='8', workers=2)
25 tensor(78.4838)
25
model
DataParallel(
  (module): ResidualNetTransfer(
    (model): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu):

##**Plot the most important images with activations mapping same as concept.**

In [26]:
!python3 /content/train_isic.py --ngpu 1 --workers 2 --arch resnet_cw --depth 18 --batch-size 2 --lr 0.005 --whitened_layers 8 --epochs 5 --start-epoch 1 --concepts Controls,Patients --prefix resnet18_isic /content/Second_Assumption/ --evaluate plot_top50


args Namespace(act_mode='pool_max', arch='resnet_cw', batch_size=2, concepts='Controls,Patients', data='/content/Second_Assumption/', depth=18, epochs=5, evaluate='plot_top50', lr=0.005, momentum=0.9, ngpu=1, prefix='resnet18_isic', print_freq=10, resume='', seed=1234, start_epoch=1, weight_decay=0.0001, whitened_layers='8', workers=2)
25 tensor(78.4838)
25
model
DataParallel(
  (module): ResidualNetTransfer(
    (model): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         