In [None]:
import Ipynb_importer
# %load utils.py
import os
import time
import math
import torch
import random
import logging
import librosa
import argparse
import pescador
import numpy as np
from config import *
from torch import autograd
from torch.autograd import Variable
import matplotlib
#matplotlib.use('agg')
import matplotlib.pyplot as plt
import soundfile as sf
import csv
from scipy import signal

freq = 400000

LOGGER = logging.getLogger('specgan')
LOGGER.setLevel(logging.DEBUG)

def make_path(output_path):
    if not os.path.isdir(output_path):
        os.makedirs(output_path)
    return output_path

traindata = DATASET_NAME
output = make_path(OUTPUT_PATH)

def time_since(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


        
def save_and_plot(epoch_samples, epoch, output_dir):
    """
    Save output samples.
    """
    sample_dir = make_path(os.path.join(output_dir, str(epoch))) 
    fs = freq
    # save_timeseries
    data = []
    for i in range(len(epoch_samples)):
        data.append( spectrogram_to_time(epoch_samples[i])[1] )
    
    # plot spectrogram
    cm=plt.cm.get_cmap('plasma')
    plt.figure(figsize=(30,10))

    plt.pcolormesh(epoch_samples[0],cmap=cm)
    plt.savefig(sample_dir + '/' + str(epoch) + "spectrogram_0.png") 
    
    plt.clf()
    for i in range(len(epoch_samples)):
        plt.subplot(2, 5, i+1)
        plt.pcolormesh(epoch_samples[i],cmap=cm)
    plt.savefig(sample_dir + '/' + str(epoch) + "spectrogram.png") # put this command forward show
    
    # plot time series
    plt.clf()
    l = len(data[0])
    plt.plot(np.arange(0,l/fs,1/fs),data[0])
    plt.savefig(sample_dir + '/' + str(epoch) + "timeseries_0.png")
    
    plt.clf()
    for i in range(len(data)):
        plt.subplot(2, 5, i+1)
        plt.plot(np.arange(0,l/fs,1/fs),data[i])
    plt.savefig(sample_dir + '/' + str(epoch) + "timeseries.png") # put this command forward show
    
    plt.cla()
    plt.close("all")
    return sample_dir

    
    

#### signal processing ####
# Based on https://github.com/librosa/librosa/issues/434

def _stft_forgriffin(y):
    f, t, Zxx = signal.stft(y, fs = freq, window='blackmanharris', nperseg = 256, noverlap = 128)
    return Zxx


def _griffin_lim(S):
    m = np.mean(S,axis=0)
    S = np.vstack((m,S))
    n = np.mean(S,axis=1)
    n = n.reshape(-1,1)
    S = np.hstack((S,n))
    
    angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) 
    S_complex = np.abs(S).astype(np.complex)
    
    for i in range(griffin_lim_iters):
        if i > 0:
            angles = np.exp(1j * np.angle(_stft_forgriffin(y[1])))
        y = _istft(S_complex * angles)
        
    return y

def _istft(S):
    return signal.istft(S, fs = freq, window='blackmanharris', nperseg = 256, noverlap = 128)

# spectrogram to time
def spectrogram_to_time(spectrogram):
    tunenum = 4.2
    spectrogram = np.power(10,(spectrogram * tunenum - tunenum))
    estimate = _griffin_lim(spectrogram)
    return estimate


# Adapted from @jtcramer https://github.com/jtcramer/wavegan/blob/master/sample.py.
def sample_generator(data, window_length=1024, fs=freq):
    """
    Audio sample generator
    """
    try:
        sample = data 
        
    except Exception as e:
        LOGGER.error("Could not load {}: {}".format(filepath, str(e)))
        raise StopIteration


    # generator
    while True:
        
        tune_num = 4.2
        fr, times, Zxx = signal.stft(data, fs=freq, window='blackmanharris', nperseg=256, noverlap=128)  
        X_norm = ( np.log10(abs(Zxx[1:,:-1])) + tune_num ) / tune_num
        X_norm = X_norm.astype('float32')  
        
        assert not np.any(np.isnan(X_norm))

        yield {'X': X_norm}

