In [None]:
# LSTM #######################################################################
def initialize_parameters():
    parameters = {}
    parameters['fgw'] = np.random.normal(0,0.01,(INPUT+HIDDEN,HIDDEN))
    parameters['igw'] = np.random.normal(0,0.01,(INPUT+HIDDEN,HIDDEN))
    parameters['ogw'] = np.random.normal(0,0.01,(INPUT+HIDDEN,HIDDEN))
    parameters['ggw'] = np.random.normal(0,0.01,(INPUT+HIDDEN,HIDDEN))
    parameters['how'] = np.random.normal(0,0.01,(HIDDEN,OUTPUT)) 
    return parameters


def initialize_V(parameters):
    V = {}
    V['vfgw'] = np.zeros(parameters['fgw'].shape)
    V['vigw'] = np.zeros(parameters['igw'].shape)
    V['vogw'] = np.zeros(parameters['ogw'].shape)
    V['vggw'] = np.zeros(parameters['ggw'].shape)
    V['vhow'] = np.zeros(parameters['how'].shape)
    return V


def initialize_S(parameters):
    S = {}
    S['sfgw'] = np.zeros(parameters['fgw'].shape)
    S['sigw'] = np.zeros(parameters['igw'].shape)
    S['sogw'] = np.zeros(parameters['ogw'].shape)
    S['sggw'] = np.zeros(parameters['ggw'].shape)
    S['show'] = np.zeros(parameters['how'].shape)
    return S


def get_embeddings(batch_dataset, embeddings):
    embedding_dataset = np.matmul(batch_dataset, embeddings)
    return embedding_dataset


def lstm_cell(batch_dataset, prev_activation_matrix, prev_cell_matrix, parameters):
    concat_dataset = np.concatenate((batch_dataset, prev_activation_matrix), axis=1)
    fa = sigmoid(np.matmul(concat_dataset, parameters['fgw']))
    ia = sigmoid(np.matmul(concat_dataset, parameters['igw']))
    oa = sigmoid(np.matmul(concat_dataset, parameters['ogw']))
    ga = tanh(np.matmul(concat_dataset, parameters['ggw']))
    cell_memory_matrix = np.multiply(fa, prev_cell_matrix) + np.multiply(ia, ga)
    activation_matrix = np.multiply(oa, tanh(cell_memory_matrix))
    lstm_activations = {}
    lstm_activations['fa'] = fa
    lstm_activations['ia'] = ia
    lstm_activations['oa'] = oa
    lstm_activations['ga'] = ga
    return lstm_activations,cell_memory_matrix,activation_matrix


def output_cell(activation_matrix, parameters):
    output_matrix = softmax(np.matmul(activation_matrix, parameters['how'])) 
    return output_matrix


def cal_loss_accuracy(batch_labels, output_cache):
    loss, accuracy, prob = 0, 0, 1
    batch_size = batch_labels[0].shape[0]
    for i in range(1, len(output_cache)+1):
        labels = batch_labels[i]
        pred = output_cache['o' + str(i)]
        prob = np.multiply(prob, np.sum(np.multiply(labels, pred), axis=1).reshape(-1, 1))
        loss += np.sum((np.multiply(labels, np.log(pred)) + np.multiply(1-labels, np.log(1-pred))), axis=1).reshape(-1, 1)
        accuracy += np.array(np.argmax(labels, 1) == np.argmax(pred, 1), dtype=np.float32).reshape(-1, 1)
    perplexity = np.sum((1 / prob)**(1 / len(output_cache))) / batch_size
    loss = np.sum(loss) * (-1 / batch_size)
    accuracy = (np.sum(accuracy) / (batch_size)) / len(output_cache)
    
    return perplexity, loss, accuracy
    

