In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
#Directory of the python scripts that need to be imported - passed to sys.path.append()
SCRIPTS_PATH = '/content/drive/My Drive/TimbreTransformer/Scripts'

In [None]:
#The working directory passed to os.chdir()
DEFAULT_PATH = '/content/drive/My Drive/TimbreTransformer'

In [None]:
TRAIN_PARAMS = {'model_name':'flute_v1', 'target_audio':['Data/Flute.mp3'], 'train_audio':['Data/Violin.mp3','Data/Piano.mp3', 'Data/Guitar.mp3', 'Data/Cello.mp3']}

In [None]:
import sys
sys.path.append(SCRIPTS_PATH)
import os
os.chdir(DEFAULT_PATH)

In [None]:
import import_audio
import process_audio
import custom_loss
import tensorflow as tf
import numpy as np
import model
import json

In [None]:
gan = model.Model(TRAIN_PARAMS['model_name']).load_from_file()

# Load Data

In [None]:
#Creates list of Audio objects that the network aims to sound like
target_audio = [import_audio.Audio(file, n_fft=gan.model_params['n_fft'], srate=gan.model_params['srate'], shuffle_spec = True, shuffle_audio = True) for file in TRAIN_PARAMS['target_audio']]



In [None]:
#Creates list of Audio objects that the network aims to transform
train_audio =  [import_audio.Audio(file,n_fft=gan.model_params['n_fft'], srate=gan.model_params['srate']) for file in TRAIN_PARAMS['train_audio']]



# Load Model

In [None]:
def train(model, n_epochs = 10, n_batches = 10):

    #The accuracy of the generator model. Will skip training the discriminator when the generator has less than 0.50 accuracy
    gan_accuracy = 1

    loss_ratio = 1

    for i in range(n_epochs):
        print('Epoch: ', i+1,)

        if gan_accuracy <= 0.5:
            print("Generator accuracy is less than 50%, skipping discriminator training for current epoch")

        else:
            
            #This is the batch size that the partition function will create for each Audio object in target_audio and train_audio
            #The total batch size of the training data will be batch_size * len(target_audio), which is approximately equal to n_batches
            target_batch_size = int(n_batches/len(target_audio))
            training_batch_size = int(n_batches/len(train_audio))

            #Get real data from target_audio. This will be the data that the gan will train to emulate
            X_real = np.expand_dims(np.concatenate([process_audio.partition(audio.ft.spec, randomize=True, batch_size = target_batch_size, input_shape=model.input_shape)[0] for audio in target_audio], axis = 0), axis = -1)
            
            #Label the real data as 1
            y_real = np.array([1]*X_real.shape[0])


            
            #The data that will be passed to the generator. The generator will try to transform this data to make it sound like the target_audio
            X_gen = np.expand_dims(np.concatenate([process_audio.partition(audio.ft.spec, randomize=True, batch_size = training_batch_size, input_shape=model.input_shape)[0] for audio in train_audio], axis = 0), axis = -1)
            
            X_fake = model.generator.predict(X_gen)
            
            #Label the fake data as 0
            y_fake = np.array([0]*X_fake.shape[0])

            X = np.concatenate((X_real, X_fake), axis = 0)
            y = np.concatenate((y_real, y_fake), axis = 0)

            print('Training Discriminator')

            #keras fit() method returns a History object that contains the final loss values. 
            #This will be used to balance the training of the generator and discriminator 
            disc_history = model.discriminator.fit(X, y)

            disc_loss = disc_history.history['loss'][0]

        if (loss_ratio >= 2):

            print('Discriminator loss is more than 2x the generator loss, skipping generator training for current epoch')

        else:

            training_batch_size = int(n_batches/len(train_audio))

            #The input for training the gan. The model will try to transform this data and use the discriminator to judge how well it emulates the target_audio samples.
            #The discriminator does not train during this step
            X_gan = np.expand_dims(np.concatenate([process_audio.partition(audio.ft.spec, randomize=True, batch_size = training_batch_size, input_shape=model.input_shape)[0] for audio in train_audio], axis = 0), axis = -1)

            #Setting the y values to 1 makes the model try to emulate the target_audio. This will result in a higher loss if the discriminator judges the output as fake.
            y_gan = np.array([1]*X_gan.shape[0])

            print('Training Gan')


            gan_history = model.model.fit(X_gan, [y_gan, X_gan])

        gan_loss = gan_history.history['discriminator_loss'][0]
        gan_accuracy = gan_history.history['discriminator_binary_accuracy'][0]


        try:
            loss_ratio = disc_loss/gan_loss

        except:
            loss_ratio = 2

In [None]:
train(gan,n_epochs = 500, n_batches = 1000)

Epoch:  1
Training Discriminator
Training Gan
Epoch:  2
Generator accuracy is less than 50%, skipping discriminator training for current epoch
Training Gan
Epoch:  3
Generator accuracy is less than 50%, skipping discriminator training for current epoch
Training Gan
Epoch:  4
Generator accuracy is less than 50%, skipping discriminator training for current epoch
Training Gan
Epoch:  5
Generator accuracy is less than 50%, skipping discriminator training for current epoch
Training Gan
Epoch:  6
Generator accuracy is less than 50%, skipping discriminator training for current epoch
Training Gan
Epoch:  7
Generator accuracy is less than 50%, skipping discriminator training for current epoch
Training Gan
Epoch:  8
Generator accuracy is less than 50%, skipping discriminator training for current epoch
Training Gan
Epoch:  9
Generator accuracy is less than 50%, skipping discriminator training for current epoch
Training Gan
Epoch:  10
Generator accuracy is less than 50%, skipping discriminator tra

In [None]:
tf.keras.models.save_model(gan.model, os.path.join(gan.model_params['save_dir'], gan.model_name))

INFO:tensorflow:Assets written to: Flute_Classifier/flute_v1/assets
