In [None]:
import sys
sys.path.append(r"/home/vidarmarsh/CEZ_Mapping")

import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
from Python.legacy.data_generator import generate_image_dataset_from_files
from Python.legacy.data_generator import Data_Generator
from Python.legacy.segmentation_model import deeplabv3plus
from Python.legacy.utils import get_png_paths_from_dir
from Python.legacy.utils import load_weight_map
from Python.legacy.model_tools import load_model
from Python.legacy.utils import write_model_to_disk
from Python.legacy.data_pipeline import split_dataset
from Python.legacy.model_tools import compare_model_predictions
from Python.legacy.data_pipeline import augment_dataset
from Python.legacy.config import Config

print(tf.config.list_physical_devices('GPU'))

In [None]:
image_files = get_png_paths_from_dir(Config.image_path)
mask_files = get_png_paths_from_dir(Config.segmentation_path)
weight_map = load_weight_map(Config.weight_map_path)
weights = np.zeros(Config.output_channels, dtype=float)
for key in weight_map.keys():
    weights[int(key)] = weight_map.get(key)

dataset = generate_image_dataset_from_files(
    image_files, 
    mask_files, 
    Config.batch_size, 
    Config.shuffle_size, 
    weights
)
train_dataset, val_dataset, test_dataset = split_dataset(
    dataset, Config.train_size, Config.val_size, Config.test_size
)
train_dataset = augment_dataset(train_dataset)

In [None]:
model = deeplabv3plus(
  Config.input_shape, 
  Config.batch_size, 
  Config.output_channels,
  Config.channels_low,
  Config.channels_high,
  Config.middle_repeat
)
model.compile(
  optimizer='adam',
  loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  metrics = ["accuracy"]
)

In [None]:
def create_mask(pred_mask):
  pred_mask = tf.math.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]


def show_predictions(model, dataset=None, num=1):
  for image, mask, weight in iter(dataset.take(num)):
    compare_model_predictions(model, image[0], mask[0])

In [None]:
show_predictions(model, dataset, num=1)

In [None]:
model_history = model.fit(
  train_dataset, 
  epochs=20,
  validation_data=val_dataset,
  shuffle=True
)

In [None]:
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

plt.figure()
plt.plot(model_history.epoch, loss, 'r', label='Training loss')
plt.plot(model_history.epoch, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

In [None]:
write_model_to_disk(model, model_history, Config)