def forward_propagation(batches, parameters, embeddings):
    batch_size = batches[0].shape[0]
    lstm_cache, activation_cache, cell_cache = {}, {}, {}
    output_cache, embedding_cache = {}, {}
    a0 = np.zeros([batch_size, HIDDEN], dtype=np.float32)
    c0 = np.zeros([batch_size, HIDDEN], dtype=np.float32)
    activation_cache['a0'] = a0
    cell_cache['c0'] = c0
    for i in range(len(batches) - 1):
        batch_dataset = batches[i]
        batch_dataset = get_embeddings(batch_dataset, embeddings)
        embedding_cache['emb' + str(i)] = batch_dataset
        lstm_activations, ct, at = lstm_cell(batch_dataset, a0, c0, parameters)
        ot = output_cell(at, parameters)
        lstm_cache['lstm' + str(i+1)]  = lstm_activations
        activation_cache['a'+str(i+1)] = at
        cell_cache['c' + str(i+1)] = ct
        output_cache['o'+str(i+1)] = ot
        a0 = at
        c0 = ct  
    return embedding_cache, lstm_cache, activation_cache, cell_cache, output_cache


def calculate_output_cell_error(batch_labels, output_cache, parameters):
    output_error_cache, activation_error_cache = {}, {}
    for i in range(1, len(output_cache)+1):
        error_output = output_cache['o' + str(i)] - batch_labels[i]
        error_activation = np.matmul(error_output, parameters['how'].T)
        output_error_cache['eo'+str(i)] = error_output
        activation_error_cache['ea'+str(i)] = error_activation
    return output_error_cache, activation_error_cache


def calculate_single_lstm_cell_error(activation_output_error, next_activation_error,
                                     next_cell_error, parameters, lstm_activation,
                                     cell_activation, prev_cell_activation):
    activation_error = activation_output_error + next_activation_error
    oa = lstm_activation['oa']
    ia = lstm_activation['ia']
    ga = lstm_activation['ga']
    fa = lstm_activation['fa']
    eo = np.multiply(np.multiply(np.multiply(activation_error, tanh(cell_activation)), oa), 1-oa)
    cell_error = np.multiply(np.multiply(activation_error, oa), dtanh(tanh(cell_activation)))
    cell_error += next_cell_error
    ei = np.multiply(np.multiply(np.multiply(cell_error, ga), ia), 1-ia)
    eg = np.multiply(np.multiply(cell_error, ia), dtanh(ga))
    ef = np.multiply(np.multiply(np.multiply(cell_error, prev_cell_activation), fa), 1-fa)
    prev_cell_error = np.multiply(cell_error, fa)
    embed_activation_error = np.matmul(ef, parameters['fgw'].T)
    embed_activation_error += np.matmul(ei, parameters['igw'].T)
    embed_activation_error += np.matmul(eo, parameters['ggw'].T)
    embed_activation_error += np.matmul(eg, parameters['ogw'].T)
    input_units = parameters['fgw'].shape[0] - parameters['fgw'].shape[1]
    prev_activation_error = embed_activation_error[:, input_units:]
    embed_error = embed_activation_error[:, :input_units]
    lstm_error = {}
    lstm_error['ef'] = ef
    lstm_error['ei'] = ei
    lstm_error['eo'] = eo
    lstm_error['eg'] = eg
    return prev_activation_error, prev_cell_error, embed_error, lstm_error


def backward_propagation(batch_labels, embedding_cache, lstm_cache,
                         activation_cache, cell_cache, output_cache, parameters):
    output_error_cache, activation_error_cache = calculate_output_cell_error(batch_labels, output_cache, parameters)
    lstm_error_cache, embedding_error_cache = {}, {}
    eat = np.zeros(activation_error_cache['ea1'].shape)
    ect = np.zeros(activation_error_cache['ea1'].shape)
    for i in range(len(lstm_cache), 0, -1):
        pae, pce, ee, le = calculate_single_lstm_cell_error(activation_error_cache['ea'+str(i)], eat, ect, parameters, lstm_cache['lstm'+str(i)], cell_cache['c'+str(i)], cell_cache['c'+str(i-1)])
        lstm_error_cache['elstm'+str(i)] = le
        embedding_error_cache['eemb'+str(i-1)] = ee
        eat = pae
        ect = pce
    derivatives = {}
    derivatives['dhow'] = calculate_output_cell_derivatives(output_error_cache, activation_cache, parameters)
    lstm_derivatives = {}
    for i in range(1, len(lstm_error_cache)+1):
        lstm_derivatives['dlstm'+str(i)] = calculate_single_lstm_cell_derivatives(lstm_error_cache['elstm'+str(i)], embedding_cache['emb'+str(i-1)], activation_cache['a'+str(i-1)])
    derivatives['dfgw'] = np.zeros(parameters['fgw'].shape)
    derivatives['digw'] = np.zeros(parameters['igw'].shape)
    derivatives['dogw'] = np.zeros(parameters['ogw'].shape)
    derivatives['dggw'] = np.zeros(parameters['ggw'].shape)
    for i in range(1, len(lstm_error_cache)+1):
        derivatives['dfgw'] += lstm_derivatives['dlstm'+str(i)]['dfgw']
        derivatives['digw'] += lstm_derivatives['dlstm'+str(i)]['digw']
        derivatives['dogw'] += lstm_derivatives['dlstm'+str(i)]['dogw']
        derivatives['dggw'] += lstm_derivatives['dlstm'+str(i)]['dggw']
    return derivatives, embedding_error_cache


