## Import required packages

In [None]:
import os
import datetime

import tensorflow as tf

from data.abc import ABCPreProcessor, ABCTokenizer
from models.symbolic.rnn import FolkLSTM

## Required files and directories

In [None]:
# Mention the path to the datastore
BASE_DIR = "/home/rithomas/project/AI-Music-Generation-Challenge-2020/"
ABC_TFRECORD_DIR = os.path.join("/home/rithomas/cache", "ABC", "6-8-Meter/")
PROCESSED_ABC_FILENAME = 'processed-abc-files'

#BASE_DIR = "/home/richhiey/Desktop/workspace/projects/AI_Music_Challenge_2020/AI-Music-Generation-Challenge-2020"
#ABC_TFRECORD_DIR = os.path.join("/home/richhiey/Desktop/workspace/projects/AI_Music_Challenge_2020", "tfrecords", "abc")
#PROCESSED_ABC_FILENAME = 'processed-abc-files'

## Load preprocessed dataset

In [None]:
preprocessor = ABCPreProcessor(ABC_TFRECORD_DIR, PROCESSED_ABC_FILENAME)
preprocessed_dataset = preprocessor.load_tfrecord_dataset()
print(preprocessed_dataset)

# Folk-LSTM

In [None]:
batch_size = 128
initial_learning_rate = 0.001
training = True
FOLK_LSTM_DIR = os.path.join(BASE_DIR, 'configs', 'lstm_6_8_meter')

data_dims = preprocessor.get_data_dimensions(ABC_TFRECORD_DIR)
data_dims['batch_size'] = batch_size
dataset = preprocessor.prepare_dataset(preprocessed_dataset, batch_size)
print(data_dims)

model = FolkLSTM(FOLK_LSTM_DIR, data_dims, ABC_TFRECORD_DIR, training, initial_learning_rate)
print(model.get_configs())
model.train(dataset)

# Generate 10000 double jigs!

In [None]:
from datetime import datetime
start = datetime.now()

n = 5
#n = 10000
start_tokens = ['<s>', 'M:6/8']
temperature = 1
model = FolkLSTM(FOLK_LSTM_DIR, data_dims, ABC_TFRECORD_DIR, False)

abc_file = open(os.path.join(ABC_TFRECORD_DIR, "10000_double_jigs.abc"), 'w', buffering=20*(1024**2))

for i in range(n):
    print('-----------------------------------------------------')
    tune = model.complete_tune(start_tokens, temperature)
    formatted_tune = 'X:' + str(i+1) + '\n' + tune[1] + '\n' + tune[2] + '\n' + 'L:1/8' + '\n' + ''.join(tune[3:len(tune)-1])
    print(formatted_tune)
    abc_file.write(formatted_tune + '\n\n')

abc_file.close()