import os
import sys
import pickle
import numpy as np
from tensorflow.keras.utils import Sequence
from tensorflow.keras.layers import (Dense, Input)
from tensorflow.keras.layers import (Activation, Dense, Input, Conv1D, Conv2D, MaxPooling2D, Reshape,
                                     Dropout, SpatialDropout1D, SpatialDropout2D)
from tensorflow.keras.models import Sequential, Model
import madmom
from madmom.processors import ParallelProcessor, SequentialProcessor
from madmom.audio.spectrogram import FilteredSpectrogramProcessor, LogarithmicSpectrogramProcessor, SpectrogramDifferenceProcessor
from madmom.audio.stft import ShortTimeFourierTransformProcessor
from madmom.audio.signal import SignalProcessor, FramedSignalProcessor
import warnings

import tensorflow.keras.backend as K
import tensorflow as tf

import tensorflow_addons as tfa

from modules.utils import PKL_PATH, DATA_PATH

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "10"
os.environ["TF_DUMP_GRAPH_PREFIX"] = 'tmp'


# GENERAL CONSTANTS
FPS = 100  # set the frame rate as FPS frames per second
MASK_VALUE = -1

lr = 0.05
num_epochs = 50
dropout_rate = 0.15
num_filters = 16
num_dilations = 11
kernel_size = 5
activation = 'elu'


class PreProcessor(SequentialProcessor):
    def __init__(
            self, frame_sizes=[2048],
            num_bands=[12],
            fps=FPS, log=np.log, add=1e-6, diff=None, start=None, stop=None, daug_rate=1.):
        # resample to a fixed sample rate in order to get always the same
        # number of filter bins
        sig = SignalProcessor(num_channels=1, sample_rate=44100, start=start, stop=stop)
        # process multi-resolution spec & diff in parallel
        multi = ParallelProcessor([])
        for frame_size, num_bands in zip(frame_sizes, num_bands):
            # split audio signal in overlapping frames
            # Update the FPS to create the data augmentation
            frames = FramedSignalProcessor(frame_size=frame_size, fps=np.int(np.round(fps * daug_rate)))
            # compute STFT
            stft = ShortTimeFourierTransformProcessor()
            # filter the magnitudes
            filt = FilteredSpectrogramProcessor(num_bands=num_bands)
            # scale them logarithmically
            spec = LogarithmicSpectrogramProcessor(log=log, add=add)
            # stack positive differences
            if diff:
                diff = SpectrogramDifferenceProcessor(positive_diffs=True,
                                                      stack_diffs=np.hstack)
            # process each frame size with spec and diff sequentially
            multi.append(SequentialProcessor((frames, stft, filt, spec, diff)))
        # instantiate a SequentialProcessor
        super(PreProcessor, self).__init__((sig, multi, np.hstack))


