# DeepCAD-RT training pipeline            
<img src="https://github.com/STAR-811/DeepCAD-RT-old/blob/main/images/logo-new.png?raw=true" width = "650" height = "180" align=right />

This file will demonstrate the basic pipeline for training DeepCAD-RT. A TIFF file will be downloaded automatically to be the example data. More information about the method and relevant results can be found in the companion paper：

**Real-time denoising of fluorescence time-lapse imaging enables high-sensitivity observations of biological dynamics beyond the shot-noise limit. bioRxiv (2022).**

In [2]:
from deepcad.train_collection import training_class
from deepcad.movie_display import display
from deepcad.utils import get_first_filename,download_demo
import os
import glob
from tifffile import imread, TiffFile, imwrite
import matplotlib.pyplot as plt
import numpy as np
import h5py

## Select file(s) to be processed (download if not exist)
The `download_demo` function will download a demo file and return the full path of it. This demo file will be stored in `/datasets`. If you want to use your own data for training, please create a new folder in `/datasets` and copy your data into it. 
Then, just change `datasets_path` into the name of your dataset folder. All TIFF files inside the dataset folder will be used for training.

In [3]:
metainfo = h5py.File(os.getcwd()+'/00_Metafile/metafile.h5', 'r')

recordingDate = metainfo['sessionInfo/recordingDate'][()].decode()
mouseID = metainfo['sessionInfo/mouseID'][()].decode()
behaviorInfo = metainfo['sessionInfo/behaviorInfo'][()].decode()
recTarget = metainfo['sessionInfo/recTarget'][()].decode()
sensor = metainfo['sessionInfo/sensor'][()].decode()
fs = metainfo['sessionInfo/fs'][()]
resolution = metainfo['sessionInfo/resolution'][()]
analysisDate = metainfo['sessionInfo/analysisDate'][()].decode()


analysisDir = metainfo['sessionInfo/analysisDir'][()].decode()+'/01_ROI_detection_DeepCAD_all' 
datasets_path = analysisDir+'/suite2p/plane0/reg_tif/data_all'
metainfo.close()

print('recording data:', recordingDate)
print('analysis data:', analysisDate)
print('behaviorInfo:', behaviorInfo)
print('sensor:', sensor)
print('recTarget:', recTarget)
print('fs:', fs)
print('resolution:', resolution)

print('analysisDir:', analysisDir)
print('dataDir_for_DeepCAD:', datasets_path)

recording data: 
analysis data: 2023-05-26 16:03:05.694461
behaviorInfo: 
sensor: GCaMP7.09
recTarget: Soma
fs: 7.65
resolution: 2048
analysisDir: /mnt/ssd1/ysaito/suite2p-pipeline-main/02_analysis/i166-m2_Exp01/01_ROI_detection_DeepCAD_all
dataDir_for_DeepCAD: /mnt/ssd1/ysaito/suite2p-pipeline-main/02_analysis/i166-m2_Exp01/01_ROI_detection_DeepCAD_all/suite2p/plane0/reg_tif/data_all


## Set the parameters for training
Default setting shows the parameters for the demo file, which are also appropriate for most data. You can change these parameters according to your data and device. To visualize the training process, you can set the flags `visualize_images_per_epoch` and `save_test_images_per_epoch` according to your demands.

In [4]:
n_epochs = 20                # number of training epochs
GPU = '0'                   # the index of GPU you will use (e.g. '0', '0,1', '0,1,2')
train_datasets_size = 3000  # datasets size for training (how many 3D patches)
patch_xy = 150              # the width and height of 3D patches
patch_t = 150               # the time dimension (frames) of 3D patches
overlap_factor = 0.4        # the overlap factor between two adjacent patches
pth_dir = 'pth'           # the path for pth file and result images 
num_workers = 0             # if you use Windows system, set this to 0.

# Setup some parameters for result visualization during training period (optional)
visualize_images_per_epoch = False  # whether to show result images after each epoch
save_test_images_per_epoch = False  # whether to save result images after each epoch

##  Show the input low-SNR data  (optional)
Play an input video (optional). This will load the video into memory and it is not an indispensable step. OpenCV library was used for display.

In [5]:
display_images = False

