# HPA21 Cellpose EDA

I wanted to be able to try the cellpose library for cell segmentation for the HPA21 competition.

Unfortunately cellpose is not one of the standard libraries avaiable to the competition notebooks
and this presented a number of technical challenges to load the necessary library and the pre-trained
model weights that the library uses without internet access (submission notebooks have to have internet
disabled).

This notebook demonstrates a method to offline load both the necessary libraries and the model weights.
You will also need to utilise the following notebooks if you want to further pursue use of cellpose.

## Offline package preparation

[Offline Package Wheeler](https://www.kaggle.com/andrewscholan/offline-package-wheeler-public)

This workbook is actually a general purpose notebook that imports a set of libraries using an internet
connection then builds the necessary wheel files so that they can be packaged into a Kaggle dataset. This
is then used to offline load the libraries (see below).

## Offline model weights preparation

[Cellpose Model Collector](https://www.kaggle.com/andrewscholan/cellpose-model-collector-public)

This workbook downloads the necessary model weights used by cellpose so that they too can be
packaged up into a Kaggle dataset.

# Exploratory Data Analysis

The final part of the notebook simply loads in some of the data and plots it with the masks detected by
cellpose.

# Torch Custom Dataset/Rescale Transform

There is also a custom dataset for Pytorch that loads the images into suitable tensors and
a transform that is used to rescale the images before passing them to cellpose. These might prove
useful...?

# GPU/Internet settings

This notebook should run correctly with GPU enabled and Internet turned off.

In [None]:
from __future__ import annotations

# Imports

In [None]:
import json
import glob
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import PIL
import requests
import socket

from tqdm.notebook import tqdm as show_progress
from typing import Optional, Tuple, List, Any
from shutil import rmtree, copytree

In [None]:
import seaborn as sns

# Notebook configuration

In [None]:
NOTEBOOK_NAME = "HPA21 Cellpose EDA"

QUICK_TEST = True

TRAIN_SLICE = slice(1000) if QUICK_TEST else slice(None)
TEST_SLICE = slice(1000) if QUICK_TEST else slice(None)

# These are dataframe column names
ID = "ID"
IMAGE_HEIGHT = "ImageHeight"
IMAGE_WIDTH = "ImageWidth"
IMAGE_SHAPE = "ImageShape"
LABEL = "Label"

In [None]:
# List of the dataset channel names
PROTEIN_OF_INTEREST = "green"
NUCLEUS = "blue"
MICROTUBULE = "red"
ENDOPLASMIC_RETICULUM = "yellow"
CHANNEL_NAMES = {
    PROTEIN_OF_INTEREST: "Protein of interest", 
    NUCLEUS: "Nucleus", 
    MICROTUBULE: "microtubule",
    ENDOPLASMIC_RETICULUM: "endoplasmic_reticulum",
}

# List of the label names
LABELS = [
    "Nucleoplasm",
    "Nuclear membrane",
    "Nucleoli",
    "Nucleoli fibrillar center",
    "Nuclear speckles",
    "Nuclear bodies",
    "Endoplasmic reticulum",
    "Golgi apparatus",
    "Intermediate filaments",
    "Actin filaments",
    "Microtubules",
    "Mitotic spindle",
    "Centrosome",
    "Plasma membrane",
    "Mitochondria",
    "Aggresome",
    "Cytosol",
    "Vesicles and punctate cytosolic patterns",
    "Negative",
]
NUM_LABELS = len(LABELS)

In [None]:
# Make plots larger
plt.rcParams['figure.figsize'] = [16, 9]
plt.rcParams['figure.dpi'] = 120

# Folder Roots

In [None]:
INPUT_ROOT = os.path.abspath(os.path.join("..", "input"))
TEMP_ROOT = os.path.abspath(os.path.join("..", "temp"))
WORKING_ROOT = os.path.abspath(".")
print(f"INPUT_ROOT='{INPUT_ROOT}'")
print(f"TEMP_ROOT='{TEMP_ROOT}'")
print(f"WORKING_ROOT='{WORKING_ROOT}'")

In [None]:
HPA21_DATASET_PATH = os.path.join(INPUT_ROOT, "hpa-single-cell-image-classification")
print(f"HPA21_DATASET_PATH='{HPA21_DATASET_PATH}'")

# Offline load additional packages

See [Offline Package Wheeler](https://www.kaggle.com/andrewscholan/offline-package-wheeler-public)

In [None]:
EXTRA_PACKAGES_ROOT = os.path.join(INPUT_ROOT, "hpa21-extra-packages")
requirements_txt = os.path.join(EXTRA_PACKAGES_ROOT, "requirements.txt")
wheels_path = os.path.join(EXTRA_PACKAGES_ROOT, "wheels")

In [None]:
with open(requirements_txt, "r") as f:
    requirements = f.readlines()
    for requirement in requirements:
        print(requirement.strip())

In [None]:
!pip install \
    --requirement {requirements_txt} \
    --no-index \
    --find-links file://{wheels_path}

# Offline load the Cellpose model weights

See [Cellpose Model Collector](https://www.kaggle.com/andrewscholan/cellpose-model-collector-public)

This copies the model weights to the cache folder that cellpose uses so that it doesn't
attempt to download them from a non-existant internet connection

In [None]:
cellpose_cache_path = os.path.join("/", "root", ".cellpose")
cellpose_cache_path

In [None]:
cellpose_models_as_dataset = os.path.join(INPUT_ROOT, "cellpose-models", "cellpose_models")
cellpose_models_as_dataset

In [None]:
if not os.path.exists(cellpose_cache_path):
    copytree(cellpose_models_as_dataset, cellpose_cache_path)

## And import the modules

In [None]:
import torch
import torchvision
import torchaudio
import cellpose
print(f"pytorch=={torch.__version__}")
print(f"torchvision=={torchvision.__version__}")
print(f"torchaudio=={torchaudio.__version__}")

In [None]:
# Set torchaudio back-end for cross-platform use
torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE = False
torchaudio.set_audio_backend("soundfile")

In [None]:
from cellpose import io as cellpose_io
import torchvision.transforms.functional as TF

# GPU detection/setup

In [None]:
# Execute the following command. If we are in a notebook on a system with
# an NVIDIA GPU it will give us the CUDA Version. This needs to match the
# version of PyTorch that we installed above.
# If we are not on a GPU enabled host then it will say command not found.
!nvidia-smi

In [None]:
import torch.cuda
use_gpu = torch.cuda.is_available()
device_math = torch.device(type="cuda" if use_gpu else "cpu")
device_cpu = torch.device("cpu")
if device_math.type=="cuda":
    print(f"Using GPU when tensors are loaded to 'device_math'.\n"
          f"  Device: {torch.cuda.get_device_name()}\n"
          f"  Number of GPUs: {torch.cuda.device_count()}\n"
          f"  GPU initialised: {torch.cuda.is_initialized()}\n")
    print(torch.cuda.memory_summary())
else:
    print("GPU not available, all tensors loaded to 'device_math' will reside "
          "in CPU memory.")

# Visualise some images

## Load in the CSV file

In [None]:
# Load the training CSV
train_csv = os.path.join(HPA21_DATASET_PATH, "train.csv")
train_df = pd.read_csv(train_csv)
train_df.head()

## One-hot encode the labels

In [None]:
labels = train_df[LABEL].str.split("|")
labels.head()

In [None]:
has_label = lambda label_list, label_num: 1.0 if str(label_num) in label_list else 0.0

In [None]:
label_truth = {label_num: labels.apply(has_label, label_num=label_num)
               for label_num in range(NUM_LABELS)}

In [None]:
label_truth[ID] = train_df[ID]
label_truth[ID].head()

In [None]:
train_truth = pd.DataFrame(label_truth).set_index(ID)
train_truth.head()

In [None]:
train_truth_melted = train_truth.melt(value_vars=train_truth.columns)
train_truth_melted = train_truth_melted[train_truth_melted["value"] != 0.0]

In [None]:
sns.catplot(x="variable", kind="count", data=train_truth_melted, 
            order=train_truth_melted["variable"].value_counts().index);

# Get the image height and width

In [None]:
# Function that computes the image height and width for all images
# in a folder.
def read_image_dimensions(
    folder_path: str, 
    file_ext: str = ".png",
    id: str = ID, 
    height: str = IMAGE_HEIGHT, 
    width: str = IMAGE_WIDTH,
    shape: str = IMAGE_SHAPE,
    channel: str = PROTEIN_OF_INTEREST,
    file_slice: slice = slice(None),
) -> pd.DataFrame:
    """
    Reads the image height and width for all files in the folder path
    that end in with the string specified in channel.
    
    Args:
        folder_path: The path to the folder containing the image files.
        file_ext: The file extension that we are interested in.
        id: The column name for the file id. The file id is acquired from
            the filename which is assumed to be in the format 
            <id>_<channel>.ext.
        height: The column name for the image height.
        width: The column name for the image width.
        shape: The column name for the image shape.
        channel: The channel image to use for sizing (see id).
        slice: Slice to apply to the file list (used for testing).
        
    Returns:
        Dataframe consisting of columns {id}, {width}. {height} and {[channel_path]} indexed
        by {id}.
    """
    # Get all the file paths that match the channel
    file_filter = os.path.join(folder_path, f"*_{channel}{file_ext}")
    file_paths = glob.glob(file_filter)[file_slice]
    
    image_size: Tuple[int, int] = lambda file_path: PIL.Image.open(file_path).size
    def file_id(file_path: str) -> str:
        _, file_name = os.path.split(file_path)
        file_name, _ = os.path.splitext(file_name)
        id, _ = file_name.split("_")
        return id

    ids = [file_id(file_path) for file_path in file_paths]
    widths_and_heights = [image_size(file_path) for file_path in show_progress(file_paths)]
    
    df_dict = {
        id: ids,
        width: [image_width for image_width, _ in widths_and_heights],
        height: [image_height for _, image_height in widths_and_heights],
        shape: [f"W{image_width}xH{image_height}" 
                for image_width, image_height in widths_and_heights]
    }
    
    return pd.DataFrame(df_dict).set_index(id)

In [None]:
train_path = os.path.join(HPA21_DATASET_PATH, "train")
train_image_sizes = read_image_dimensions(train_path, file_slice=TRAIN_SLICE)
train_image_sizes.head()

In [None]:
test_path = os.path.join(HPA21_DATASET_PATH, "test")
test_image_sizes = read_image_dimensions(test_path, file_slice=TEST_SLICE)
test_image_sizes.head()

In [None]:
train_image_shapes = set(train_image_sizes[IMAGE_SHAPE].unique())
test_image_shapes = set(test_image_sizes[IMAGE_SHAPE].unique())
image_shapes = sorted(list(train_image_shapes | test_image_shapes))

In [None]:
sns.catplot(x=IMAGE_SHAPE, kind="count", data=train_image_sizes,
            order=image_shapes);

In [None]:
sns.catplot(x=IMAGE_SHAPE, kind="count", data=test_image_sizes,
            order=image_shapes);

# Add shapes to truth dataframes

In [None]:
# Get the subset of the truth that corresponds with the slice that we are using
# for this note-book by using the index of the images we sized
train_truth_slice = train_truth.loc[train_image_sizes.index, :]
train = pd.concat([train_truth_slice, train_image_sizes], axis="columns")
train

In [None]:
# And given that the test data is not labelled then our "truth" for the test
# data is the image size dataframe
test = test_image_sizes
test

# Transform to rescale the images
The images are very high resolution, but I suspect that for cell segmentation we don't need to use them that size.

In [None]:
class RescaleImage(object):
    """
    Transform callable class that can be used to transform the tensor size by a
    scale value. Under the hood it uses the torchvision.functional.resize transform
    to do the work.
    """
    
    def __init__(self, scale:int = 0.25, interpolation:int = PIL.Image.BILINEAR):
        """
        This initailises the transform function.
        
        Args:
            scale: The scale factor for the image W and H dimensions (E.g. 0.5 will
                reduce the image dimensions by 50% in W and H)
            interpolation: The interpolation technique to use for the rescaling.
        """
        self._scale = scale
        self._interpolation = interpolation
        
    def __call__(self, t:torch.Tensor) -> torch.Tensor:
        """
        This is the transforming function. It takes as input a tensor or shape
        [..., H, W] and returns a tensor of the same shape but with dimensions H and
        W scaled as set in the initialiser.
        
        Args:
            t: Input tensor of shape [..., H, W]
            
        Returns:
            Tensor of shape [..., H, W] with scaled dimensions
        """
        shape_in = t.size()
        height, width = list(shape_in)[-2:]
        new_size = [int(height * self._scale), int(width * self._scale)]
        t_out = TF.resize(
            img=t, size=new_size, interpolation=self._interpolation
        )
        return t_out


# Data loader

This is a data loader that takes images of the same size and combines the individual images into composites.

In [None]:
# Dataset for HPA21 images.
class CellImageDataset(torch.utils.data.Dataset):
    """
    Defines an image dataset for the HPA21 images.
    """
    
    def __init__(
        self, 
        images: pd.DataFrame, 
        root_dir: str,
        file_ext: str = ".png",
        id: str = ID,
        channels: List[List[str]] = [
            [ENDOPLASMIC_RETICULUM, NUCLEUS, MICROTUBULE], 
            [PROTEIN_OF_INTEREST]
        ],
        target_cols=None, 
        transform=None, 
        transform_target=None
    ) -> List[torch.Tensor]:
        """
        Creates the dataset based on the images dataframe.
        
        Args:
            images: Dataframe indexed by id; the id is the root of the filename.
            root_dir: The directory containing all of the images.
            id: The name of the ID column.
            channels: List of tesnsors to return and the channels to include in each tensor
            target_cols: List of column names that constitututes the target labels. Set
                to None if there are no labels.
            transform: Transform to perform on the image after loading.
            transform_target: Transform to perform on the target tensor.
        """
        super().__init__()
        self._images = images.reset_index()
        self._root_dir = root_dir
        self._file_ext = file_ext
        self._id = id
        self._channels = channels
        self._target_cols = target_cols
        self._transform = transform
        self._transform_target = transform_target
        
    def __len__(self):
        """
        Return the size of the dataset. This is defined by the dataframe that was input.
        
        Returns:
            The size of the dataset (based on the dataframe passed in).
        """
        return len(self._images)
    
    def __getitem__(self, idx):
        """
        Gets the data associated with a specific index.
        
        Args:
            idx: Index of data to get.
            
        Returns:
            A list of tensors; the data tensors are defined by the channels
            parameter; the labels tensor is appended if target_cols is specified.
        """
        # Get this image ID
        this_image = self._images.iloc[idx, :]
        this_image_id = this_image[self._id]
        
        # Build the filename lists to load and combine
        channel_file_paths = [
            [
                os.path.join(
                    self._root_dir, f"{this_image_id}_{channel}{self._file_ext}"
                )
                for channel in tensor_channels
            ]
            for tensor_channels in self._channels
        ]
        
        # Now load the images as numpy arrays using the cellpose_io imread
        channel_image_arrays = [
            np.stack(
                [
                    cellpose_io.imread(file_path)       
                    for file_path in tensor_channel_file_paths
                ]
            )
            for tensor_channel_file_paths in channel_file_paths
        ]
        
        # Convert to torch tensors
        channel_tensors = [
            torch.tensor(channel_image_array.astype(np.float32) / 256.0)
            for channel_image_array in channel_image_arrays
        ]
        
        # Now apply any transforms to them
        if self._transform is not None:
            channel_tensors = [self._transform(channel_tensor) 
                               for channel_tensor in channel_tensors]
        
        # Now process the targets, if necessary
        if self._target_cols is not None:
            # Extract the labels as a numpy array
            targets = this_image.loc[target_cols].values.astype(np.float32)
            # Convert to a tensor and shape to be CT
            target_tensors = [torch.tensor(targets).reshape(1, -1)]
        else:
            target_tensors = []
    
        # And apply the target transform
        if self._transform_target is not None:
            target_tensors = [self._transform_target(target_tensor)
                              for target_tensor in target_tensors]
    
        return tuple(channel_tensors + target_tensors)
    

In [None]:
target_cols = list(range(NUM_LABELS))
train_dataset = CellImageDataset(
    images = train,
    root_dir = train_path,
    target_cols = target_cols,
    transform = RescaleImage()
)

## Build the datasets by image label - for visualisations

In [None]:
train_datasets = {
    label: CellImageDataset(
        images = train[train[label]==1.0],
        root_dir = train_path,
        target_cols = target_cols,
        transform = RescaleImage()
    )
    for label in target_cols
}
train_datasets

In [None]:
for label, dataset in train_datasets.items():
    print(f"{label}: {len(dataset)}")

# Plot an image for each label type

In [None]:
cell_image, protein_image, labels = next(iter(train_datasets[0]))
cell_image.size()

In [None]:
plot_images = {label: next(iter(train_datasets[label]))
               for label in show_progress(target_cols)
               if len(train_datasets[label])>0}

make_im = lambda t: t.detach().cpu().numpy().transpose((1, 2, 0))

cell_images = {label: make_im(image)
               for label, (image, _, _) in plot_images.items()}
protein_images = {label: make_im(image)
                  for label, (_, image, _) in plot_images.items()}

In [None]:
for index, label in enumerate(sorted(plot_images.keys())):
    label_name = LABELS[label]
    cell_image = cell_images[label]
    protein_image = protein_images[label]
    
    plt.subplot(1, 4, 1)
    plt.imshow(cell_image, origin="lower")
    
    plt.subplot(1, 4, 3)
    plt.imshow(protein_image, cmap="Greens", origin="lower")
    plt.title(f"{label}: {label_name}")
    
    plt.show()

# Segmentation with Cellpose

In [None]:
from cellpose import models

# DEFINE CELLPOSE MODEL
# model_type='cyto' or model_type='nuclei'
# We want the cell extent so set it as cyto
model = models.Cellpose(gpu=use_gpu, model_type='cyto', torch=True)

# define CHANNELS to run segementation on
# grayscale=0, R=1, G=2, B=3
# channels = [cytoplasm, nucleus]
# if NUCLEUS channel does not exist, set the second channel to 0
# channels = [0,0]
# IF ALL YOUR IMAGES ARE THE SAME TYPE, you can give a list with 2 elements
# channels = [0,0] # IF YOU HAVE GRAYSCALE
# channels = [2,3] # IF YOU HAVE G=cytoplasm and B=nucleus
# channels = [2,1] # IF YOU HAVE G=cytoplasm and R=nucleus
channels = [1, 2]

# if diameter is set to None, the size of the cells is estimated on a per image basis
# you can set the average cell `diameter` in pixels yourself (recommended) 
# diameter can be a list or a single number for all images
diameter = None

images = list(cell_images.values())

In [None]:
# you can run all in a list e.g.
masks, flows, styles, diams = model.eval(images, diameter=diameter, channels=channels)

In [None]:
for index, label in enumerate(sorted(plot_images.keys())):
    label_name = LABELS[label]
    cell_image = cell_images[label]
    protein_image = protein_images[label]
    mask = masks[index]
    
    plt.subplot(1, 4, 1)
    plt.imshow(cell_image, origin="lower")
    
    plt.subplot(1, 4, 2)
    mask_image = cellpose.plot.mask_rgb(mask)
    plt.imshow(mask_image, origin="lower")
    
    plt.subplot(1, 4, 3)
    plt.title(f"{label}: {label_name}")
    plt.imshow(protein_image, cmap="Greens", origin="lower")
    
    plt.subplot(1, 4, 4)
    mask_image = cellpose.plot.mask_overlay(protein_image, mask)
    plt.imshow(mask_image, origin="lower")
    
    plt.show()
