# Supplementary Vignette 2

## Example workflow for H&E images

Here we demonstrate a typical workflow for preprocessing of H&E images, consisting of the following steps:

1. Loading the raw image
2. Defining a simple preprocessing pipeline for tissue detection
3. Creating a PyTorch DataLoader for interfacing with any downstream machine learning model

The image used in this example is publicly avilalable for download: http://openslide.cs.cmu.edu/download/openslide-testdata/Aperio/

**a. Load the image**

In [87]:
import os

In [2]:
import pathml

In [5]:
# os.environ['LD_LIBRARY_PATH']
# os.environ["JAVA_HOME"] = "/usr/lib/jvm/jdk-17/" 

In [4]:
# os.environ["LD_LIBRARY_PATH"] += ":/opt/conda/envs/standard/lib/python3.10/site-packages/nvidia/cublas/lib/libcublas.so.11"


In [3]:
# from pathml.core import SlideData, types

# # load the image
# wsi = SlideData("../../../bucket_data/CMU-1.svs", name = "example", slide_type = types.HE)

**b. Define a preprocessing pipeline**

Pipelines are created by composing a sequence of modular transformations; in this example we apply a blur to reduce noise in the image followed by tissue detection

In [2]:
from pathml.preprocessing import Pipeline, BoxBlur, TissueDetectionHE

pipeline = Pipeline([
    BoxBlur(kernel_size=15),
    TissueDetectionHE(mask_name = "tissue", min_region_size=500, 
                      threshold=30, outer_contours_only=True)
])

**c. Run preprocessing**

Now that we have constructed our pipeline, we are ready to run it on our WSI.
PathML supports distributed computing, speeding up processing by running tiles in parallel among many workers rather than processing each tile sequentially on a single worker. This is supported by [Dask.distributed](https://distributed.dask.org/en/latest/index.html) on the backend, and is highly scalable for very large datasets. 

The first step is to create a `Client` object. In this case, we will use a simple cluster running locally; however, Dask supports other setups including Kubernetes, SLURM, etc. See the [PathML documentation](https://pathml.readthedocs.io/en/latest/running_pipelines.html#distributed-processing) for more information.

In [None]:
from dask.distributed import Client, LocalCluster

cluster = LocalCluster(n_workers=6)
client = Client(cluster)


In [3]:
wsi.run(pipeline) #, distributed=True, client=client);

In [4]:
print(f"Total number of tiles extracted: {len(wsi.tiles)}")

Total number of tiles extracted: 22912


distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client
_GatheringFuture exception was never retrieved
future: <_GatheringFuture finished exception=CancelledError()>
asyncio.exceptions.CancelledError


**e. Save results to disk**

The resulting preprocessed data is written to disk, leveraging the HDF5 data specification optimized for efficiently manipulating larger-than-memory data.

In [5]:
wsi.write("../../../bucket_data/CMU-1-preprocessed.h5path")

**f. Create PyTorch DataLoader**

The `DataLoader` provides an interface with any machine learning model built on the PyTorch ecosystem

In [22]:
from pathml.ml import TileDataset
from torch.utils.data import DataLoader

dataset = TileDataset("../../../bucket_data/CMU-1-preprocessed.h5path")
dataloader = DataLoader(dataset, batch_size = 64, num_workers = 4)

In [4]:
dataset

<pathml.ml.dataset.TileDataset at 0x7faed7d653a0>

In [5]:
import jax
import jax.numpy as jnp

In [10]:
from jax import jit

In [29]:
class JaxTileDataset():
    """
    JAX Dataset class for h5path files.
    Each item is a tuple of (``tile_image``, ``tile_masks``, ``tile_labels``, ``slide_labels``) where:
        - ``tile_image`` is a jnp.ndarray of shape (C, H, W) or (T, Z, C, H, W)
        - ``tile_masks`` is a jnp.ndarray of shape (n_masks, tile_height, tile_width)
        - ``tile_labels`` is a dict
        - ``slide_labels`` is a dict
    Args:
        file_path (str): Path to .h5path file on disk
    """
    def __init__(self, file_path):
        self.dataset = TileDataset(file_path)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        tile_image, tile_masks, tile_labels, slide_labels = self.dataset[index]

        # convert to JAX arrays
        tile_image = jnp.array(tile_image)
        if tile_masks is not None:
            tile_masks = jnp.array(tile_masks)
        tile_labels = jax.tree_map(jnp.array, tile_labels)
        slide_labels = jax.tree_map(jnp.array, slide_labels)

        return tile_image, tile_masks, tile_labels, slide_labels


def jax_collate_fn(batch):
    """
    A custom collate_fn for JAX DataLoader that handles stacking of labels into a batch.
    """
    tile_images, tile_masks, tile_labels, slide_labels = zip(*batch)
    tile_images = jax.device_put(jnp.stack(tile_images))
    if tile_masks[0] is not None:
        tile_masks = jax.device_put(jnp.stack(tile_masks))
    else:
        tile_masks = None
    tile_labels = jax.tree_multimap(lambda *x: jnp.stack(x), *tile_labels)
    slide_labels = jax.tree_multimap(lambda *x: jnp.stack(x), *slide_labels)

    return tile_images, tile_masks, tile_labels, slide_labels


In [80]:
import numpy as np
from torch.utils import data
# from torchvision.datasets import MNIST

def numpy_collate(batch):
    
      if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
      elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
      else:
        return np.array(batch)
    
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], tuple) or isinstance(batch[0], list):
        images = np.stack([FlattenAndCast()(im) for im, _, _, _ in batch])
        masks = np.stack([mask for _, mask, _, _ in batch]) if batch[0][1] is not None else None
        labels = [label for _, _, label, _ in batch]
        slide_labels = [slide_label for _, _, _, slide_label in batch]
        return images, masks, labels, slide_labels
    else:
        return np.array(batch)


