In [None]:
from os.path import join
from random import choice, randint

from pandas import DataFrame

from ai_tools import DataGenerator, ModelManager
from ai_tools.helpers import create_data_frame_from_path, split_stratified_into_train_val_test
from ai_tools.displays import  plot_history

from tensorflow.keras.callbacks import History

# Paths.
dataset_path: str = '/mnt/datasets/serialized_dataset'
logs_path: str = '/mnt/logs'
aim_logs_path: str = '/mnt/aim'
model_checkpoint_path: str = '/mnt/model-checkpoints'
model_histories: str = '/mnt/model_histories'
model_config_csv_log_path: str = '/mnt/model_settings.csv'

# Settings.
batch_size: int = 32

### Create dataset.

In [2]:
# Create dataset dataframe and split it into train, validation, and test.
df: DataFrame = create_data_frame_from_path(
    dataset_path,
    number_of_samples_for_each_class=2_000
)

df_train, df_val, df_test = split_stratified_into_train_val_test(df)  # type: DataFrame, DataFrame, DataFrame

# Store the data generator data frame for recreating the data generator if needed.
df_train.to_csv(join(logs_path, 'train_data.csv'))
df_val.to_csv(join(logs_path, 'val_data.csv'))
df_test.to_csv(join(logs_path, 'test_data.csv'))

In [3]:
# Create Generators.
train_data_generator: DataGenerator = DataGenerator(df_train, batch_size=batch_size)
val_data_generator: DataGenerator = DataGenerator(df_val, batch_size=batch_size)
test_data_generator: DataGenerator = DataGenerator(df_test, batch_size=batch_size)

In [4]:
train_data_generator.get_data_frame.head()

Unnamed: 0,path,instrument,pitch,instrument_label,pitch_label
0,/mnt/datasets/serialized_dataset/reed/reed_F#_...,reed,F#,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
1,/mnt/datasets/serialized_dataset/bass/bass_E_0...,bass,E,"[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ..."
2,/mnt/datasets/serialized_dataset/bass/bass_C_0...,bass,C,"[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
3,/mnt/datasets/serialized_dataset/synth/synth_D...,synth,D#,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ..."
4,/mnt/datasets/serialized_dataset/reed/reed_B_0...,reed,B,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ...","[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."


In [5]:
model_manager = ModelManager(
    path_to_csv_logs=model_config_csv_log_path,
    model_checkpoint_dir=model_checkpoint_path,
    aim_logs_dir=aim_logs_path,
    history_log_dir=model_histories
)

In [None]:
model_manager.build_model(
    num_conv_block=4,
    num_filters=64,
    dense_layer_size=64,
    num_dense_layers=1,
    use_separable_conv_layer=False,
    use_regularization=True,
    use_dropout_dense_layers=True,
    use_dropout_conv_blocks=True
)

accuracy: float = model_manager.train_and_optimize_model(
    train_data_generator,
    val_data_generator,
    test_data_generator,
    epochs=100
)

In [6]:
model_manager.load_model_at_best_epoch(1)

['epoch-01.pb', 'epoch-03.pb', 'epoch-05.pb', 'epoch-07.pb', 'epoch-08.pb']
