# Training with large datasets

In order to allow larger sets of data to be processed with limited system memory we take advantage of HDF5 and the [h5py package](https://www.h5py.org/).

This allows us to incrementally save our pre-processed and randomised data to disk, and then incrementally load the data using iterable tensors for training.

In contrast to the main example notebook `AstroNoise2Noise.ipynb`, in this example we train with two sets of subs from different targets resulting in substantially more data.

*Enhancements have been provided by Maria Pavlou as code changes to the base CDBDeep repo.*

In [16]:
# Allow reloading of CSBDeep modules following any code changes
%reload_ext autoreload
%autoreload 2

# A couple required imports
import numpy as np
from tifffile import imread
from pathlib import Path

# Pre-processing Training Data and Save

We have two sets of sub frames in the `data/astro` sub-folders `NGC6888` and `NGC7000`, acquired with the same optical train and equipment.

All images have been aligned and calibrated and saved in the tiff format.

The example data can be downloaded [here](https://1drv.ms/u/s!AvWEkn9Anb_Nq9Aw52Xs3LuYEcq_rg?e=EexXxL)

Place the train images in sub-folders under the `data/astro` folder.

In [17]:
# Setup Data Parameters:
# Root Data path
basepath=Path('data/astro')
# Train Data path/s
source_dirs=['NGC6888','NGC7000']
# Image file pattern. Note: only formats supported by imread currently
pattern='*.tiff'
# Image patch size
patchsize=64
# Training data output savefile path & name
training_data_name="_".join(str(s) for s in source_dirs) + "_p{0}".format(patchsize) + '_NoPreProcessor' + '_NormPerc'
training_data_filename=training_data_name + '.hdf5'
save_file=basepath/training_data_filename


In [18]:
# Make an estimate of the number of non-overlapping patches for the images we have, sampling from the first we find.
first_image_file = list((basepath/source_dirs[0]).glob(pattern))[0]
sampleimage = imread(first_image_file)
n_patches_per_image=np.int(sampleimage.shape[0]/patchsize)*np.int(sampleimage.shape[1]/patchsize)

# Save pre-processed data to HDF5

No we use the `create_patches_hdf5` helper to create the sampled patches from the sub frame images, and store these directly to disk as an HDF5 file.

**Note:**

* Saving to disk will take longer, and significantly longer depending on your system configuration
* Randomised indexing cannot be used when writing to HDF5 Datasets. A secondary read/write process is used to achieve a compromised shuffle on the raw data. This can take some time, but leads to improved performance during training. A good tradeoff.

In [None]:
from csbdeep.data import RawData, create_patches, create_patches_hdf5, norm_percentiles, norm_reinhard
from csbdeep.data import NoPreProcessor, ReinhardPreProcessor

# Load image pairs for Noise2Noise processing, each image paired against every other at most once.
raw_data = RawData.from_folder_n2n(basepath, source_dirs=source_dirs, axes='YXC', pattern=pattern, preprocessor=NoPreProcessor(), imageloader=None)

# Create patch data from image pairs with parameters,
# normalization set as norm_percentiles() by default, optionally set to None, norm_reinhard() or other custom
create_patches_hdf5(
    raw_data, 
    patch_size=(patchsize,patchsize,3),
    normalization=norm_percentiles(),
    n_patches_per_image=n_patches_per_image,
    save_file=save_file,
    patch_filter=None,
    overlap=False,
    shuffle=True)

# Load Training Data
Here we load the training data from a save file created in earlier steps.
As data is loaded we can also split the data into training `X,Y` and validation `X_val,Y_val` sets.

In [None]:
from csbdeep.data.generate_hdf5 import HDF5Data

# Train/Validation split %
validation_split=0.1
# Select 1st channel initially
channel_slice = slice(0,1)
# Since we have shuffles the raw data we disable this for performance
hdf5_shuffled_read = False

# Load saved training and validation from HDF5 data with iterable wrapper object and channel selection
train_data, val_data = HDF5Data.from_hdf5(save_file, validation_split=validation_split, channels=channel_slice, shuffled_read=hdf5_shuffled_read)

print('Train Data Shape =', train_data.shape)
print('Val Data Shape =', val_data.shape)

# Configure and Train the Learning Model

Here we configure the training model parameters. 
Training will be done for each color channel separately and saved with individual names based on:

* `training_data_name`
* `model_base_name`
* `channel_name[i]`

## Training debug tools

You can monitor the progress during training with [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard) by starting it from the current working directory:

    $ tensorboard --logdir=.

Then connect to [http://localhost:6006/](http://localhost:6006/) with your browser.

In [None]:
# Importing the deep learning framework
from csbdeep.models import Config, HDF5CARE

# Probabilistic training will be used as this yields better results
probabilistic=True
# Reccomended as large a batch size as can fit into GPU memory for your patch size.
train_batch_size = 128
# The number of training epochs to execute over all
train_epochs=50
# The number of passes over the training data that should be completed over the number of epochs
train_passes=4
train_steps_per_epoch = int(np.ceil(((len(train_data)*train_passes)/train_epochs)/train_batch_size))

# Since we are training each channel separately: n_channel_in, n_channel_out = 1, 1
config = Config('SYXC', n_channel_in=1, n_channel_out=1, unet_kern_size=3, probabilistic=probabilistic, train_steps_per_epoch=train_steps_per_epoch, train_epochs=train_epochs, train_batch_size=train_batch_size)

# Give a name for the model
version=1
model_base_name = "_PRB-{0}_B{1}_SPE{2}_E{3}_V{4}".format(probabilistic,train_batch_size,train_steps_per_epoch,train_epochs,version)
model_name = training_data_name + model_base_name

skipindex = [ ]
channel_names=['R', 'G', 'B']
for i, channel in enumerate(channel_names):
    if i in skipindex:
        continue
    
    # update the HDF5 iterable channel slice selection
    train_data.set_channel(slice(i,i+1))
    val_data.set_channel(slice(i,i+1))

    # Generate a model name for each channel
    full_model_name = model_name + '_' + channel
    print("Train model name: ",model_name)
    
    # Create the Learning Model from the CARE framework with configuration
    model = HDF5CARE(config, name=full_model_name, basedir='models')
        
    # Train the model and capture history
    history = model.train(
        XY_data=train_data, validation_data=val_data)
    
    # Save the model
    model.export_TF()

# Load and De-noise an Image 

Here we load an example image to de-noise using the trained set of RGB models.

The image to be de-noised must be normalized in the same way as the training data.
Then for each channel the model is used to predict the de-noised output and these are then saved as a single RGB image.

The example test image used can be downloaded [here](https://1drv.ms/u/s!AvWEkn9Anb_Nq9Aw52Xs3LuYEcq_rg?e=EexXxL).

Place the test image in the `data/astro/test` folder.

In [None]:
# Specify the test file
test_file_name='CrescentNebula-NoSt-Deep.tiff'

testfilepath=basepath/'test'/test_file_name
x = imread(testfilepath)
testaxes = 'YX'

print('Test Image size =', x.shape)
print('Test Image axes =', testaxes)

In [None]:
from csbdeep.data import PercentileNormalizer, PadAndCropResizer, ReinhardNormalizer, NoNormalizer

channel_names=['R', 'G', 'B']
output_denoised = []
for i, channel in enumerate(channel_names):
    full_model_name = model_name + '_' + channel
    
    # Load the model for the specific channel
    print("Loading model:", full_model_name)
    model = HDF5CARE(config=None, name=full_model_name, basedir='models')

    # Predict/de-noise the image channel with the corresponding trained model
    # Default PercentileNormalizer is used to match the normalization used to train the model
    output_denoised.append(
        model.predict(x[:,:,i],testaxes, normalizer=PercentileNormalizer(), resizer=PadAndCropResizer(), n_tiles = (2, 4))
    )

In [None]:
# Load astropy library for saving the de-noised image in fits format
from astropy.io import fits

output_file_name = model_name + '_RGB_' + Path(test_file_name).stem + '.fits'
output_file_path = basepath/'test'/output_file_name
hdu = fits.PrimaryHDU(output_denoised)
hdul = fits.HDUList([hdu])
hdul.info()
hdul.writeto(output_file_path)
print("Output file saved:", output_file_path)