# Train Model

### Imports

In [None]:
import sys
import os
from config import Config
sys.path.append(Config.root_path)

import importlib
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# Load Model
from Python.model.segmentation_model import deeplabv3plus
# Load Dataset & Preprocessing
from Python.data_processing.utils import get_png_paths_from_dir
from Python.data_processing.utils import load_weight_map
from Python.data_processing.utils import split_dataset_paths
from Python.data_processing.data_generator import generate_image_dataset_from_files
from Python.data_processing.data_generator import augment_dataset
# Print Model Prediction
from Python.data_processing.compare_predictions import show_predictions
# Save Model
from Python.data_processing.save_model import write_model_to_disk


print(tf.config.list_physical_devices('GPU'))

### Prepare dataset

In [None]:
image_files = get_png_paths_from_dir(Config.image_path)
mask_files = get_png_paths_from_dir(Config.segmentation_path)
weight_map = load_weight_map(Config.weight_map_path)
weights = np.zeros(Config.output_channels, dtype=float)
for key in weight_map.keys():
    weights[int(key)] = weight_map.get(key)

train_files, val_files, test_files = split_dataset_paths(
    image_files, 
    mask_files, 
    Config.train_size, 
    Config.val_size, 
    Config.test_size 
)
train_dataset, val_dataset, test_dataset = [
    generate_image_dataset_from_files(
        img_files, 
        msk_files,
        Config.batch_size,
        tf.data.AUTOTUNE, 
        Config.shuffle_size, 
        weights
    ) for img_files, msk_files in [train_files, val_files, test_files]
]
train_dataset = augment_dataset(train_dataset)

### Load Model

In [None]:
model = deeplabv3plus(
    Config.input_shape,
    Config.batch_size,
    Config.output_channels,
    Config.channels_low,
    Config.channels_high,
    Config.middle_repeat
)
model.compile(
    optimizer = 'adam',
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ["accuracy"]
)

### Compare Image, Segmentation Mask and Prediction

In [None]:
show_predictions(model, val_dataset)

### Model Training

Train Model

In [None]:
# Update any changes to Config
importlib.reload(sys.modules["config"])
from config import Config
# Train model
model_history = model.fit(
    train_dataset,
    epochs=Config.epochs,
    validation_data=val_dataset,
    shuffle=True
) 

Show Model Performance

In [None]:
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

plt.figure()
plt.plot(model_history.epoch, loss, 'r', label='Training accuracy')
plt.plot(model_history.epoch, val_loss, 'bo', label='Validation accuracy')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

### Save Model

In [None]:
train_names = [os.path.split(img_path)[1] for img_path in train_files[0]]
val_names = [os.path.split(img_path)[1] for img_path in val_files[0]]
test_names = [os.path.split(img_path)[1] for img_path in test_files[0]]
file_partitions = dict(
    "train", train_files, "validation", val_files, "test", test_files
)
write_model_to_disk(
    model, model_history, file_partitions, Config.model_dir_path, Config
)