In [None]:
import numpy as np
from common.functions import *
class LSTM:
    def __init__(self, wx, wh, b):
        self.params = [wx, wh, b]
        self.grads = [np.zeros_like(wx), np.zeros_like(wh), np.zeros_like(b)]
        self.cache = None
    def forward(self, x, h_prev, c_prev):
        wx, wh, b = self.params
        N, H = h_prev.shape
        
        A = np.matmul(x, wx) + np.matmul(h_prev, wh) + b
        
        f, g, i, o = A[:, :H], A[:, H:2*H], A[:, 2*H:3*H], A[:, 3*H:]
        f, i, o = sigmoid(f), sigmoid(i), sigmoid(o)
        g = np.tanh(g)
        
        c_next = c_prev * f + g * i
        h_next = np.tanh(c_next) * o
        
        self.cache = (x, h_prev, c_prev, i, f, g, o, c_next)
        return h_next, c_next
    def backward(self, dh_next, dc_next):
        wx, wh, b = self.params
        x, h_prev, c_prev, i, f, g, o, c_next = self.cache
        
        tanh_c_next = np.tanh(c_next)
        ds = dc_next + (dh_next * o) * (1 - tanh_c_next**2)
        
        dc_prev = f * ds
        di = g * ds
        df = c_prev * ds
        do = tanh_c_next * dh_next
        dg = i * ds
        
        di *= i * (1 - i)
        df *= f * (1 - f)
        do *= o * (1 - o)
        dg *= (1 - g ** 2)
        
        dA = np.hstack((df, dg, di, do))
        
        db = dA.sum(axis=0)
        dwh = np.dot(h_prev.T, dA)
        dwx = np.dot(x.T, dA)
        
        self.grads[0][...] = dWx
        self.grads[1][...] = dWh
        self.grads[2][...] = db

        dx = np.dot(dA, Wx.T)
        dh_prev = np.dot(dA, Wh.T)

        return dx, dh_prev, dc_prev