# Train a new model with transfer learning

In [None]:
import os
import json
import time
from pathlib import Path
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util import Retry

# Scientific package imports
import imageio
import numpy as np
import tensorflow as tf
from skimage import io
import matplotlib.pyplot as plt

# Utils import
from shutil import copy
import zipfile
from tqdm import tqdm
import cgi
import tempfile

# AxonDeepSeg imports
try:
    from AxonDeepSeg.ads_utils import download_data
except ModuleNotFoundError:
    # Change cwd to project main folder 
    os.chdir("..")
    try :
        from AxonDeepSeg.ads_utils import download_data
    except:
        raise
except:
    raise
# If no exceptions were raised import all folders        
from AxonDeepSeg.config_tools import validate_config
from AxonDeepSeg.train_network import train_model
from AxonDeepSeg.apply_model import axon_segmentation
import AxonDeepSeg.ads_utils as ads

# reset the tensorflow graph for new training
tf.reset_default_graph()

%matplotlib inline

In [None]:
#  folder containing training data
path_training = Path("./training")

In [None]:
# Define path of the init model for tranfer learning (TEM)
dir_name_init = Path("default_TEM_model")
path_model_init = "../AxonDeepSeg/models/" / dir_name_init

# Define file name of network configuration
file_config = 'config_network.json'

# Load config file from init model
fname_config = os.path.join(path_model_init, file_config)
if os.path.exists(fname_config):
    with open(fname_config, 'r') as fd:
        config = json.loads(fd.read())
else:
    print("config file doens't exists")
    
print("The model used to init the config and weights is : " + str(path_model_init.resolve().absolute()))

# Define path to where the trained model will be saved
dir_name = Path(config["trainingset"] + '_' + time.strftime("%Y-%m-%d") + '_' + time.strftime("%H-%M-%S"))
path_model = "../models" / dir_name

# Create directory new model
if not os.path.exists(path_model):
    os.makedirs(path_model)

# Set number of epochs for new training
config["epochs"] = 1000
    
# Save config new model
fname_config = os.path.join(path_model, file_config)
with open(fname_config, 'w') as f:
        json.dump(config, f, indent=2)

print("This training session's model will be saved in the folder: " + str(path_model.resolve().absolute()))

#### 1.4. Launch the training procedure

The training can be launched by calling the *'train_model'* function. After each epoch, the function will display the loss and accuracy of the model. The model checkpoints will be saved according to the "checkpoint_period" parameter in "config".

In [None]:
# reset the tensorflow graph for new testing
tf.reset_default_graph()

# Train model
train_model(str(path_training), str(path_model), config, path_model_init=path_model_init)

#### 1.5. Monitor the training with Tensorboard

[TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard) can be used to monitor the training procedure (loss and accuracy graphs, gradients, activations, identify bugs, etc.). To run TensorBoard, activate ADS virtual environment and run:
```
tensorboard --logdir PATH_MODEL --port 6006
```
where `PATH_MODEL` corresponds to this notebook's `path_model` variable (folder where model is being trained), and `port` is the port number where the TensorBoard local web server will be sent to (e.g., port 6006). Once the command is run, open a web browser with the address:
```
http://localhost:6006/
```

#### 1.6. Resume training from checkpoint

To resume training from a checkpoint, change the "checkpoint" parameter in "config" from None to "loss" or "accuracy".

In [None]:
# path_model = "../models/Path_of_the_model" # Path of the model where the checkpoint is saved

# train_model(str(path_training), str(path_model), config)

### 2. Test the trained model
#### 2.1. Set the path of the test image to be segmented with the trained model

In [None]:
# Modify the lines below to use your image
path_img = Path("../data")
file_img = "image.png"

#### 2.2. Launch the image segmentation

The target resolution of the current version of the models are 0.1 for the **'default_SEM_model'** and 0.01 for the **'default_TEM_model'**. In this case, our test sample is a TEM brain sample of the mouse, so we set resampled_resolutions to 0.01.

For your own trained model, use a resampled_resolutions corresponding to the general_pixel_size that was set in the 01-guide_dataset_building notebook in section 1.1. "Define the parameters of the patch extraction" when you created the dataset.

In [None]:
# reset the tensorflow graph for new testing
tf.reset_default_graph()
prediction = axon_segmentation(path_img, file_img, path_model, config, acquired_resolution=0.1, resampled_resolutions=0.1, verbosity_level=3)

#### 2.3. Display the resulted segmentation

In [None]:
file_img_seg = 'AxonDeepSeg.png'  # axon+myelin segmentation

img_seg = ads.imread(path_img / file_img_seg)
img = ads.imread(path_img / file_img)
# Note: The arguments of the two function calls above use the pathlib syntax for path concatenation.

fig, axes = plt.subplots(1,2, figsize=(13,10))
ax1, ax2 = axes[0], axes[1]
ax1.set_title('Original image')
ax1.imshow(img, cmap='gray')
ax2.set_title('Prediction with the trained model')
ax2.imshow(img_seg,cmap='gray')
plt.show()