class NumpyLoader(data.DataLoader):
      def __init__(self, dataset, batch_size=1,
                    shuffle=False, sampler=None,
                    batch_sampler=None, num_workers=0,
                    pin_memory=False, drop_last=False,
                    timeout=0, worker_init_fn=None):
        super(self.__class__, self).__init__(dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            sampler=sampler,
            batch_sampler=batch_sampler,
            num_workers=num_workers,
            collate_fn=numpy_collate,
        #     collate_fn = lambda x: (
        # np.stack([FlattenAndCast()(im) for im, _, _, _ in x]),
        # np.stack([mask for _, mask, _, _ in x]) if x[0][1] is not None else None,
        # [label for _, _, label, _ in x],
        # [slide_label for _, _, _, slide_label in x]),
                
            pin_memory=pin_memory,
            drop_last=drop_last,
            timeout=timeout,
            worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
      def __call__(self, pic):
        return np.ravel(np.array(pic, dtype=jnp.float32))

In [81]:
training_generator = NumpyLoader(dataset, batch_size=batch_size, num_workers=0)

In [41]:
# for a,b,c,d in training_generator:
#     # y = one_hot(y, n_targets)
#     pass

In [72]:
import jax.numpy as jnp
import numpy as np
from torch.utils import data

# define FlattenAndCast transform
class FlattenAndCast(object):
    def __call__(self, pic):
        return np.ravel(np.array(pic, dtype=jnp.float32))



# create NumpyLoader
numpy_loader = NumpyLoader(
    dataset=dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4
)


In [82]:
a,b,c,d=next(iter(numpy_loader))

In [83]:
a.shape

(32, 196608)

In [84]:
b.shape

(32, 1, 256, 256)

In [30]:
jax_dataset = JaxTileDataset(file_path='../../../bucket_data/CMU-1-preprocessed.h5path')
jax_dataloader = jax_utils.JaxDataLoader(
    dataset=jax_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=jax_collate_fn,
    num_workers=num_workers,
)


NameError: name 'jax_utils' is not defined

In [14]:
import jax
import jax.numpy as jnp
import numpy as np

class JaxDataLoader:
    def __init__(self, dataset, batch_size, shuffle=False, num_workers=0):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
        self.index_generator = self._generate_indices()

    def __iter__(self):
        for batch_indices in self.index_generator:
            batch = self._get_batch(batch_indices)
            yield batch

    def __len__(self):
        return int(np.ceil(len(self.dataset) / self.batch_size))

    def _generate_indices(self):
        indices = np.arange(len(self.dataset))
        if self.shuffle:
            np.random.shuffle(indices)
        for i in range(0, len(indices), self.batch_size):
            yield indices[i:i + self.batch_size]

    @jax.jit
    def _get_batch(self, batch_indices):
        batch = [self.dataset[i] for i in batch_indices]
        return jax_collate_fn(batch)

@jax.jit
def jax_collate_fn(batch):
    """
    A custom collate_fn for JAX DataLoader that handles stacking of labels into a batch.
    """
    tile_images, tile_masks, tile_labels, slide_labels = zip(*batch)
    tile_images = jnp.stack(tile_images)
    if tile_masks[0] is not None:
        tile_masks = jnp.stack(tile_masks)
    else:
        tile_masks = None
    tile_labels = jax.tree_multimap(lambda *x: jnp.stack(x), *tile_labels)
    slide_labels = jax.tree_multimap(lambda *x: jnp.stack(x), *slide_labels)

    return tile_images, tile_masks, tile_labels, slide_labels

batch_size = 32
num_workers = 4
jax_dataset = JaxTileDataset('../../../bucket_data/CMU-1-preprocessed.h5path')
jax_dataloader = JaxDataLoader(
    dataset=jax_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
)


In [16]:
type(jax_dataloader)

__main__.JaxDataLoader

In [19]:
a= next(iter(jax_dataloader))

TypeError: Cannot interpret value of type <class '__main__.JaxDataLoader'> as an abstract array; it does not have a dtype attribute

In [21]:
import jax_dataloader as jdl

In [23]:
dataloader = jdl.DataLoader(dataset, 'jax', batch_size=5, shuffle=True)

ValueError: dataset (type=<class 'pathml.ml.dataset.TileDataset'>) is a pytorch dataset, which is only supported by 'pytorch' backend.However, we got `backend=jax`, which is not 'pytorch'.

In [24]:
# Create features `X` and labels `y`
X = jnp.arange(100).reshape(10, 10)
y = jnp.arange(10)
# Create an `ArrayDataset`
arr_ds = jdl.ArrayDataset(X, y)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [25]:
X.shape

(10, 10)

In [26]:
y.shape

(10,)

In [27]:
y

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [28]:
dataset

<pathml.ml.dataset.TileDataset at 0x7faeccc02f40>

In [17]:
for batch in jax_dataloader:
    # check if batch has the expected shape
    print(batch[0].shape)  # should print (batch_size, C, H, W) or (batch_size, T, Z, C, H, W)
    print(batch[1].shape)  # should print (batch_size, n_masks, tile_height, tile_width) or None
    # check if elements in batch are JAX arrays
    print(type(batch[0]))  # should print <class 'jax.interpreters.xla.DeviceArray'>
    print(type(batch[1]))  # should print <class 'jax.interpreters.xla.DeviceArray'> or NoneType


TypeError: Cannot interpret value of type <class '__main__.JaxDataLoader'> as an abstract array; it does not have a dtype attribute

### Summary

Here we demonstrate a complete `PathML` workflow for analyzing brightfield images:

1. Loading the raw image
2. Define a simple preprocessing pipeline for tissue detection
3. Create a PyTorch DataLoader for interfacing with any downstream machine learning model

Full documentation of the `PathML` API is available at https://pathml.org.  

Full code for this vignette is available at https://github.com/Dana-Farber-AIOS/pathml/tree/master/examples/manuscript_vignettes_stable