# 使用 PyTorch 和 TIAToolbox 进行 WSIs 分类

在本教程中，将展示如何使用 TIAToolbox 辅助的 PyTorch 深度学习模型对全量影像（Whole Slide Images，简称 WSI）进行分类。WSIs 是通过手术或活检获取的人体组织，并使用专用扫描仪进行扫描。病理学家和计算病理学研究人员使用它们来[在显微镜下研究癌症](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7522141/)，以了解例如肿瘤生长并帮助改善患者的治疗。

使 WSI 难以处理的是它们巨大的尺寸。例如，典型的幻灯片图像的大小约为[100,000x100,000像素](https://doi.org/10.1117%2F12.912388)，其中每个像素对应于幻灯片上约0.25x0.25微米的区域。这带来了加载和处理这些图像的挑战，更不用说在一个研究中可能有数百甚至数千张 WSIs（更大的研究会产生更好的结果）！

传统的图像处理管道不适用于 WSI 处理，因此需要更好的工具。这就是 [TIAToolbox](https://github.com/TissueImageAnalytics/tiatoolbox) 可以提供帮助的地方，它带来了一套有用的工具，可以快速且计算高效地导入和处理组织幻灯片。通常，WSIs 以金字塔结构保存，具有多个相同图像的不同放大级别的副本，这些副本针对可视化进行了优化。金字塔的第0层（或底层）包含最高放大率或缩放级别的图像，而金字塔中的较高层具有较低分辨率的基图像副本。金字塔结构如下图所示。

![WSI 金字塔堆栈](https://tia-toolbox.readthedocs.io/en/latest/_images/read_bounds_tissue.png)
*WSI 金字塔堆栈（[来源](https://tia-toolbox.readthedocs.io/en/latest/_autosummary/tiatoolbox.wsicore.wsireader.WSIReader.html#)）*

TIAToolbox 允许我们自动化常见的下游分析任务，例如[组织分类](https://doi.org/10.1016/j.media.2022.102685)。在本教程中，我们将向您展示如何：
1. 使用 TIAToolbox 加载 WSI 图像；以及
2. 使用不同的 PyTorch 模型对幻灯片进行批量分类。在本教程中，我们将提供使用 TorchVision 的 `ResNet18` 模型和自定义 [`HistoEncoder`](https://github.com/jopo666/HistoEncoder) 模型的示例。

让我们开始吧！

## 设置环境

要运行本教程中提供的示例，需要以下软件包作为先决条件。

1. OpenJpeg
2. OpenSlide
3. Pixman
4. TIAToolbox
5. HistoEncoder（用于自定义模型示例）

请在终端中运行以下命令来安装这些软件包：

```bash
%%bash
apt-get -y install libopenjp2-7-dev libopenjp2-tools openslide-tools libpixman-1-dev | tail -n 1
pip install histoencoder | tail -n 1
pip install git+https://github.com/TissueImageAnalytics/tiatoolbox.git@develop | tail -n 1
echo "Installation is done."
```

或者，您可以在 MacOS 上运行 `brew install openjpeg openslide` 来安装先决条件软件包，而不是 `apt-get`。有关安装的更多信息可以在[这里](https://tia-toolbox.readthedocs.io/en/latest/installation.html)找到。您可能需要在页面顶部的运行时菜单中重新启动运行时，以继续本教程的其余部分，以便新安装的依赖项被识别。

In [1]:
"""Import modules required to run the Jupyter notebook."""
# Configure logging
import logging
import warnings
if logging.getLogger().hasHandlers():
    logging.getLogger().handlers.clear()
warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")

# Downloading data and files
import shutil
from pathlib import Path
from zipfile import ZipFile

# Data processing and visualization
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import cm
import PIL
import contextlib
import io
from sklearn.metrics import accuracy_score, confusion_matrix

# TIAToolbox for WSI loading and processing
from tiatoolbox import logger
from tiatoolbox.models.architecture import vanilla
from tiatoolbox.models.engine.patch_predictor import (
    IOPatchPredictorConfig,
    PatchPredictor,
)
from tiatoolbox.utils.misc import download_data, grab_files_from_dir
from tiatoolbox.utils.visualization import overlay_prediction_mask
from tiatoolbox.wsicore.wsireader import WSIReader

# Torch-related
import torch
from torchvision import transforms

# Configure plotting
mpl.rcParams["figure.dpi"] = 160  # for high resolution figure in notebook
mpl.rcParams["figure.facecolor"] = "white"  # To make sure text is visible in dark mode

# If you are not using GPU, change ON_GPU to False
ON_GPU = True

# Function to suppress console output for overly verbose code blocks
def suppress_console_output():
    return contextlib.redirect_stderr(io.StringIO())

ModuleNotFoundError: No module named 'tiatoolbox'

### Clean-up before a run

To ensure proper clean-up (for example in abnormal termination), all files downloaded or created in this run are saved in a single directory `global_save_dir`, which we set equal to "./tmp/". To simplify maintenance, the name of the directory occurs only at this one place, so that it can easily be changed, if desired.



In [None]:
warnings.filterwarnings("ignore")
global_save_dir = Path("./tmp/")


def rmdir(dir_path: str | Path) -> None:
    """Helper function to delete directory."""
    if Path(dir_path).is_dir():
        shutil.rmtree(dir_path)
        logger.info("Removing directory %s", dir_path)


rmdir(global_save_dir)  # remove  directory if it exists from previous runs
global_save_dir.mkdir()
logger.info("Creating new directory %s", global_save_dir)

### Downloading the data
For our sample data, we will use one whole-slide image, and patches from the validation subset of [Kather 100k](https://zenodo.org/record/1214456#.YJ-tn3mSkuU) dataset.


In [None]:
wsi_path = global_save_dir / "sample_wsi.svs"
patches_path = global_save_dir / "kather100k-validation-sample.zip"
weights_path = global_save_dir / "resnet18-kather100k.pth"

logger.info("Download has started. Please wait...")

# Downloading and unzip a sample whole-slide image
download_data(
    "https://tiatoolbox.dcs.warwick.ac.uk/sample_wsis/TCGA-3L-AA1B-01Z-00-DX1.8923A151-A690-40B7-9E5A-FCBEDFC2394F.svs",
    wsi_path,
)

# Download and unzip a sample of the validation set used to train the Kather 100K dataset
download_data(
    "https://tiatoolbox.dcs.warwick.ac.uk/datasets/kather100k-validation-sample.zip",
    patches_path,
)
with ZipFile(patches_path, "r") as zipfile:
    zipfile.extractall(path=global_save_dir)

# Download pretrained model weights for WSI classification using ResNet18 architecture
download_data(
    "https://tiatoolbox.dcs.warwick.ac.uk/models/pc/resnet18-kather100k.pth",
    weights_path,
)

logger.info("Download is complete.")

## Reading the data

We create a list of patches and a list of corresponding labels.
For example, the first label in `label_list` will indicate the class of the first image patch in `patch_list`.



In [None]:
# Read the patch data and create a list of patches and a list of corresponding labels
dataset_path = global_save_dir / "kather100k-validation-sample"

# Set the path to the dataset
image_ext = ".tif"  # file extension of each image

# Obtain the mapping between the label ID and the class name
label_dict = {
    "BACK": 0, # Background (empty glass region)
    "NORM": 1, # Normal colon mucosa
    "DEB": 2,  # Debris
    "TUM": 3,  # Colorectal adenocarcinoma epithelium
    "ADI": 4,  # Adipose
    "MUC": 5,  # Mucus
    "MUS": 6,  # Smooth muscle
    "STR": 7,  # Cancer-associated stroma
    "LYM": 8,  # Lymphocytes
}

class_names = list(label_dict.keys())
class_labels = list(label_dict.values())

# Generate a list of patches and generate the label from the filename
patch_list = []
label_list = []
for class_name, label in label_dict.items():
    dataset_class_path = dataset_path / class_name
    patch_list_single_class = grab_files_from_dir(
        dataset_class_path,
        file_types="*" + image_ext,
    )
    patch_list.extend(patch_list_single_class)
    label_list.extend([label] * len(patch_list_single_class))

# Show some dataset statistics
plt.bar(class_names, [label_list.count(label) for label in class_labels])
plt.xlabel("Patch types")
plt.ylabel("Number of patches")

# Count the number of examples per class
for class_name, label in label_dict.items():
    logger.info(
        "Class ID: %d -- Class Name: %s -- Number of images: %d",
        label,
        class_name,
        label_list.count(label),
    )

# Overall dataset statistics
logger.info("Total number of patches: %d", (len(patch_list)))

As you can see for this patch dataset, we have 9 classes/labels with IDs 0-8 and associated class names. describing the dominant tissue type in the patch:

- BACK ⟶ Background (empty glass region)
- LYM  ⟶ Lymphocytes
- NORM ⟶ Normal colon mucosa
- DEB  ⟶ Debris
- MUS  ⟶ Smooth muscle
- STR  ⟶ Cancer-associated stroma
- ADI  ⟶ Adipose
- MUC  ⟶ Mucus
- TUM  ⟶ Colorectal adenocarcinoma epithelium



## Classify image patches

We demonstrate how to obtain a prediction for each patch within a digital slide first with the `patch` mode and then with a large slide using `wsi` mode.

### Define `PatchPredictor` model

The PatchPredictor class runs a CNN-based classifier written in PyTorch.

- `model` can be any trained PyTorch model with the constraint that it should follow the [`tiatoolbox.models.abc.ModelABC`](https://tia-toolbox.readthedocs.io/en/latest/_autosummary/tiatoolbox.models.models_abc.ModelABC.html) class structure. For more information on this matter, please refer to [our example notebook on advanced model techniques](https://github.com/TissueImageAnalytics/tiatoolbox/blob/develop/examples/07-advanced-modeling.ipynb). In order to load a custom model, you need to write a small preprocessing function, as in `preproc_func(img)`, which make sures the input tensors are in the right format for the loaded network.
- Alternatively, you can pass `pretrained_model` as a string argument. This specifies the CNN model that performs the prediction, and it must be one of the models listed [here](https://tia-toolbox.readthedocs.io/en/latest/usage.html?highlight=pretrained%20models#tiatoolbox.models.architecture.get_pretrained_model). The command will look like this: `predictor = PatchPredictor(pretrained_model='resnet18-kather100k', pretrained_weights=weights_path, batch_size=32)`.
- `pretrained_weights`: When using a `pretrained_model`, the corresponding pretrained weights will also be downloaded by default.  You can override the default with your own set of weights via the `pretrained_weight` argument.
- `batch_size`: Number of images fed into the model each time. Higher values for this parameter require a larger (GPU) memory capacity.

In [None]:
model = vanilla.CNNModel(backbone="resnet18", num_classes=9) # Importing model from torchvision.models.resnet18
model.load_state_dict(torch.load(weights_path, map_location="cpu"), strict=True)
def preproc_func(img):
    img = PIL.Image.fromarray(img)
    img = transforms.ToTensor()(img)
    return img.permute(1, 2, 0)
model.preproc_func = preproc_func
predictor = PatchPredictor(model=model, batch_size=32)

### Predict patch labels

We create a predictor object and then call the `predict` method using the `patch` mode. We then compute the classification accuracy and confusion matrix.

In [None]:
with suppress_console_output():
    output = predictor.predict(imgs=patch_list, mode="patch", on_gpu=ON_GPU)

acc = accuracy_score(label_list, output["predictions"])
logger.info("Classification accuracy: %f", acc)

# Creating and visualizing the confusion matrix for patch classification results
conf = confusion_matrix(label_list, output["predictions"], normalize="true")
df_cm = pd.DataFrame(conf, index=class_names, columns=class_names)
df_cm

### Predict patch labels for a whole slide

We also introduce `IOPatchPredictorConfig`, a class that specifies the configuration of image reading and prediction writing for the model prediction engine. This is required to inform the classifier which level of the WSI pyramid the classifier should read, process data and generate output.

Parameters of `IOPatchPredictorConfig` are defined as:

- `input_resolutions`: A list, in the form of a dictionary, specifying the resolution of each input. List elements must be in the same order as in the target `model.forward()`. If your model accepts only one input, you just need to put one dictionary specifying `'units'` and `'resolution'`. Note that TIAToolbox supports a model with more than one input. For more information on units and resolution, please see [TIAToolbox documentation](https://tia-toolbox.readthedocs.io/en/latest/_autosummary/tiatoolbox.wsicore.wsireader.WSIReader.html#tiatoolbox.wsicore.wsireader.WSIReader.read_rect).
- `patch_input_shape`: Shape of the largest input in (height, width) format.
- `stride_shape`: The size of a stride (steps) between two consecutive patches, used in the patch extraction process. If the user sets `stride_shape` equal to `patch_input_shape`, patches will be extracted and processed without any overlap.

In [None]:
wsi_ioconfig = IOPatchPredictorConfig(
    input_resolutions=[{"units": "mpp", "resolution": 0.5}],
    patch_input_shape=[224, 224],
    stride_shape=[224, 224],
)

The `predict` method applies the CNN on the input patches and get the results. Here are the arguments and their descriptions:

- `mode`: Type of input to be processed. Choose from `patch`, `tile` or `wsi` according to your application.
- `imgs`: List of inputs, which should be a list of paths to the input tiles or WSIs.
- `return_probabilities`: Set to *__True__* to get per class probabilities alongside predicted labels of input patches. If you wish to merge the predictions to generate prediction maps for `tile` or `wsi` modes, you can set `return_probabilities=True`.
- `ioconfig`: set the IO configuration information using the `IOPatchPredictorConfig` class.
- `resolution` and `unit` (not shown below): These arguments specify the level or micron-per-pixel resolution of the WSI levels from which we plan to extract patches and can be used instead of `ioconfig`. Here we specify the WSI's level as `'baseline'`, which is equivalent to level 0. In general, this is the level of greatest resolution. In this particular case, the image has only one level. More information can be found in the [documentation](https://tia-toolbox.readthedocs.io/en/latest/usage.html?highlight=WSIReader.read_rect#tiatoolbox.wsicore.wsireader.WSIReader.read_rect).
- `masks`: A list of paths corresponding to the masks of WSIs in the     `imgs` list. These masks specify the regions in the original WSIs from which we want to extract patches. If the mask of a particular WSI is specified as `None`, then the labels for all patches of that WSI (even background regions)  would be predicted. This could cause unnecessary computation.
- `merge_predictions`: You can set this parameter to `True` if it's required to generate a 2D map of patch classification results. However, for large WSIs this will require large available memeory. An alternative (default) solution is to set `merge_predictions=False`, and then generate the 2D prediction maps using the `merge_predictions` function as you will see later on.

Since we are using a large WSI the patch extraction and prediction processes may take some time (make sure to set the `ON_GPU=True` if you have access to Cuda enabled GPU and PyTorch+Cuda).

In [None]:
with suppress_console_output():
    wsi_output = predictor.predict(
        imgs=[wsi_path],
        masks=None,
        mode="wsi",
        merge_predictions=False,
        ioconfig=wsi_ioconfig,
        return_probabilities=True,
        save_dir=global_save_dir / "wsi_predictions",
        on_gpu=ON_GPU,
    )

We see how the prediction model works on our whole-slide images by visualizing the `wsi_output`. We first need to merge patch prediction outputs and then visualize them as an overlay on the original image. As before, the `merge_predictions` method is used to merge the patch predictions. Here we set the parameters `resolution=1.25, units='power'` to generate the prediction map at 1.25x magnification. If you would like to have higher/lower resolution (bigger/smaller) prediction maps, you need to change these parameters accordingly. When the predictions are merged, use the `overlay_patch_prediction` function to overlay the prediction map on the WSI thumbnail, which should be extracted at the resolution used for prediction merging.

In [None]:
overview_resolution = (
    4  # the resolution in which we desire to merge and visualize the patch predictions
)
# the unit of the `resolution` parameter. Can be "power", "level", "mpp", or "baseline"
overview_unit = "mpp"
wsi = WSIReader.open(wsi_path)
wsi_overview = wsi.slide_thumbnail(resolution=overview_resolution, units=overview_unit)
plt.figure(), plt.imshow(wsi_overview)
plt.axis("off")

Overlaying the prediction map on this image as below gives:

In [None]:
# Visualization of whole-slide image patch-level prediction
# first set up a label to color mapping
label_color_dict = {}
label_color_dict[0] = ("empty", (0, 0, 0))
colors = cm.get_cmap("Set1").colors
for class_name, label in label_dict.items():
    label_color_dict[label + 1] = (class_name, 255 * np.array(colors[label]))

pred_map = predictor.merge_predictions(
    wsi_path,
    wsi_output[0],
    resolution=overview_resolution,
    units=overview_unit,
)
overlay = overlay_prediction_mask(
    wsi_overview,
    pred_map,
    alpha=0.5,
    label_info=label_color_dict,
    return_ax=True,
)
plt.show()

## Feature extraction with a pathology-specific model

In this section, we will show how to extract features from a pretrained pytorch model that exists outside TIAToolbox, using the WSI inference engines provided by tiatoolbox. To illustrate this we will use HistoEncoder, a computational-pathology specific model that has been trained in a self-supervised fashion to extract features from histology images. The model has been made available here:

'HistoEncoder: Foundation models for digital pathology' (https://github.com/jopo666/HistoEncoder) by Pohjonen, Joona and team at the University of Helsinki.

We will plot a umap reduction into 3D (rgb) of the feature map to visualize how the features capture the differences between some of the above mentioned tissue types.

In [None]:
# Import some extra modules
import histoencoder.functional as F
import torch.nn as nn

from tiatoolbox.models.engine.semantic_segmentor import DeepFeatureExtractor, IOSegmentorConfig
from tiatoolbox.models.models_abc import ModelABC
import umap

TIAToolbox defines a ModelABC which is a class inheriting PyTorch [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) and specifies how a model should look in order to be used in the TIAToolbox inference engines. The histoencoder model doesn't follow this structure, so we need to wrap it in a class whose output and methods are those that the TIAToolbox engine expects.

In [None]:
class HistoEncWrapper(ModelABC):
    """Wrapper for HistoEnc model that conforms to tiatoolbox ModelABC interface."""

    def __init__(self: HistoEncWrapper, encoder) -> None:
        super().__init__()
        self.feat_extract = encoder

    def forward(self: HistoEncWrapper, imgs: torch.Tensor) -> torch.Tensor:
        """Pass input data through the model.

        Args:
            imgs (torch.Tensor):
                Model input.

        """
        out = F.extract_features(self.feat_extract, imgs, num_blocks=2, avg_pool=True)
        return out

    @staticmethod
    def infer_batch(
        model: nn.Module,
        batch_data: torch.Tensor,
        *,
        on_gpu: bool,
    ) -> list[np.ndarray]:
        """Run inference on an input batch.

        Contains logic for forward operation as well as i/o aggregation.

        Args:
            model (nn.Module):
                PyTorch defined model.
            batch_data (torch.Tensor):
                A batch of data generated by
                `torch.utils.data.DataLoader`.
            on_gpu (bool):
                Whether to run inference on a GPU.

        """
        img_patches_device = batch_data.to('cuda') if on_gpu else batch_data
        model.eval()
        # Do not compute the gradient (not training)
        with torch.inference_mode():
            output = model(img_patches_device)
        return [output.cpu().numpy()]

Now that we have our wrapper, we will create our feature extraction model and instantiate a [DeepFeatureExtractor](https://tia-toolbox.readthedocs.io/en/v1.4.1/_autosummary/tiatoolbox.models.engine.semantic_segmentor.DeepFeatureExtractor.html) to allow us to use this model over a WSI. We will use the same WSI as above, but this time we will extract features from the patches of the WSI using the HistoEncoder model, rather than predicting some label for each patch.

In [None]:
# create the model
encoder = F.create_encoder("prostate_medium")
model = HistoEncWrapper(encoder)

# set the pre-processing function
norm=transforms.Normalize(mean=[0.662, 0.446, 0.605],std=[0.169, 0.190, 0.155])
trans = [
    transforms.ToTensor(),
    norm,
]
model.preproc_func = transforms.Compose(trans)

wsi_ioconfig = IOSegmentorConfig(
    input_resolutions=[{"units": "mpp", "resolution": 0.5}],
    patch_input_shape=[224, 224],
    output_resolutions=[{"units": "mpp", "resolution": 0.5}],
    patch_output_shape=[224, 224],
    stride_shape=[224, 224],
)

When we create the `DeepFeatureExtractor`, we will pass the `auto_generate_mask=True` argument. This will automatically create a mask of the tissue region using otsu thresholding, so that the extractor processes only those patches containing tissue.

In [None]:
# create the feature extractor and run it on the WSI
extractor = DeepFeatureExtractor(model=model, auto_generate_mask=True, batch_size=32, num_loader_workers=4, num_postproc_workers=4)
with suppress_console_output():
    out = extractor.predict(imgs=[wsi_path], mode="wsi", ioconfig=wsi_ioconfig, save_dir=global_save_dir / "wsi_features",)

These features could be used to train a downstream model, but here in order to get some intuition for what the features represent, we will use a UMAP reduction to visualize the features in RGB space. The points labeled in a similar color should have similar features, so we can check if the features naturally separate out into the different tissue regions when we overlay the UMAP reduction on the WSI thumbnail. We will plot it along with the patch-level prediction map from above to see how the features compare to the patch-level predictions in the following cells.

In [None]:
# First we define a function to calculate the umap reduction
def umap_reducer(x, dims=3, nns=10):
    """UMAP reduction of the input data."""
    reducer = umap.UMAP(n_neighbors=nns, n_components=dims, metric="manhattan", spread=0.5, random_state=2)
    reduced = reducer.fit_transform(x)
    reduced -= reduced.min(axis=0)
    reduced /= reduced.max(axis=0)
    return reduced

# load the features output by our feature extractor
pos = np.load(global_save_dir / "wsi_features" / "0.position.npy")
feats = np.load(global_save_dir / "wsi_features" / "0.features.0.npy")
pos = pos / 8 # as we extracted at 0.5mpp, and we are overlaying on a thumbnail at 4mpp

# reduce the features into 3 dimensional (rgb) space
reduced = umap_reducer(feats)

# plot the prediction map the classifier again
overlay = overlay_prediction_mask(
    wsi_overview,
    pred_map,
    alpha=0.5,
    label_info=label_color_dict,
    return_ax=True,
)

# plot the feature map reduction
plt.figure()
plt.imshow(wsi_overview)
plt.scatter(pos[:,0], pos[:,1], c=reduced, s=1, alpha=0.5)
plt.axis("off")
plt.title("UMAP reduction of HistoEnc features")
plt.show()

We see that the prediction map from our patch-level predictor, and the feature map from our self-supervised feature encoder, capture similar information about the tissue types in the WSI. This is a good sanity check that our models are working as expected. It also shows that the features extracted by the HistoEncoder model are capturing the differences between the tissue types, and so that they are encoding histologically relevant information.

## Where to Go From Here

In this notebook, we show how we can use the `PatchPredictor` and `DeepFeatureExtractor` classes and their `predict` method to predict the label, or extract features, for patches of big tiles and WSIs. We introduce `merge_predictions` and `overlay_prediction_mask` helper functions that merge the patch prediction outputs and visualize the resulting prediction map as an overlay on the input image/WSI.

All the processes take place within TIAToolbox and we can easily put the pieces together, following our example code. Please make sure to set inputs and options correctly. We encourage you to further investigate the effect on the prediction output of changing `predict` function parameters. We have demonstrated how to use your own pretrained model or one provided by the research community for a specific task in the TIAToolbox framework to do inference on large WSIs even if the model structure is not defined in the TIAToolbox model class.

You can learn more through the following resources:

- [Advanced model handling with PyTorch and TIAToolbox](https://tia-toolbox.readthedocs.io/en/latest/_notebooks/jnb/07-advanced-modeling.html)
- [Creating slide graphs for WSI with a custom PyTorch graph neural network](https://tia-toolbox.readthedocs.io/en/latest/_notebooks/jnb/full-pipelines/slide-graph.html)