In [None]:
%matplotlib inline

<center>
<h1> How to apply Feature Aware Normalization to a novel dataset </h1>
</center>
<center>
Steffen Schneider and Daniel Bug <br />
Institute of Imaging & Computer Vision <br />
steffen.schneider@rwth-aachen.de
</center>


In this tutorial, we will demonstrate the feature aware normalization module [1] on the digital pathology dataset by [2,3]. Details on the approach are available in our [paper](https://arxiv.org/abs/1708.04099) as well as on the [project page](https://stes.github.io/fan).

If you use this code in your research, please cite our [paper](https://arxiv.org/abs/1708.04099):

```
@incollection{bug2017context,
  title={Context-based Normalization of Histological Stains using Deep Convolutional Features},
  author={Bug, Daniel and Schneider, Steffen and Grote, Anne and Oswald, Eva and Feuerhake, Friedrich and Sch{\"u}ler, Julia and Merhof, Dorit},
  booktitle={Deep Learning in Medical Image Analysis and Multimodal Learning for Clinical Decision Support},
  pages={135--142},
  year={2017},
  publisher={Springer}
}
```

We also have a pre-print available on arxiv: [abs/1708.04099](https://arxiv.org/abs/1708.04099).

## Introduction

Feature Aware Normalization is a technique for normalizing images based on context information estimated by a Feature Extraction Network.
Instead of normal batch normalization layers, in FAN, the shift and scaling parameters $\beta$ and $\gamma$ are functions of a feature representation $z$ computed from the input image.

![Unnormalized Images](docs/img/BAS_unnormalized_A.jpg)
![FAN normalized Images](docs/img/FAN_HoEoTp_A.jpg)

## Dataset and Network weights

In this tutorial, we will use the public available dataset of Kather et al. (2016).

The data is licensed under the [Creative Commons Attribution 4.0 International License](http://creativecommons.org/licenses/by/4.0/).
The data can be accessed via the following DOI: [10.5281/zenodo.53169.](dx.doi.org/10.5281/zenodo.53169)

#### Download the data

In [None]:
# Training dataset, 5000 patches with resolution 150x150 (258.1 MB)
#!wget https://zenodo.org/record/53169/files/Kather_texture_2016_image_tiles_5000.zip
# MD5SUM 0ddbebfc56344752028fda72602aaade

# Validation dataset, 10 patches with resolution 5000x5000 (742.0 MB)
# !wget https://zenodo.org/record/53169/files/Kather_texture_2016_larger_images_10.zip
# MD5SUM ff6e18f484c5d324b049ed2ec133d9cc

Integrety check

In [None]:
#! cd data && md5sum --check MD5SUM

Extract data

In [None]:
#! unzip -qq -o data/Kather_texture_2016_image_tiles_5000.zip -d data

#### Download the network weights

We will need both the weights for the feature extactor (defaults to the VGG19 network) and the normalization network. The latter was trained as outlined in our original publication [1].

In [None]:
# ! cd weights && wget "https://s3.amazonaws.com/lasagne/recipes/pretrained/imagenet/vgg19_normalized.pkl"

Integrety check

In [None]:
# check file
! echo "cb8ee699c50a64f8fef2a82bfbb307c5  weights/vgg19_normalized.pkl" | md5sum --check

## Normalization with FAN

In [None]:
import numpy as np
import os
from skimage import io
from tifffile import imread

In [None]:
root = "data/Kather_texture_2016_image_tiles_5000"
assert os.path.exists(root)

In [None]:
def load_small_patches(root):
    classes = os.listdir(root)
    X     = []
    y     = []
    lbl   = []
    fnames = []
    
    for i, cl in enumerate(sorted(classes)):
        classdir = os.path.join(root, cl)
        if not os.path.isdir(classdir): continue
        imgfiles = os.listdir(classdir)
        
        for j, fname in enumerate(imgfiles):
            path = os.path.join(root, cl, fname)
            X.append(imread(path))
            y.append(i)
            fnames.append(path)
        lbl.append(cl)
        
    X = np.stack(X, axis=0)
    y = np.stack(y, axis=0)
    
    return X, y, lbl, fnames

In [None]:
X, y, lbl, fnames = load_small_patches(root)
print("Loaded {} images of size {}x{}".format(*X.shape[0:3]))

In [None]:
from stainnorm import fan

%load_ext autoreload
%autoreload 2

In [None]:
model = fan.NormalizationNetwork(fname='weights/171028-weights-dlmia.npz',
                                patch_size=300,
                                batch_size=10)

In [None]:
model, X.shape

In [None]:
unnormed = model.crop(X[100:103].repeat(2,axis=1).repeat(2,axis=2))
normed = model(X[100:103].repeat(2,axis=1).repeat(2,axis=2))

In [None]:
unnormed.shape

In [None]:
colors = [ X[y == 3].max(axis=(1,2)),
           X[y == 3].min(axis=(1,2)),
           X[y == 3].mean(axis=(1,2)),
           X[y == 3].std(axis=(1,2)) ]

colors = np.concatenate(colors, axis=-1)
colors.shape

## Dataset preprocessing

We'll preprocess the dataset for visualization purposes. Patches are labeled according to eight different classes.
For purposes of comparison, we'll employ a very simple clustering scheme based on the average value of the RGB channels to estimate the protocols

In [None]:
from sklearn.manifold import TSNE

tsne = TSNE()
embedding = tsne.fit_transform(colors)

In [None]:
plt.scatter(embedding[:,0], embedding[:,1])
plt.show()

In [None]:
from sklearn.mixture import GaussianMixture

gmm = GaussianMixture(10)
gmm.fit(embedding)
yd = gmm.predict(embedding)

In [None]:
plt.scatter(embedding[:,0], embedding[:,1], c=yd)
plt.show()

In [None]:
from stainnorm.tools import panelize
from itertools import product

overview = np.zeros((10,10)+X.shape[1:])

for i in range(overview.shape[0]):
    for j in range(overview.shape[1]):
        x = X[y == 3][yd == j]
        if len(x) > i:     
            overview[i,j,...] = x[i]
print(overview.shape)
            
grid = np.concatenate(np.concatenate(overview, axis=1), axis=1)
print(grid.shape)
plt.figure(figsize=(20,10))
plt.imshow(grid / 255.)

In [None]:
normed = []
for batch in overview.transpose((1,0,2,3,4)):
    normed.append(model(batch.repeat(2,axis=1).repeat(2,axis=2)))
    print("finsished")
    
normed = np.stack(normed, axis=1)

In [None]:
normed.shape, overview.shape

grid = np.concatenate(np.concatenate(normed, axis=1), axis=1)
print(grid.shape)
plt.figure(figsize=(20,10))
plt.imshow(grid / 255.)

In [None]:
normed.shape
import matplotlib.pyplot as plt

for u, n in zip(unnormed, normed):
    fig, (ax1, ax2) = plt.subplots(1,2,figsize=(10,5))

    ax1.imshow(u/255.)
    ax2.imshow(n/255.)
    plt.show()

# References

1. Bug et al.
2. Kather, Jakob Nikolas, et al. "Multi-class texture analysis in colorectal cancer histology." Scientific reports 6 (2016): 27988. DOI [10.1038/srep27988](https://dx.doi.org/10.1038/srep27988)
3. Kather, Jakob Nikolas, et al. "Collection of textures in colorectal cancer histology, May 2016." Last Accessed 19 (2016). DOI [10.5281/zenodo.53169](https://dx.doi.org/10.5281/zenodo.53169)