# DenoiSeg Example: Example Flywing data
This is an example notebook which illustrates how DenoiSeg should be trained. In this notebook we use a membrane labeled developing Fly Wing dataset from our collaborators. This notebook can be used as a reference to train DenoiSeg networks on your own data.

In [None]:
# Here we are just importing some libraries which are needed to run this notebook.

import warnings
warnings.filterwarnings('ignore')

import numpy as np
from matplotlib import pyplot as plt
from scipy import ndimage

from denoiseg.models import DenoiSeg, DenoiSegConfig
from denoiseg.utils.misc_utils import combine_train_test_data, shuffle_train_data, augment_data
from denoiseg.utils.seg_utils import *
from denoiseg.utils.compute_precision_threshold import measure_precision, compute_labels
from denoiseg.utils.denoiseg_data_preprocessing import generate_patches_from_list

from csbdeep.utils import plot_history
from tifffile import imread, imsave
from glob import glob

import urllib
import os
import zipfile

## Downloading and  Data Loading
We download a dataset consisting of noisy flywing images. The downloaded data creates a folder `MyData` and extracts within it three subfolders `train`, `val` and `test`. These folders have subfolders `raw` and `gt`. In `train/raw` folder, there are `1428` raw images and only the first `5` of raw images have ground truth annotations in folder `train/gt`. Similarly, in `val/raw` folder, there are `252` raw images and only the first `2` of raw images have ground truth annotations in folder `val/gt`. The `test` folder only has `raw` subfolder since these are the images we want denoised and segmented outputs for.

In [None]:
# create a folder for our data
if not os.path.isdir('./data'):
    os.mkdir('data')
    
link = 'https://owncloud.mpi-cbg.de/index.php/s/9ok6q1azniMJobq/download'

# check if data has been downloaded already
zipPath="data/MyData.zip"
if not os.path.exists(zipPath):
    #download and unzip data
    data = urllib.request.urlretrieve(link, zipPath)
    with zipfile.ZipFile(zipPath, 'r') as zip_ref:
        zip_ref.extractall("data")

In [None]:
# Loading of the training images
train_images = imread(sorted(glob("data/MyData/train/raw/*.tif")))
val_images = imread(sorted(glob("data/MyData/val/raw/*.tif")))
test_images = imread(sorted(glob("data/MyData/test/raw/*.tif")))
available_train_masks = imread(sorted(glob("data/MyData/train/gt/*.tif")))
available_val_masks = imread(sorted(glob("data/MyData/val/gt/*.tif")))

### Create zero images for missing masks

Here we create zero images for those training and validation images for which segmentation masks are not present. Then we use these zero images along with the images for which segmentation annotations are available for training the DenoiSeg network.

In [None]:
blank_images_train = np.zeros((train_images.shape[0]-available_train_masks.shape[0], available_train_masks.shape[1], available_train_masks.shape[2]))
blank_images_val = np.zeros((val_images.shape[0]-available_val_masks.shape[0], available_val_masks.shape[1], available_val_masks.shape[2]))
blank_images_train = blank_images_train.astype("uint16")
blank_images_val = blank_images_val.astype("uint16")

train_masks = np.concatenate((available_train_masks,blank_images_train), axis = 0)
val_masks = np.concatenate((available_val_masks,blank_images_val), axis = 0)

## Data Preprocessing
We do some necessary data preprocessing in the cell below such as augmenting training data; extracting foreground, background and border classes from our training and validation masks.

In [None]:
# Here we generate patches from images and apply augmentation
X_final, Y_final = generate_patches_from_list([X_frac], [Y_frac], augment=True, shuffle=False, shape=(128, 128))
X_val_final, Y_val_final = generate_patches_from_list([val_images], [val_masks], augment=False, shape=(128, 128))

# Here we add the channel dimension to our input images.
# Dimensionality for training has to be 'SYXC' (Sample, Y-Dimension, X-Dimension, Channel)
X_final = X_final[... ,np.newaxis]
Y_final = convert_to_oneHot(Y_final, n_classes=3)

X_val_final = X_val_final[... ,np.newaxis]
Y_val_final = convert_to_oneHot(Y_val_final, n_classes=3)

