In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
%matplotlib inline

In [2]:
INPUT = 28
HIDDEN = 128
OUTPUT = 10

INPUT += HIDDEN

ALPHA = 0.001
BATCH_NUM = 64

ITER_NUM = 10000
LOG_ITER = ITER_NUM // 10
PLOT_ITER = ITER_NUM // 200

In [3]:
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

wf = np.random.normal(0, 1, [INPUT, HIDDEN])
wi = np.random.normal(0, 1, [INPUT, HIDDEN])
wc = np.random.normal(0, 1, [INPUT, HIDDEN])
wo = np.random.normal(0, 1, [INPUT, HIDDEN])
wy = np.random.normal(0, 1, [HIDDEN, OUTPUT])

bf = np.random.normal(0, 1, [HIDDEN])
bi = np.random.normal(0, 1, [HIDDEN])
bc = np.random.normal(0, 1, [HIDDEN])
bo = np.random.normal(0, 1, [HIDDEN])
by = np.random.normal(0, 1, [OUTPUT])

dwf = np.zeros_like(wf)
dwi = np.zeros_like(wi)
dwc = np.zeros_like(wc)
dwo = np.zeros_like(wo)
dwy = np.zeros_like(wy)

dbf = np.zeros_like(bf)
dbi = np.zeros_like(bi)
dbc = np.zeros_like(bc)
dbo = np.zeros_like(bo)
dby = np.zeros_like(by)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [4]:
errors = []

In [5]:
def softmax(arr):
    e = np.exp(arr)
    return e / np.sum(e)

def cross_entropy(out, label):
    return -np.sum(label * np.log(out))

def sigmoid(arr):
    return 1 / 1 + np.exp(-arr)

def deriv_sigmoid(out):
    return out * (1 - out)

def tanh(arr):
    return 2 / (1 + np.exp(-2 * arr)) - 1

def deriv_tanh(out):
    return 1 - np.square(out)

In [6]:
for i in range(ITER_NUM):
    X, Y = mnist.train.next_batch(BATCH_NUM)
    Xt = np.transpose(np.reshape(X, [-1, 28, 28]), [1, 0, 2])
        
    caches = []
    
    hidden_values = []
    hidden_values.append(np.zeros([BATCH_NUM, HIDDEN]))
        
    cell_values = []
    cell_values.append(np.zeros([BATCH_NUM, HIDDEN]))
        
    for x in Xt:
        prev_cell = cell_values[-1]
        prev_hidden = hidden_values[-1]
        x = np.column_stack([x, prev_hidden])
            
        hf = sigmoid(np.dot(x, wf) + bf)
        hi = sigmoid(np.dot(x, wi) + bi)
        ho = sigmoid(np.dot(x, wo) + bo)
        hc = tanh(np.dot(x, wc) + bc)
            
        cell = hf * prev_cell + hi * hc
        hidden = ho * tanh(cell)
        
        cell_values.append(cell)
        hidden_values.append(hidden)
        caches.append([x, hf, hi, ho, hc])
        
    out = np.dot(hidden, wy) + by
    pred = softmax(out)
    entropy = cross_entropy(pred, Y)
    
    dout = pred - Y
    dwy = np.dot(hidden.T, dout)
    dby = np.sum(dout, axis=1)
    
    dc_next = np.zeros_like(cell_values[-1])
    dh_next = np.zeros_like(hidden_values[-1])
    
    for x in range(Xt.shape[0]):
        cell = cell_values[-x-1]
        prev_cell = cell_values[-x-2]
        
        hidden = hidden_values[-x-1]
        prev_hidden = hidden_values[-x-2]
        
        x, hf, hi, ho, hc = caches[-x-1]
        
        tc = tanh(cell)
        
        dh = np.dot(dout, wy.T) + dh_next
        
        dc = dh * ho * deriv_tanh(tc)
        dc = dc + dc_next
        
        dho = dh * tc 
        dho = dho * deriv_sigmoid(ho)
        
        dhf = dc * prev_cell 
        dhf = dhf * deriv_sigmoid(hf)
        
        dhi = dc * hc 
        dhi = dhi * deriv_sigmoid(hi)
        
        dhc = dc * hi 
        dhc = dhc * deriv_tanh(hc)
        
        dwf += np.dot(X.T, dhf)
        dbf += np.sum(dhf, axis=1)
        dXf = np.dot(dhf, wf.T)
        
        dwi += np.dot(X.T, dhi)
        dbi += np.sum(dhi, axis=1)
        dXi = np.dot(dhi, wi.T)
        
        dwo += np.dot(X.T, dho)
        dbo += np.sum(dho, axis=1)
        dXo = np.dot(dho, wo.T)
        
        dhc += np.dot(X.T, dhc)
        dbc += np.sum(dhc, axis=1)
        dXc = np.dot(dhc, wc.T)
        
        dX = dXf + dXi + dXo + dXc
        
        dc_next = hf * dc
        dh_next = dX[:, -HIDDEN:]
        
    wf += ALPHA * dwf
    wi += ALPHA * dwi
    wc += ALPHA * dwc
    wo += ALPHA * dwo
    wy += ALPHA * dwy
    
    bf += ALPHA * dbf
    bi += ALPHA * dbi
    bc += ALPHA * dbc
    bo += ALPHA * dbo
    by += ALPHA * dby
    
    dwf *= 0
    dwi *= 0
    dwc *= 0
    dwo *= 0
    dwy *= 0
    
    dbf *= 0
    dbi *= 0
    dbc *= 0
    dbo *= 0
    dby *= 0
    
    if i % PLOT_ITER == 0:
        errors.append(entropy)
    
    if i % LOG_ITER == 0:
        print('iter', i)
        print('entropy', entropy)
        print('----------')



ValueError: operands could not be broadcast together with shapes (156,128) (784,128) (156,128) 

In [None]:
plt.plot(errors)