# Download example image
We can download some example images from this site: https://openslide.cs.cmu.edu/download/openslide-testdata/
Use the following code (or manually) to download the example image data.

In [1]:
import urllib.request
download_url = "https://openslide.cs.cmu.edu/download/openslide-testdata/Aperio/CMU-2.svs"
urllib.request.urlretrieve(download_url, "CMU-2.svs")


('CMU-2.svs', <http.client.HTTPMessage at 0x7f5e70608820>)

# Extract image patches from whole slide images
Please refer to: https://github.com/smujiang/WSITools

In [1]:
import os
wsi_fn = "./CMU-2.svs"             # Define a sample image that can be read by OpenSlide
output_dir = "./patches"    # Define an output directory
log_dir = "./logs"

if not os.path.exists(output_dir):
    os.makedirs(output_dir)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
print(log_dir)
print(output_dir)

./logs
./patches


In [None]:
from wsitools.tissue_detection.tissue_detector import TissueDetector
from wsitools.patch_extraction.patch_extractor import ExtractorParameters, PatchExtractor

# Define the parameters for Patch Extraction, including generating an thumbnail from which to traverse over to find
# tissue.
parameters = ExtractorParameters(output_dir, # Where the patches should be extracted to
    save_format = '.png',                      # Can be '.jpg', '.png', or '.tfrecord'
    sample_cnt = -1,                           # Limit the number of patches to extract (-1 == all patches)
    patch_size = 128,                          # Size of patches to extract (Height & Width)
    rescale_rate = 128,                        # Fold size to scale the thumbnail to (for faster processing)
    patch_filter_by_area = 0.5,                # Amount of tissue that should be present in a patch
    with_anno = True,                          # If true, you need to supply an additional XML file
    extract_layer = 0,                          # OpenSlide Level
    log_dir=log_dir
    )

# Choose a method for detecting tissue in thumbnail image
tissue_detector = TissueDetector("LAB_Threshold",   # Can be LAB_Threshold or GNB
    threshold = 85,                                   # Number from 1-255, anything less than this number means there is tissue
    training_files = None                             # Training file for GNB-based detection
    )

# Create the extractor object
patch_extractor = PatchExtractor(tissue_detector,
    parameters,
    feature_map = None,                       # See note below
    annotations = None                        # Object of Annotation Class (see other note below)
    )

patch_extractor.extract([wsi_fn])

2024-10-07 10:42:08.796734: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Show where the patches were extracted from

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
log_img_fn = os.listdir(log_dir)[0]
log_img = os.path.join(log_dir, log_img_fn)
img = Image.open(log_img)
plt.imshow(img)
plt.show()

# Normalize the extracted patches
Need to install staintools. Please refer to https://github.com/Peter554/StainTools

In [None]:
import staintools
from PIL import Image
template_img = "./template.png"  # TODO: select an template image, should be the same size of the image tiles to be normalized
normalized_patches_dir = "./normalized_patches"
if not os.path.exists(normalized_patches_dir):
    os.makedirs(normalized_patches_dir)

img_fn_list = os.listdir(output_dir)
for img_fn in img_fn_list:
    target = staintools.read_image(os.path.join(output_dir, img_fn))
    to_transform = staintools.read_image(template_img)

    # Standardize brightness (optional, can improve the tissue mask calculation)
    target = staintools.LuminosityStandardizer.standardize(target)
    to_transform = staintools.LuminosityStandardizer.standardize(to_transform)

    # Stain normalize
    normalizer = staintools.StainNormalizer(method='vahadane')
    normalizer.fit(target)
    transformed = normalizer.transform(to_transform)

    sv_fn = os.path.join(normalized_patches_dir, img_fn)
    transformed.save(sv_fn)

# Download pretrained CTransPath model
Please refer to this site: https://github.com/Xiyue-Wang/TransPath.
The model can be downloaded from here: https://drive.google.com/file/d/1DoDx_70_TLj98gTf6YTXnu4tFhsFocDX/view?usp=sharing
Download the file, and save to "./CTransPath/ctranspath.pth"

In [None]:
model_sv = "./CTransPath/ctranspath.pth"

# Get image embedding

In [None]:
import pandas as pd
import numpy as np
import time
import torch, torchvision
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset
from ctran import ctranspath

class roi_dataset(Dataset):
    def __init__(self, img_csv,):
        super().__init__()
        self.transform = trnsfrms_val
        self.images_lst = img_csv

    def __len__(self):
        return len(self.images_lst)

    def __getitem__(self, idx):
        path = self.images_lst.filename[idx]
        image = Image.open(path).convert('RGB')
        image = self.transform(image)
        return image

embeddings_csv = "./embeddings/img_embeddings.csv"
if not os.path.exists(os.path.split(embeddings_csv)[0]):
    os.makedirs(os.path.split(embeddings_csv)[0])

if __name__ == "__main__":
    start_time = time.time()
    mean = (0.6373, 0.5260, 0.7438)
    std = (0.1089, 0.1249, 0.0710)
    trnsfrms_val = transforms.Compose(
        [
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)
        ]
    )

    img_csv = pd.read_csv(r'All_HE_img_samples.csv')
    test_datat = roi_dataset(img_csv)
    database_loader = torch.utils.data.DataLoader(test_datat, batch_size=10, shuffle=False)

    model = ctranspath()
    model.head = nn.Identity()
    td = torch.load(model_sv)
    model.load_state_dict(td['model'], strict=True)

    model.eval()

    embed_list = []
    with torch.no_grad():
        for batch in database_loader:
            features = model(batch)
            features = features.cpu().numpy()

            embed_list.append(features)

    all_embeds = np.concatenate(embed_list)
    ##embeddings_standardized = StandardScaler().fit_transform(all_embeds)
    print("--- %s minutes ---" % ((time.time() - start_time) / 60))

    np.savetxt(embeddings_csv, all_embeds, delimiter=",")

# Result visualization

In [None]:
import pandas as pd
import umap

import numpy as np

df = pd.read_csv(embeddings_csv).astype(float)
img_features = np.array(df.iloc[:, :-1])
lb = list(df.iloc[:, -1])
lb_int = [int(x) for x in lb]

dm_red = umap.UMAP(random_state=12)
pca_cell_f = dm_red.fit_transform(img_features)

plt.scatter(pca_cell_f[:, 0], pca_cell_f[:, 1], marker=".", s=1)
plt.title("Image embeddings UMap")
plt.show()