# Train 3DUnet for Image Enhancement

- Create conda environment from environment.yaml by following command: ```conda env create -f environment.yml```

- Activate environment 
```conda activate 3dunet```

- Install pytorch-3dunet, following instructions from https://github.com/wolny/pytorch-3dunet.git


## Load config

You need to make the config file in yaml for training and test 

Some parameters in train_config.yaml:
- checkpoint_dir: path to save the trained model
- train: filepath: 

   path to the dataset. Each dataset is a folder includs (possibly) 3 subfolders, one subfolder ("raw_iamges") for store all raw images, one folder ("ground_truth") to store all groundtruth images, one folder ("weights") to store pixel weights if using weigths. For training dataset: "raw_images" and "groundtruth" are necessory, "weights" is optional. For test dataset: only "raw_images" is mandatory

In [21]:
from train_regression import train_3dunet_regression, load_config_yaml

# Path to config file
config_file = "train_config.yaml"

# Load config
config = load_config_yaml(config_file)

# get config for data loaders
config_loaders = config['loaders']

# Visualize Dataset

### Visualize raw images

Read image array from tif file and visualize them

- Give the path to dataset folder, such as "path/to/dataset"
- Use TIFDataset.load_dataset_files() to load all tif files in the dataset, retuern a list of [(raw_image_file, groundtruth_file_path, weight_file_path(optional)), ...
- Choose the file to visualize by given the index. Print all files using 

In [22]:
import tifffile
import numpy as np
from visualize import display_sequence
from dataloader import TIFDataset

%matplotlib inline

In [23]:
# Get image file paths  
train_file_path = config_loaders['train']['file_paths'][0]
file_paths = TIFDataset.load_dataset_files(train_file_path)

# Uncomment following lines to print all files 
# for rf, gf, _ in file_paths:
#     print(f"*---{rf}\n |--{gf}")

# Choose a file to visualize 
# get image path 
raw_file_path = file_paths[0][0]
gt_file_path = file_paths[0][1]
print(f"- Visualize following files: \n\t{raw_file_path}\n\t{gt_file_path}")

# read image data
raw_im_arr = tifffile.imread(raw_file_path)
gt_im_arr = tifffile.imread(gt_file_path)

print(f"- Image shape \n\t raw image: {raw_im_arr.shape} \n\t Gt image {gt_im_arr.shape}")

- Visualize following files: 
	/Users/w.zhao/Projects/MemSeg/Dataset/dataset01/train/raw_images/201223_RL57M_pos1.tif
	/Users/w.zhao/Projects/MemSeg/Dataset/dataset01/train/ground_truth/201223_RL57M_pos1_gt.tif
- Image shape 
	 raw image: (64, 406, 406) 
	 Gt image (64, 406, 406)


In [24]:
# Visualize 3d image
display_sequence(raw_im_arr)

interactive(children=(IntSlider(value=31, description='frame', max=63), Output()), _dom_classes=('widget-inter…

<function visualize.display_sequence.<locals>._show(frame=(0, 63))>

# Train Model

- Use train_3dunet_regression to train the model

In [35]:
# Start to train
train_3dunet_regression(config_file)

2024-04-19 10:41:14,282 [MainThread] INFO UNetTrainer - Number of learnable params 4119227
2024-04-19 10:41:14,282 [MainThread] INFO TIFDataset - Creating training and validation set loaders...
2024-04-19 10:41:14,283 [MainThread] INFO TIFDataset - Loading train set from: /Users/w.zhao/Projects/MemSeg/Dataset/dataset01/train/raw_images/201223_RL57M_pos1.tif...
2024-04-19 10:41:14,299 [MainThread] INFO Dataset - Slice builder config: {'name': 'SliceBuilder', 'patch_shape': [16, 128, 128], 'stride_shape': [32, 32, 32]}
2024-04-19 10:41:14,299 [MainThread] INFO TIFDataset - Number of patches: 300
2024-04-19 10:41:14,300 [MainThread] INFO TIFDataset - Loading train set from: /Users/w.zhao/Projects/MemSeg/Dataset/dataset01/train/raw_images/201223_RL57M_pos2.tif...
2024-04-19 10:41:14,315 [MainThread] INFO Dataset - Slice builder config: {'name': 'SliceBuilder', 'patch_shape': [16, 128, 128], 'stride_shape': [32, 32, 32]}
2024-04-19 10:41:14,315 [MainThread] INFO TIFDataset - Number of patch

KeyboardInterrupt: 

# Prediction