In [1]:
import sys
import theano
import theano.tensor as T
import numpy as np
import string
import matplotlib.pyplot as plt
import json
import nltk
import operator
from nltk import pos_tag, word_tokenize
from sklearn.utils import shuffle
from datetime import datetime

In [2]:
def init_weight(Mi, Mo):
    return np.random.randn(Mi, Mo) / np.sqrt(Mi + Mo)

In [4]:
class GRU(object):
    def __init__(self, Mi, Mo, activation):
        self.Mi = Mi
        self.Mo = Mo
        self.f = activation
        
        W_xr = init_weight(Mi, Mo)
        W_hr = init_weight(Mo, Mo)
        br = np.zeros(Mo)
        
        W_xz = init_weight(Mi, Mo)
        W_hz = init_weight(Mo, Mo)
        bz = np.zeros(Mo)
        
        W_xh = init_weight(Mi, Mo)
        W_hh = init_weight(Mo, Mo)
        bh = np.zeros(Mo)
        h0 = np.zeros(Mo)
        
        self.W_xr = theano.shared(W_xr)
        self.W_hr = theano.shared(W_hr)
        self.br = theano.shared(br)
        
        self.W_xz = theano.shared(W_xz)
        self.W_hz = theano.shared(W_hz)
        self.bz = theano.shared(bz)
        
        self.W_xh = theano.shared(W_xh)
        self.W_hh = theano.shared(W_hh)
        self.bh = theano.shared(bh)
        self.h0 = theano.shared(h0)
        
        self.params = [self.W_xr, self.W_hr, self.br, self.W_xz, self.W_hz, self.bz, self.W_xh, 
                       self.W_hh, self.bh, self.h0]
        
    def recurrence(self, x_t, h_t1):
        z_t = T.nnet.sigmoid(x_t.dot(self.W_xz) + h_t1.dot(self.W_hz) + self.bz)
        r_t = T.nnet.sigmoid(x_t.dot(self.W_xr) + h_t1.dot(self.W_hr) + self.br)
        h_hat_t = self.f(x_t.dot(self.W_xh) + (r_t * h_t1).dot(self.W_hh) + self.bh)
        h_t = (1 - z_t) * h_t1 + z_t * h_hat_t
        return h_t
    
    def output(self, X):
        h, _ = theano.scan(
                fn=self.recurrence,
                sequences=X,
                outputs_info=[self.h0],
                n_steps=X.shape[0],
        )
        return h

In [5]:
class RNN(object):
    def __init__(self, D, hidden_layer_sizes, V):
        self.V = V
        self.D = D
        self.hidden_layer_sizes = hidden_layer_sizes
        
    def fit(self, X, learning_rate=1e-5, mu=0.99, epochs=10, activation=T.nnet.relu, show_fig=True, 
            RecurrentUnit=GRU, normalize=True):
        
        V = self.V
        D = self.D
        N = len(X)
        
        ### initialize hidden layers (i.e., recurrent units)
            
        self.hidden_layers = []
        Mi = D
        for Mo in self.hidden_layer_sizes:
            ru = RecurrentUnit(Mi, Mo, activation)
            self.hidden_layers.append(ru)
            Mi = Mo
        
        ### initialize weights for word embedding layer and output layer
                
        We = init_weight(V, D)
        Wo = init_weight(Mi, V)
        bo = np.zeros(V)
        
        self.We = theano.shared(We)
        self.Wo = theano.shared(Wo)
        self.bo = theano.shared(bo)
        self.params = [self.Wo, self.bo]
        for ru in self.hidden_layers:
            self.params += ru.params
        
        ### create input training vectors
        
        thx = T.ivector('X')
        thy = T.ivector('Y')
        
        ### forward propagation
        
        Z = self.We[thx]
        for ru in self.hidden_layers:
            Z = ru.output(Z)
            
        py_x = T.nnet.softmax(Z.dot(self.Wo) + self.bo)
        
        prediction = T.argmax(py_x, axis=1)
        self.prediction_op = theano.function(
            inputs=[thx],
            outputs=[py_x, prediction],
            allow_input_downcast=True,
        )
        
        ### back propagation
        
        cost = -T.mean(T.log(py_x[T.arange(thy.shape[0]), thy]))
        grads = T.grad(cost, self.params)
        dparams = [theano.shared(p.get_value()*0) for p in self.params]
        
        gWe = T.grad(cost, self.We)
        dWe = theano.shared(self.We.get_value()*0)
        dWe_update = mu*dWe - learning_rate*gWe
        We_update = self.We + dWe_update
        if normalize:
            We_update /= We_update.norm(2)
        
        updates = [
            (p, p + mu*dp - learning_rate*g) for p, dp, g in zip(self.params, dparams, grads)
        ] + [
            (dp, mu*dp - learning_rate*g) for dp, g in zip(self.dparams, grads)
        ] + [
            (self.We, We_update), (dWe, dWe_update)
        ]
        
        self.train_op = theano.function(
            inputs=[thx, thy],
            outputs=[cost, prediction],
            updates=updates,
        )
        
        ### training
        costs=[]
        for i in range(epochs):
            t0 = datetime.now()
            X = shuffle(X)
            cost=0
            n_correct=0
            n_total=0
            
            for j in range(N):
                if np.random.random() < 0.01 or len(X[j]) <=1:
                    input_sequence = [0] + X[j]
                    output_sequence = X[j] + [1]
                else:
                    input_sequence = [0] + X[j][:-1]
                    output_sequence = X[j]
                n_total += len(output_sequence)
                
                try:
                    c, p = self.train_op(input_sequence, output_sequence)
                except Exception as e:
                    py_x, pred = self.prediction_op(input_sequence)
                    print("input_sequence len:", len(input_sequence))
                    print("py_x.shape", py_x.shape)
                    print("pred.shape", pred.shape)
                    raise e
                cost+=c
                for pj, xj in zip(p, output_sequence):
                    if pj == xj:
                        n_correct+=1
                
                if j % 200 == 0:
                    sys.stdout.write("j/N: %d/%d correct rate so far: %f\r" % (j, N, float(n_correct)/n_total))
                    sys.stdout.flush()
            print("i:", i, "cost:", cost, "correct rate:", (float(n_correct)/n_total), "time for epoch:", (datetime.now() - t0))
            costs.append(cost)
        
        
        if show_fig:
            plt.plot(costs)
            plt.show()    
        