<a href="https://colab.research.google.com/github/Nimanui/PyHealth-fitzpa15/blob/SaliencyMappingGradient/ChestXrayClassificationWithSaliencyMapping.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Medical Image Classification with PyHealth

Welcome to the PyHealth tutorial on image classification. In this notebook, we will explore how to use PyHealth to analyze chest X-ray images and classify them into various chest diseases.

## Environment Setup

To begin, we need to install PyHealth and a few additional packages to support our analysis.

In [None]:
!pip install mne pandarallel rdkit transformers

In [None]:
!rm -rf PyHealth-fitzpa15
# !git clone https://github.com/sunlabuiuc/PyHealth.git
!git clone -b SaliencyMappingGradient https://github.com/Nimanui/PyHealth-fitzpa15.git

In [None]:
%pip install -e ./PyHealth-fitzpa15

In [None]:
import sys


sys.path.append("./pyHealth")
sys.path.append("./pyhealth-fitzpa15")
sys.path.append("./PyHealth")
sys.path.append("./Pyhealth-fitzpa15")

## Download Data

Next, we will download the dataset containing COVID-19 data. This dataset includes chest X-ray images of normal cases, lung opacity, viral pneumonia, and COVID-19 patients. You can find more information about the dataset [here](https://www.kaggle.com/datasets/tawsifurrahman/covid19-radiography-database).

This dataset is hosted on Google Cloud, so the download speed should be relatively fast, taking approximately 10 seconds to complete.

In [None]:
!wget -N https://storage.googleapis.com/pyhealth/covid19_cxr_data/archive.zip

In [None]:
!unzip -q -o archive.zip

In [None]:
!ls -1 COVID-19_Radiography_Dataset

Next, we will proceed with the chest X-ray classification task using PyHealth, following a five-stage pipeline.

## Step 1. Load Data in PyHealth

The initial step involves loading the data into PyHealth's internal structure. This process is straightforward: import the appropriate dataset class from PyHealth and specify the root directory where the raw dataset is stored. PyHealth will handle the dataset processing automatically.

In [None]:
from pyhealth.datasets import COVID19CXRDataset


root = "/content/COVID-19_Radiography_Dataset"
base_dataset = COVID19CXRDataset(root)

Once the data is loaded, we can perform simple queries on the dataset.

In [None]:
base_dataset.stats()

## Step 2. Define the Task

The next step is to define the machine learning task. This step instructs the package to generate a list of samples with the desired features and labels based on the data for each individual patient. Please note that in this dataset, patient identification information is not available. Therefore, we will assume that each chest X-ray belongs to a unique patient.

For this dataset, PyHealth offers a default task specifically for chest X-ray classification. This task takes the image as input and aims to predict the chest diseases associated with it.

In [None]:
base_dataset.default_task

In [None]:
sample_dataset = base_dataset.set_task()

Here is an example of a single sample, represented as a dictionary. The dictionary contains keys for feature names, label names, and other metadata associated with the sample.

In [None]:
sample_dataset[0]

We can also check the input and output schemas, which specify the data types of the features and labels.

In [None]:
sample_dataset.input_schema

In [None]:
sample_dataset.output_schema

Below, we plot the number of samples per classes, and visualize some samples.

In [None]:
print(sample_dataset)

In [None]:
label2id = sample_dataset.output_processors["disease"].label_vocab
print(sample_dataset.output_schema)
id2label = {v: k for k, v in label2id.items()}

In [None]:
from collections import defaultdict
import matplotlib.pyplot as plt


label_counts = defaultdict(int)
for sample in sample_dataset.samples:
    label_counts[id2label[sample["disease"].item()]] += 1
print(label_counts)
plt.bar(label_counts.keys(), label_counts.values())

In [None]:
import random

label_to_idxs = defaultdict(list)
for idx, sample in enumerate(sample_dataset.samples):
    label_to_idxs[sample["disease"].item()].append(idx)

fig, axs = plt.subplots(1, 4, figsize=(15, 3))
for ax, label in zip(axs, label_to_idxs.keys()):
    ax.set_title(id2label[label], fontsize=15)
    idx = random.choice(label_to_idxs[label])
    sample = sample_dataset[idx]
    image = sample["image"][0]
    ax.imshow(image, cmap="gray")

Finally, we will split the entire dataset into training, validation, and test sets using the ratios of 70%, 10%, and 20%, respectively. We will then obtain the corresponding data loaders for each set.

In [None]:
from pyhealth.datasets import split_by_sample


train_dataset, val_dataset, test_dataset = split_by_sample(
    dataset=sample_dataset,
    ratios=[0.7, 0.1, 0.2]
)

In [None]:
from pyhealth.datasets import get_dataloader


train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

## Step 3. Define the Model

Next, we will define the deep learning model we want to use for our task. PyHealth supports all major vision models available in the Torchvision package. You can load any of these models using the model_name argument.

In [None]:
from pyhealth.models import TorchvisionModel


resnet = TorchvisionModel(
    dataset=sample_dataset,
    model_name="resnet18",
    model_config={"weights": "DEFAULT"}
)

resnet

In [None]:
from pyhealth.models import TorchvisionModel


vit = TorchvisionModel(
    dataset=sample_dataset,
    model_name="vit_b_16",
    model_config={"weights": "DEFAULT"}
)

vit


## Step 4. Training


In this step, we will train the model using PyHealth's Trainer class, which simplifies the training process and provides standard functionalities.

Let us first train the ResNet model.

In [None]:
from pyhealth.trainer import Trainer


resnet_trainer = Trainer(model=resnet)

Before we begin training, let's first evaluate the initial performance of the model.

In [None]:
print(resnet_trainer.evaluate(test_dataloader))

Now, let's start the training process. Due to computational constraints, we will train the model for only one epoch.

In [None]:
resnet_trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=1,
    monitor="accuracy"
)

