## Import required libraries and helper code

In [None]:
import os
import tensorflow as tf

from data.abc import ABCPreProcessor
from models.symbolic.transformer import FolkTransformer
from models.symbolic.rnn import FolkLSTM

## Base directory paths and names

In [None]:
# Mention the path to the datastore
BASE_DIR = "/home/rithomas/project/AI-Music-Generation-Challenge-2020/"
TFRECORD_PATH = os.path.join("/home/rithomas/cache", "ABC")
PROCESSED_ABC_FILENAME = 'processed-abc-files'

## Load datasets and prepare for use

In [None]:
DOUBLE_JIGS_DATASET_DIR = os.path.join(TFRECORD_PATH, 'Double-Jigs')
SIXEIGHT_METER_TUNES_DIR = os.path.join(TFRECORD_PATH, '6-8-Meter')
THESESSION_TUNES_DIR = os.path.join(TFRECORD_PATH, 'TheSession-Data')

preprocessor = ABCPreProcessor(DOUBLE_JIGS_DATASET_DIR, PROCESSED_ABC_FILENAME)
double_jigs_dataset = preprocessor.load_tfrecord_dataset()

preprocessor = ABCPreProcessor(SIXEIGHT_METER_TUNES_DIR, PROCESSED_ABC_FILENAME)
sixeight_meter_tunes_dataset = preprocessor.load_tfrecord_dataset()

preprocessor = ABCPreProcessor(THESESSION_TUNES_DIR, PROCESSED_ABC_FILENAME)
thesession_dataset = preprocessor.load_tfrecord_dataset()

# Same preprocessing for all datasets, so reuse preprocessor
batch_size = 64
double_jigs_dataset = preprocessor.prepare_dataset(double_jigs_dataset, 16)
sixeight_meter_tunes_dataset = preprocessor.prepare_dataset(sixeight_meter_tunes_dataset, batch_size)
thesession_dataset = preprocessor.prepare_dataset(thesession_dataset, batch_size)

# Data dims same for all since same preprocessing is applied
data_dims_1 = preprocessor.get_data_dimensions(DOUBLE_JIGS_DATASET_DIR)
data_dims_2 = preprocessor.get_data_dimensions(SIXEIGHT_METER_TUNES_DIR)
data_dims_3 = preprocessor.get_data_dimensions(THESESSION_TUNES_DIR)

## Create LSTM and Transformer model instances

In [None]:
FOLK_LSTM_1_DIR = os.path.join(BASE_DIR, 'configs', 'lstm_double_jigs')
FOLK_LSTM_2_DIR = os.path.join(BASE_DIR, 'configs', 'lstm_6_8_meter')
FOLK_LSTM_3_DIR = os.path.join(BASE_DIR, 'configs', 'lstm_thesession')
lstm_1 = FolkLSTM(FOLK_LSTM_1_DIR, data_dims_1, DOUBLE_JIGS_DATASET_DIR)
lstm_2 = FolkLSTM(FOLK_LSTM_2_DIR, data_dims_2, SIXEIGHT_METER_TUNES_DIR)
lstm_3 = FolkLSTM(FOLK_LSTM_3_DIR, data_dims_3, THESESSION_TUNES_DIR)

FOLK_TRANSFORMER_1_DIR = os.path.join(BASE_DIR, 'configs', 'transformer_double_jigs')
FOLK_TRANSFORMER_2_DIR = os.path.join(BASE_DIR, 'configs', 'transformer_6_8_meter')
FOLK_TRANSFORMER_3_DIR = os.path.join(BASE_DIR, 'configs', 'transformer_thesession')      
transformer_1 = FolkTransformer(FOLK_TRANSFORMER_1_DIR, data_dims_1, DOUBLE_JIGS_DATASET_DIR)
transformer_2 = FolkTransformer(FOLK_TRANSFORMER_2_DIR, data_dims_2, SIXEIGHT_METER_TUNES_DIR)
transformer_3 = FolkTransformer(FOLK_TRANSFORMER_3_DIR, data_dims_3, THESESSION_TUNES_DIR)

# LSTM 1 - Training on double jigs

In [None]:
print(lstm_1.get_configs())
lstm_1.train(double_jigs_dataset)

# LSTM 2 - Training on tunes in 6/8 Meter from TheSession

In [None]:
print(lstm_2.get_configs())
lstm_2.train(sixeight_meter_tunes_dataset)

# LSTM 3 - Training on all tunes from TheSession

In [None]:
print(lstm_3.get_configs())
lstm_3.train(thesession_dataset)

# Transformer 1 - Training on double jigs

In [None]:
print(transformer_1.get_configs())
transformer_1.train(double_jigs_dataset)

# Transformer 2 - Training on tunes in 6/8 Meter from TheSession

In [None]:
print(transformer_2.get_configs())
transformer_2.train(sixeight_meter_tunes_dataset)

# Transformer 3 - Training on all tunes from TheSession

In [None]:
print(transformer_3.get_configs())
transformer_3.train(thesession_dataset)