class Dataset(object):
    def __init__(self, path, name=None, audio_suffix='.flac', beat_suffix='.beats', daug_rate=1., start=None, stop=None):
        self.path = path
        if name is None:
            name = os.path.basename(path)
        self.name = name
        self.daug_rate = daug_rate
        self.start = start
        self.stop = stop
        # populate lists containing audio and annotation files
        audio_files = madmom.utils.search_files(
            self.path + '/audio', audio_suffix)
        annotation_files = madmom.utils.search_files(
            self.path + '/annotations/beats/', beat_suffix)
        # match annotation to audio files
        self.files = []
        self.audio_files = []
        self.annotation_files = []

        for audio_file in audio_files:
            matches = madmom.utils.match_file(audio_file, annotation_files,
                                              suffix=audio_suffix,
                                              match_suffix=beat_suffix)
            if len(matches):
                self.audio_files.append(audio_file)
                self.files.append(os.path.basename(audio_file[:-len(audio_suffix)]))
                if len(matches) == 1:
                    self.annotation_files.append(matches[0])
                else:
                    self.annotation_files.append(None)

    def __len__(self):
        return len(self.files)

    def pre_process(self, pre_processor, num_threads=1):
        self.x = []
        for i, f in enumerate(self.audio_files):
            sys.stderr.write('\rprocessing file %d of %d' % (i + 1, len(self.audio_files)))
            sys.stderr.flush()
            self.x.append(pre_processor(f))

    def load_splits(self, path=None, fold_suffix='.fold'):
        path = path if path is not None else self.path + '/splits'
        self.split_files = madmom.utils.search_files(path, fold_suffix, recursion_depth=1)
        # populate folds
        self.folds = []
        for i, split_file in enumerate(self.split_files):
            fold_idx = []
            with open(split_file) as f:
                for file in f:
                    file = file.strip()
                    # get matching file idx
                    try:
                        idx = self.files.index(file)
                        fold_idx.append(idx)
                    except ValueError:
                        # file could be not available, e.g. in Ballrom set a few duplicates were found
                        warnings.warn('no matching audio/annotation files: %s' % file)
                        continue
            # set indices for fold
            self.folds.append(np.array(fold_idx))

    def load_annotations(self, widen=None):
        self.annotations = []
        # self.tempo_annotations = []
        # self.downbeat_annotations = []
        for f in self.annotation_files:
            if f is None:
                beats = np.array([])
            else:
                beats = madmom.io.load_beats(f)
                if beats.ndim > 1:
                    beats = beats[:, 0]

            if self.stop is not None:
                beats = beats[beats <= self.stop]  # discard any after stop point

            if self.start is not None:
                beats = beats - self.start  # subtract the offset for the start point
                beats = beats[beats > 0]  # and remove any negative beats

            beats = beats * self.daug_rate  # update to reflect the data augmentation
            self.annotations.append(beats)
            # self.tempo_annotations.append(np.array([]))
            # self.downbeat_annotations.append(np.array([]))

    def add_dataset(self, dataset):
        self.files.extend(dataset.files)
        self.audio_files.extend(dataset.audio_files)
        self.annotation_files.extend(dataset.annotation_files)
        self.x.extend(dataset.x)
        self.annotations.extend(dataset.annotations)
        # self.tempo_annotations.extend(dataset.tempo_annotations)
        # self.downbeat_annotations.extend(dataset.downbeat_annotations)

    def dump(self, filename=None):
        if filename is None:
            filename = '%s/%s.pkl' % (self.path, self.name)
        pickle.dump(self, open(filename, 'wb'), protocol=2)


def create_training_dataset(dataset, daug_rate=1, start=None, stop=None):
    if 'bambuco' in dataset:
        audio_suffix = '.wav'

    else:
        audio_suffix = '.flac'

    db = Dataset('%s/%s' % (DATA_PATH, dataset), audio_suffix=audio_suffix,
                 beat_suffix='.beats', daug_rate=daug_rate, start=start, stop=stop)
    print('loading annotations')
    db.load_annotations()
    print('loading splits')
    db.load_splits()
    pp = PreProcessor(frame_sizes=[2048], num_bands=[12], fps=100, log=np.log, add=1e-6,
                      daug_rate=daug_rate, start=start, stop=stop)
    print('pre-processing')
    db.pre_process(pp)
    print('saving')
    db.dump('%s/%s.pkl' % (PKL_PATH, dataset))
    return db


class Fold(object):

    def __init__(self, folds, fold):
        self.folds = folds
        self.fold = fold

    @property
    def test(self):
        # fold N for testing
        return np.unique(self.folds[self.fold])

    @property
    def val(self):
        # fold N+1 for validation
        return np.unique(self.folds[(self.fold + 1) % len(self.folds)])

    @property
    def train(self):
        # all remaining folds for training
        train = np.hstack(self.folds)
        train = np.setdiff1d(train, self.val)
        train = np.setdiff1d(train, self.test)
        return train


def cnn_pad(data, pad_frames):
    """Pad the data by repeating the first and last frame N times."""
    pad_start = np.repeat(data[:1], pad_frames, axis=0)
    pad_stop = np.repeat(data[-1:], pad_frames, axis=0)
    return np.concatenate((pad_start, data, pad_stop))


def residual_block_v1(x, dilation_rate, activation, num_filters, kernel_size, padding, dropout_rate=0, name=''):
    original_x = x
    conv = Conv1D(num_filters, kernel_size=kernel_size,
                  dilation_rate=dilation_rate, padding='same',
                  name=name + '_%d_dilated_conv' % (dilation_rate))(x)
    x = Activation(activation, name=name + '_%d_activation' % (dilation_rate))(conv)
    x = SpatialDropout1D(dropout_rate, name=name + '_%d_spatial_dropout_%.2f' % (dilation_rate, dropout_rate))(x)
    x = Conv1D(num_filters, 1, padding='same', name=name + '_%d_conv_1x1' % (dilation_rate))(x)
    res_x = tf.keras.layers.add([original_x, x], name=name + '_%d_residual' % (dilation_rate))
    return res_x, x


