In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter

class RNNCell(nn.Module):
    
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.W = Parameter(torch.Tensor(hidden_size,input_size))
        self.U = Parameter(torch.Tensor(hidden_size,hidden_size))
        self.b = Parameter(torch.Tensor(hidden_size))
    
    def forward(self, x , hidden):
        return torch.tanh(x.matmul(self.W.t()) + hidden.matmul(self.U.t()) + self.b)


    
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.W = Parameter(torch.Tensor(hidden_size))
        self.cell = RNNCell(input_size, hidden_size)
        
    
    
    def forward(self, input_):
        
        hidden = torch.randn(self.hidden_size)
        outputs = []
        
        for i in torch.unbind(input_, dim=0):#this could work or could be a terrible mistake
            #import ipdb; ipdb.set_trace()
            
            hidden = self.cell(i, hidden)
            outputs.append(hidden.clone())
        
        
        return outputs[-1].matmul(self.W)
        #return torch.stack(outputs,dim=0)

In [2]:
#rnn = RNNCell(4,16)
#out = rnn(torch.randn(4), torch.randn(16))

rnn = RNN(4,16)
out = rnn(torch.randn(10,1,4))
pnn = nn.RNN(4,16)
pout, _ = pnn(torch.randn(10,1,4))

print(out.shape)
print(pout.shape)

torch.Size([1])
torch.Size([10, 1, 16])


In [65]:
def train(category_tensor, line_tensor):
    hidden = rnn.initHidden()

    rnn.zero_grad()

    for i in range(line_tensor.size()[0]):
        output, hidden = rnn(line_tensor[i], hidden)

    loss = criterion(output, category_tensor)
    loss.backward()

    # Add parameters' gradients to their values, multiplied by learning rate
    for p in rnn.parameters():
        p.data.add_(-learning_rate, p.grad.data)

    return output, loss.item()

tensor([[1.3541, 0.1041, 1.0765, 0.3207]])
tensor([[-1.7803,  0.3306, -1.7975,  1.5376]])
tensor([[-1.0844, -0.2547,  2.0768, -0.2917]])
tensor([[-0.9011, -0.5724, -0.4990, -0.6314]])
tensor([[ 1.3021,  0.4176,  1.7057, -2.0591]])
tensor([[-0.5537, -0.3709,  0.2508,  1.1543]])
tensor([[1.6652, 0.3657, 1.0785, 0.4046]])
tensor([[-0.1286,  0.8798, -0.4899,  0.8600]])
tensor([[-0.0236,  0.1552, -0.2991, -0.1127]])
tensor([[-0.1399, -0.2968, -0.2649, -0.0176]])