if display_images:
    display_filename = get_first_filename(datasets_path)
    print('\033[1;31mDisplaying the first raw file -----> \033[0m')
    print(display_filename)
    display_length = 300     # how many frames to display
    # normalize the image and display
    display(display_filename, display_length=display_length, norm_min_percent=1, norm_max_percent=98)

## Create a training object
This will creat a training object by passing all parameters as a dictionary. Parameters not specified in the dictionary will use their default values.

In [6]:
train_dict = {
    # dataset dependent parameters
    'patch_x': patch_xy,                          # the width of 3D patches
    'patch_y': patch_xy,                          # the height of 3D patches
    'patch_t': patch_t,                           # the time dimension (frames) of 3D patches
    'overlap_factor':overlap_factor,              # overlap factor
    'scale_factor': 1,                            # the factor for image intensity scaling
    'select_img_num': 1000000,                    # select the number of frames used for training (use all frames by default)
    'train_datasets_size': train_datasets_size,   # datasets size for training (how many 3D patches)
    'datasets_path': datasets_path,               # folder containing files for training
    'pth_dir': pth_dir,                           # the path for pth file and result images 
    
    # network related parameters
    'n_epochs': n_epochs,                         # the number of training epochs
    'lr': 0.00005,                                # learning rate
    'b1': 0.5,                                    # Adam: bata1
    'b2': 0.999,                                  # Adam: bata2
    'fmap': 16,                                   # model complexity
    'GPU': GPU,                                   # GPU index
    'num_workers': num_workers,                   # if you use Windows system, set this to 0.
    'visualize_images_per_epoch': visualize_images_per_epoch,   # whether to show result images after each epoch
    'save_test_images_per_epoch': save_test_images_per_epoch    # whether to save result images after each epoch
}

tc = training_class(train_dict)

[1;31mTraining parameters -----> [0m
{'overlap_factor': 0.4, 'datasets_path': '/mnt/ssd1/ysaito/suite2p-pipeline-main/02_analysis/i166-m2_Exp01/01_ROI_detection_DeepCAD_all/suite2p/plane0/reg_tif/data_all', 'n_epochs': 20, 'fmap': 16, 'output_dir': './results', 'pth_dir': 'pth', 'onnx_dir': './onnx', 'batch_size': 1, 'patch_t': 150, 'patch_x': 150, 'patch_y': 150, 'gap_y': 90, 'gap_x': 90, 'gap_t': 90, 'lr': 5e-05, 'b1': 0.5, 'b2': 0.999, 'GPU': '0', 'ngpu': 1, 'num_workers': 0, 'scale_factor': 1, 'train_datasets_size': 3000, 'select_img_num': 1000000, 'test_datasize': 400, 'visualize_images_per_epoch': False, 'save_test_images_per_epoch': False, 'colab_display': False, 'result_display': ''}


## Start the training process

Here we lanuch the training process. The model of each epoch will be saved in the `/pth` folder.

In [7]:
%%time
tc.run()

[1;31mImage list for training -----> [0m
Total stack number ----->  1
Noise image name ----->  registered.tif
Noise image shape ----->  (8000, 2048, 2048)
[1;31mUsing 1 GPU(s) for training -----> [0m
[Epoch 1/20] [Batch 3388/3388] [Total loss: 325453.31, L1 Loss: 630.18, L2 Loss: 650276.44] [ETA: 8:00:01] [Time cost: 1525 s]         
[Epoch 2/20] [Batch 3388/3388] [Total loss: 486950.69, L1 Loss: 770.43, L2 Loss: 973130.94] [ETA: 7:25:50] [Time cost: 3052 s]         
[Epoch 3/20] [Batch 3388/3388] [Total loss: 321193.66, L1 Loss: 618.11, L2 Loss: 641769.19] [ETA: 6:59:15] [Time cost: 4578 s]          
[Epoch 4/20] [Batch 3388/3388] [Total loss: 865988.88, L1 Loss: 1026.68, L2 Loss: 1730951.12] [ETA: 6:34:31] [Time cost: 6105 s]       
[Epoch 5/20] [Batch 3388/3388] [Total loss: 270377.47, L1 Loss: 568.22, L2 Loss: 540186.75] [ETA: 6:04:11] [Time cost: 7628 s]         
[Epoch 6/20] [Batch 3388/3388] [Total loss: 166310.00, L1 Loss: 436.09, L2 Loss: 332183.91] [ETA: 6:06:13] [Time co