In [None]:
import tensorflow as tf
import datetime, os
from model import *
from data import *

#### Unpack videos and labels

In [None]:
unpack_video('data/larynx/train', video_folder='video', image_folder='image/glottis', target_size=(256, 256))
unpack_tif('data/larynx/train', tif_folder='tifs', label_folder='label/glottis', target_size=(256, 256))

#### Initialize the model (U-Net)

In [None]:
model = unet()
# model_checkpoint = ModelCheckpoint('unet_larynx_{epoch:02d}-{loss:.2f}.hdf5', monitor='loss', verbose=1, save_best_only=True)
model_checkpoint = ModelCheckpoint('unet_larynx.hdf5', monitor='val_loss', verbose=1, save_best_only=True)



#### Prepare training dataset

In [None]:
import splitfolders

splitfolders.ratio("data/larynx/train/image", output="data/larynx/train/image_split",
    seed=1234, ratio=(.8, .2), group_prefix=None, move=False) 
splitfolders.ratio("data/larynx/train/label", output="data/larynx/train/label_split",
    seed=1234, ratio=(.8, .2), group_prefix=None, move=False) 


training_dataset = training_dataset_generator(2,
                                              'data/larynx/train/train_dataset/',
                                              'image',
                                              'label',
                                              save_to_dir=None)

val_dataset = training_dataset_generator(2,
                                              'data/larynx/train/validation_dataset/',
                                              'image',
                                              'label',
                                              save_to_dir=None)

In [None]:
# Set up TensorBoard callback
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)

# Early stopping callback
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=5,
    verbose=1,
    mode='min',
    restore_best_weights=True
)

# Train the model using the training dataset and TensorBoard callback
history = model.fit(training_dataset, 
          steps_per_epoch=10, 
          epochs=100, 
          callbacks=[model_checkpoint, tensorboard_callback, early_stopping],
          validation_data=val_dataset,
          validation_steps=10
        )

#### Testing the model

In [None]:
# testing the model
test_dataset = test_dataset_generator("data/larynx/test", num_image=5)
results = model.predict(test_dataset, 3, verbose=1)
saveResult("data/larynx/test", results)