class TCN_v1():
    def __init__(self, num_filters=8, kernel_size=5, dilations=[1, 2, 4, 8, 16, 32, 64, 128],
                 activation='elu', dropout_rate=0.15, name='tcn'):
        self.name = name
        self.dropout_rate = dropout_rate
        self.activation = activation
        self.dilations = dilations
        self.kernel_size = kernel_size
        self.num_filters = num_filters

    def __call__(self, inputs):
        x = inputs
        for d in self.dilations:
            x, _ = residual_block_v1(x, d, self.activation, self.num_filters,
                                     self.kernel_size, self.dropout_rate, name=self.name)
        x = Activation(self.activation, name=self.name + '_activation')(x)
        return x


class DataSequence_TCNv1(Sequence):

    def __init__(self, x, y, fps=FPS, pad_frames=None):
        self.x = x
        self.y = [madmom.utils.quantize_events(o, fps=fps, length=len(d))
                  for o, d in zip(y, self.x)]
        self.pad_frames = pad_frames

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x = np.array(cnn_pad(self.x[idx], self.pad_frames))[np.newaxis, ..., np.newaxis]
        y = self.y[idx][np.newaxis, ..., np.newaxis]
        return x, y

    def widen_targets(self, size=3, value=0.5):
        from scipy.ndimage import maximum_filter1d
        for y in self.y:
            np.maximum(y, maximum_filter1d(y, size=size) * value, out=y)


def simple_TCN(dataset):
    train_db = pickle.load(open('%s/%s.pkl' % (PKL_PATH, dataset), 'rb'))
    num_fold = 0
    fold = Fold(train_db.folds, num_fold)
    train = DataSequence_TCNv1([train_db.x[i] for i in fold.train],
                               [train_db.annotations[i] for i in fold.train],
                               pad_frames=2)
    val = DataSequence_TCNv1([train_db.x[i] for i in fold.val],
                             [train_db.annotations[i] for i in fold.val],
                             pad_frames=2)
    train.widen_targets()
    val.widen_targets()
    input_layer = Input(shape=((None, ) + train[0][0].shape[-2:]))
    conv_1 = Conv2D(num_filters, (3, 3), padding='valid',
                    name='conv_1_convolution')(input_layer)
    conv_1 = Activation(activation, name='conv_1_activation')(conv_1)
    conv_1 = MaxPooling2D((1, 3), name='conv_1_pooling')(conv_1)
    conv_1 = Dropout(dropout_rate, name='conv_1_dropout')(conv_1)
    conv_2 = Conv2D(num_filters, (3, 3), padding='valid',
                    name='conv_2_convolution')(conv_1)
    conv_2 = Activation(activation, name='conv_2_activation')(conv_2)
    conv_2 = MaxPooling2D((1, 3), name='conv_2_pooling')(conv_2)
    conv_2 = Dropout(dropout_rate, name='conv_2_dropout')(conv_2)
    conv_3 = Conv2D(num_filters, (1, 8), padding='valid',
                    name='conv_3_convolution')(conv_2)
    conv_3 = Activation(activation, name='conv_3_activation')(conv_3)
    conv_3 = Dropout(dropout_rate, name='conv_3_dropout')(conv_3)
    x = Reshape((-1, num_filters), name='tcn_reshape')(conv_3)
    dilations = [2 ** i for i in range(num_dilations)]
    tcn_layer = TCN_v1(num_filters=num_filters, kernel_size=kernel_size, dilations=dilations,
                       activation='elu', dropout_rate=dropout_rate, name='tcn')(x)
    output_layer = Dense(1, name='output')(tcn_layer)
    output_layer = Activation('sigmoid', name='output_activation')(output_layer)

    model = Model(input_layer, output_layer)

    radam = tfa.optimizers.RectifiedAdam(lr=lr, clipnorm=0.5)
    ranger = tfa.optimizers.Lookahead(radam, sync_period=6, slow_step_size=0.5)
    adam = tf.keras.optimizers.Adam()
    model.compile(optimizer=ranger, loss=K.binary_crossentropy, metrics=['binary_accuracy'])
    history = model.fit(train, steps_per_epoch=len(train), epochs=num_epochs, shuffle=True,
                        validation_data=val, validation_steps=len(val),
                        verbose=True)
    return True


if __name__ == "__main__":
    tf.config.set_soft_device_placement(True)
    dataset = 'traintest_smallsmc'
    db = create_training_dataset(dataset, daug_rate=1)
    # simple_TCN(dataset)