print("Shape of X:     {}".format(X_final.shape))
print("Shape of Y:     {}".format(Y_final.shape))
print("Shape of X_val: {}".format(X_val_final.shape))
print("Shape of Y_val: {}".format(Y_val_final.shape))

Next we look at a single sample. In the first column we show the input image, in the second column the background segmentation, in the third column the foreground segmentation and in the last column the border segmentation.

With the parameter `sample` you can choose different training patches. You will notice that not all of them have a segmentation ground truth.

In [None]:
sample = 0
plt.figure(figsize=(20,5))
plt.subplot(1,4,1)
plt.imshow(X_final[sample,...,0])
plt.axis('off')
plt.title('Raw training image')
plt.subplot(1,4,2)
plt.imshow(Y_final[sample,...,0], vmin=0, vmax=1, interpolation='nearest')
plt.axis('off')
plt.title('1-hot encoded background')
plt.subplot(1,4,3)
plt.imshow(Y_final[sample,...,1], vmin=0, vmax=1, interpolation='nearest')
plt.axis('off')
plt.title('1-hot encoded foreground')
plt.subplot(1,4,4)
plt.imshow(Y_final[sample,...,2], vmin=0, vmax=1, interpolation='nearest')
plt.axis('off')
plt.title('1-hot encoded border')

### Configure network parameters

In [None]:
train_batch_size = 128
train_steps_per_epoch = min(400, max(int(X_final.shape[0]/train_batch_size), 10))

In [None]:
### In the next cell, you can choose how much relative importance (weight) to assign to denoising 
### and segmentation tasks by choosing appropriate value for denoiseg_alpha (between 0 and 1; with 0 being
### only segmentation and 1 being only denoising. Here we choose denoiseg_alpha = 0.5)

In [None]:
conf = DenoiSegConfig(X_final, unet_kern_size=3, n_channel_in=1, n_channel_out=4, relative_weights = [1.0,1.0,5.0],
                      train_steps_per_epoch=train_steps_per_epoch, train_epochs=120, 
                      batch_norm=True, train_batch_size=128, unet_n_first = 32, 
                      unet_n_depth=4, denoiseg_alpha=0.5, train_tensorboard=False)

vars(conf)

In [None]:
model_name = 'DenoiSeg_Practicalfinal_n20'
basedir = 'models'
model = DenoiSeg(conf, model_name, basedir)

In [None]:
history = model.train(X_final, Y_final, (X_val_final, Y_val_final))

In [None]:
history.history.keys()

In [None]:
plot_history(history, ['loss', 'val_loss'])

## Computing Threshold Value
The network predicts 4 output channels:
1. The denoised input.
2. The foreground likelihoods.
3. The background likelihoods.
4. The border likelihoods.

We will threshold the foreground prediction image to obtain object segmentations. The optimal threshold is determined on the validation data. Additionally we can optimize the threshold for a given measure. In this case we choose the Average Precision (AP) measure.

In [None]:
threshold, val_score = model.optimize_thresholds(val_images[:available_val_masks.shape[0]].astype(np.float32), val_masks, measure=measure_precision(), axes='YX')

print("The higest score of {} is achieved with threshold = {}.".format(np.round(val_score, 3), threshold))

## Test Data
Finally we load the test data and run the prediction.

In [None]:
denoised_images = []
segmented_images = []

for i in range(test_images.shape[0]):
    predicted_channels = model.predict(test_images[i].astype(np.float32), axes='YX')
    denoised_images.append(predicted_channels[...,0])
    segmented_images.append(compute_labels(predicted_channels, threshold))

### Visualize the results

In [None]:
sl = 6
fig = plt.figure()
plt.figure(figsize=(20,10))
plt.subplot(1, 4, 1)
plt.imshow(test_images[sl])
plt.title("Raw image")

plt.subplot(1, 4, 2)
plt.imshow(denoised_images[sl])
plt.title("Predicted denoised image")

plt.subplot(1, 4, 3)
plt.imshow(segmented_images[sl], cmap = "viridis")
plt.title("Predicted segmentation")

plt.show()

### Export your model for Fiji

In [None]:
model.export_TF(name='DenoiSeg - YourOwnData Example', 
                description='This is the 2D DenoiSeg example trained on YourOwnData in python.', 
                authors=["You"],
                test_img=X_val[0,...,0], axes='YX',
                patch_shape=(128, 128))