def batch_generator(datas, batch_size):
    streamers = []
    for i in range(len(datas)):
        s = pescador.Streamer(sample_generator, datas[i])
        streamers.append(s)
    mux = pescador.ShuffledMux(streamers)
    batch_gen = pescador.buffer_stream(mux, batch_size)

    return batch_gen


def numpy_to_var(numpy_data, cuda):
    """
    Convert numpy array to Variable.
    """
    data = numpy_data[:, np.newaxis, : , :]
    data = torch.Tensor(data)
    if cuda:
        data = data.cuda()
    return Variable(data, requires_grad=False)
        


def plot_loss(D_loss, G_loss, save_path):
    

    save_path = os.path.join(save_path, "loss_curve.png")

    x = range(len(D_loss))

    y1 = D_loss
    y2 = G_loss

    plt.plot(x, y1, label='D_loss')
    plt.plot(x, y2, label='G_loss')

    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    plt.legend(loc=4)
    plt.grid(True)
    plt.tight_layout()

    plt.savefig(save_path)
    
    


def parse_arguments():
    """
    Get command line arguments
    """
    parser = argparse.ArgumentParser(description='Train a WaveGAN on a given set of audio')

    parser.add_argument('-ms', '--model-size', dest='model_size', type=int, default=4,
                        help='Model size parameter used in WaveGAN')
    
    parser.add_argument('-pssf', '--phase-shuffle-shift-factor', dest='shift_factor', type=int, default=2,
                        help='Maximum shift used by phase shuffle')
    
    parser.add_argument('-psb', '--phase-shuffle-batchwise', dest='batch_shuffle', action='store_true',
                        help='If true, apply phase shuffle to entire batches rather than individual samples')
    
    parser.add_argument('-ppfl', '--post-proc-filt-len', dest='post_proc_filt_len', type=int, default=512,
                        help='Length of post processing filter used by generator. Set to 0 to disable.')
    
    parser.add_argument('-lra', '--lrelu-alpha', dest='alpha', type=float, default=0.2,
                        help='Slope of negative part of LReLU used by discriminator')
    
    parser.add_argument('-vr', '--valid-ratio', dest='valid_ratio', type=float, default=0.1,
                        help='Ratio of audio files used for validation')
    
    parser.add_argument('-tr', '--test-ratio', dest='test_ratio', type=float, default=0.1,
                        help='Ratio of audio files used for testing')
    
    parser.add_argument('-bs', '--batch-size', dest='batch_size', type=int, default=BATCH_SIZE,
                        help='Batch size used for training')
    
    parser.add_argument('-ne', '--num-epochs', dest='num_epochs', type=int, default=EPOCHS, help='Number of epochs')
    
    parser.add_argument('-ng', '--ngpus', dest='ngpus', type=int, default=4,
                        help='Number of GPUs to use for training')
    
    parser.add_argument('-ld', '--latent-dim', dest='latent_dim', type=int, default=100,
                        help='Size of latent dimension used by generator')
    
    parser.add_argument('-eps', '--epochs-per-sample', dest='epochs_per_sample', type=int, default=SAMPLE_EVERY,
                        help='How many epochs between every set of samples generated for inspection')
    
    parser.add_argument('-ss', '--sample-size', dest='sample_size', type=int, default=SAMPLE_NUM,
                        help='Number of inspection samples generated')
    
    parser.add_argument('-rf', '--regularization-factor', dest='lmbda', type=float, default=10.0,
                        help='Gradient penalty regularization factor')
    
    parser.add_argument('-lr', '--learning-rate', dest='learning_rate', type=float, default=0.0002,
                        help='Initial ADAM learning rate')
    
    parser.add_argument('-bo', '--beta-one', dest='beta1', type=float, default=0.5, help='beta_1 ADAM parameter')
    
    parser.add_argument('-bt', '--beta-two', dest='beta2', type=float, default=0.9999, help='beta_2 ADAM parameter')
    
    parser.add_argument('-v', '--verbose', dest='verbose', action='store_true')
    
    parser.add_argument('-audio_dir', '--audio_dir', dest='audio_dir', type=str, default=traindata, help='Path to directory containing audio files')
    
    parser.add_argument('-output_dir', '--output_dir', dest='output_dir', type=str, default=output, help='Path to directory where model files will be output')
    
    args = parser.parse_args(args=[])
   
    return vars(args)