def calculate_output_cell_derivatives(output_error_cache, activation_cache, parameters):
    dhow = np.zeros(parameters['how'].shape)
    batch_size = activation_cache['a1'].shape[0]
    for i in range(1, len(output_error_cache)+1):
        output_error = output_error_cache['eo' + str(i)]
        activation = activation_cache['a'+str(i)]
        dhow += np.matmul(activation.T,output_error)/batch_size
    return dhow


def calculate_single_lstm_cell_derivatives(lstm_error, embedding_matrix, activation_matrix):
    concat_matrix = np.concatenate((embedding_matrix, activation_matrix), axis=1) 
    batch_size = embedding_matrix.shape[0]
    derivatives = {}
    derivatives['dfgw'] = np.matmul(concat_matrix.T, lstm_error['ef']) / batch_size
    derivatives['digw'] = np.matmul(concat_matrix.T, lstm_error['ei']) / batch_size
    derivatives['dogw'] = np.matmul(concat_matrix.T, lstm_error['eo']) / batch_size
    derivatives['dggw'] = np.matmul(concat_matrix.T, lstm_error['eg']) / batch_size
    return derivatives


def update_parameters(parameters, derivatives, V, S):
    vfgw = BETA1 * V['vfgw'] + (1 - BETA1) * derivatives['dfgw']
    vigw = BETA1 * V['vigw'] + (1 - BETA1) * derivatives['digw']
    vogw = BETA1 * V['vogw'] + (1 - BETA1) * derivatives['dogw']
    vggw = BETA1 * V['vggw'] + (1 - BETA1) * derivatives['dggw']
    vhow = BETA1 * V['vhow'] + (1 - BETA1) * derivatives['dhow']
    sfgw = BETA2 * S['sfgw'] + (1 - BETA2) * derivatives['dfgw']**2
    sigw = BETA2 *S['sigw'] + (1 - BETA2) * derivatives['digw']**2
    sogw = BETA2 *S['sogw'] + (1 - BETA2) * derivatives['dogw']**2
    sggw = BETA2 * S['sggw'] + (1 - BETA2) * derivatives['dggw']**2
    show = BETA2 * S['show'] + (1 - BETA2) * derivatives['dhow']**2
    parameters['fgw'] -= LEARNING_RATE * (vfgw / (np.sqrt(sfgw) + 1e-6))
    parameters['igw'] -= LEARNING_RATE * (vigw / (np.sqrt(sigw) + 1e-6))
    parameters['ogw'] -= LEARNING_RATE * (vogw / (np.sqrt(sogw) + 1e-6))
    parameters['ggw'] -= LEARNING_RATE * (vggw / (np.sqrt(sggw) + 1e-6))
    parameters['how'] -= LEARNING_RATE * (vhow / (np.sqrt(show) + 1e-6))
    V['vfgw'], V['vigw'], V['vogw'], V['vggw'], V['vhow'] = vfgw, vigw, vogw, vggw, vhow
    S['sfgw'], S['sigw'], S['sogw'], S['sggw'], S['show'] = sfgw, sigw, sogw, sggw, show
    return parameters, V, S


def update_embeddings(embeddings, embedding_error_cache, batch_labels):
    embedding_derivatives = np.zeros(embeddings.shape)
    batch_size = batch_labels[0].shape[0]
    for i in range(len(embedding_error_cache)):
        embedding_derivatives += np.matmul(batch_labels[i].T, embedding_error_cache['eemb'+str(i)]) / batch_size
    embeddings = embeddings - LEARNING_RATE * embedding_derivatives
    return embeddings