After training the model, we can compare its performance before and after. We should expect to see an increase in the accuracy score as the model learns from the training data.

## Step 5. Evaluation

Lastly, we can eavluate the ResNet model on the test set. This can be done using PyHealth's `Trainer.evaluate()` function.

In [None]:
print(resnet_trainer.evaluate(test_dataloader))

Additionally, you can perform inference using the `Trainer.inference()` function.

In [None]:
y_true, y_prob, loss = resnet_trainer.inference(test_dataloader)
y_pred = y_prob.argmax(axis=1)

Below we show a confusion matrix of the trained ResNet model.

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns


cf_matrix = confusion_matrix(y_true, y_pred)
ax = sns.heatmap(cf_matrix, linewidths=1, annot=True, fmt='g')
ax.set_xticklabels([id2label[i] for i in range(4)])
ax.set_yticklabels([id2label[i] for i in range(4)])
ax.set_xlabel("Pred")
ax.set_ylabel("True")

# 6 Gradient Saliency Mapping
For a bonus let's look at some simple gradient saliency maps applied to our sample dataset.

In [None]:
def add_requires_grad(in_dataset):
  for sample in in_dataset:
    sample['image'].requires_grad_()

In [None]:
from pyhealth.datasets import get_dataloader
from pyhealth.interpret.methods.saliency import GradientSaliencyMapping
batch_size = 32

sample_dataloader = get_dataloader(sample_dataset.samples, batch_size=batch_size, shuffle=True)

saliency_maps = GradientSaliencyMapping(resnet, sample_dataloader, 20)

In [None]:
import torchvision
import numpy as np

def imshow(img, title):
    npimg = img.numpy()
    plt.figure(figsize=(15, 7))
    plt.axis('off')
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.title(title)
    plt.show()

def imshowSaliencyCompFromDict(saliency_dict_list, batch_index, image_index, title, alpha=0.3):
    img = saliency_dict_list[batch_index]['image'][image_index]
    saliency = saliency_dict_list[batch_index]['saliency'][image_index]
    label = saliency_dict_list[batch_index]['label'][image_index]
    new_title = str(title + " " + id2label[label.item()])
    imshowSaliencyComp(img, saliency, new_title, alpha)

def imshowSaliencyComp(img, saliency, title, alpha=0.3):
    npimg = img.detach().numpy()
    npimg = np.transpose(npimg, (1, 2, 0))
    plt.figure(figsize=(15, 7))
    plt.axis('off')
    plt.imshow(npimg.squeeze(), cmap='gray')
    plt.imshow(saliency, cmap='hot', alpha=alpha)
    plt.title(title)
    plt.show()

In [None]:
batch_count = int(len(saliency_maps[0])/3)
batch_size = len(saliency_maps[0]['saliency'])
batch_random = random.randint(0, batch_count - 1)
image_index_random = random.randint(0, batch_size - 1)
imshowSaliencyCompFromDict(saliency_maps, batch_random, image_index_random, "Gradient Saliency", .6)