In [40]:
# code by Tae Hwan Jung @graykode
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

def make_batch():
    input_batch, target_batch = [], []

    for seq in seq_data:
        input = [word_dict[n] for n in seq[:-1]] # 'm', 'a' , 'k' is input
        target = word_dict[seq[-1]] # 'e' is target
        input_batch.append(np.eye(n_class)[input])
        target_batch.append(target)

    return input_batch, target_batch

import math

class CustomLSTM(nn.Module):
    def __init__(self, input_sz, hidden_sz):
        super().__init__()
        self.input_sz = input_sz
        self.hidden_size = hidden_sz
        self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
        self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
        self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
        self.init_weights()
                
    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)
         
    def forward(self, x, 
                init_states=None):
        """Assumes x is of shape (batch, sequence, feature)"""
        bs, seq_sz, _ = x.size()
        print('bs:',bs)
        print('seq_sz:',seq_sz)
        hidden_seq = []
        if init_states is None:
            h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device), 
                        torch.zeros(bs, self.hidden_size).to(x.device))
        else:
            h_t, c_t = init_states
         
        HS = self.hidden_size
        for t in range(seq_sz):
            x_t = x[:, t, :]
            print('x_t:',x_t.shape)
            print('self.W:',self.W.shape)
            print('h_t:', h_t.shape)
            print('self.U', self.U.shape)
            print('self.bias:', self.bias.shape)
            # batch the computations into a single matrix multiplication
            gates = x_t @ self.W + h_t @ self.U + self.bias
            print('gates:',gates.shape)
            i_t, f_t, g_t, o_t = (
                torch.sigmoid(gates[:, :HS]), # input
                torch.sigmoid(gates[:, HS:HS*2]), # forget
                torch.tanh(gates[:, HS*2:HS*3]),
                torch.sigmoid(gates[:, HS*3:]), # output
            )
            print('i_t',i_t.shape)
            print('f_t',f_t.shape)
            print('g_t',g_t.shape)
            print('o_t', o_t.shape)
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            hidden_seq.append(h_t.unsqueeze(0))
            print(len(hidden_seq))
        hidden_seq = torch.cat(hidden_seq, dim=0)
        print('hidden_seq', hidden_seq.shape)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        print('hidden_seq2', hidden_seq.shape)
        return hidden_seq, (h_t, c_t)

class TextLSTM(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = CustomLSTM(n_class, n_hidden)#nn.LSTM(32, 32, batch_first=True)
        self.fc1 = nn.Linear(n_hidden, n_class)
        
    def forward(self, x):
        
        x_, (h_n, c_n) = self.lstm(x)
        x_ = (x_[:, -1, :])
        x_ = self.fc1(x_)
        print('x_:',x_.shape)
        return x_

if __name__ == '__main__':
    n_step = 3 # number of cells(= number of Step)
    n_hidden = 128 # number of hidden units in one cell

    char_arr = [c for c in 'abcdefghijklmnopqrstuvwxyz']
    word_dict = {n: i for i, n in enumerate(char_arr)}
    number_dict = {i: w for i, w in enumerate(char_arr)}
    n_class = len(word_dict)  # number of class(=number of vocab)

    seq_data = ['make', 'need', 'coal', 'word', 'love', 'hate', 'live', 'home', 'hash', 'star']

    model = TextLSTM()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    input_batch, target_batch = make_batch()
    input_batch = torch.FloatTensor(input_batch)
    target_batch = torch.LongTensor(target_batch)

    # Training
    for epoch in range(1000):
        optimizer.zero_grad()

        output = model(input_batch)
        loss = criterion(output, target_batch)
        if (epoch + 1) % 100 == 0:
            print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))

        loss.backward()
        optimizer.step()

    inputs = [sen[:3] for sen in seq_data]

    predict = model(input_batch).data.max(1, keepdim=True)[1]
    print(predict)
    ps = predict.squeeze()
    print('---->',number_dict[ps[2].item()])
    for i,n in enumerate(inputs):
        print(n+number_dict[ps[i].item()])
    print(inputs, '->', [number_dict[n.item()] for n in predict.squeeze()])

bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
1
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
2
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
3
hidden_seq torch.Size([3, 10, 128])
hidden_seq2 torch.Size([10, 3, 128])
x_: torch.Size([10, 26])
bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: tor

bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
1
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
2
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
3
hidden_seq torch.Size([3, 10, 128])
hidden_seq2 torch.Size([10, 3, 128])
x_: torch.Size([10, 26])
bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: tor

1
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
2
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
3
hidden_seq torch.Size([3, 10, 128])
hidden_seq2 torch.Size([10, 3, 128])
x_: torch.Size([10, 26])
bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
1
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 51

gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
3
hidden_seq torch.Size([3, 10, 128])
hidden_seq2 torch.Size([10, 3, 128])
x_: torch.Size([10, 26])
bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
1
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
2
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])

gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
1
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
2
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
3
hidden_seq torch.Size([3, 10, 128])
hidden_seq2 torch.Size([10, 3, 128])
x_: torch.Size([10, 26])
bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])

bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
1
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
2
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
3
hidden_seq torch.Size([3, 10, 128])
hidden_seq2 torch.Size([10, 3, 128])
x_: torch.Size([10, 26])
bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: tor

bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
1
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
2
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
3
hidden_seq torch.Size([3, 10, 128])
hidden_seq2 torch.Size([10, 3, 128])
x_: torch.Size([10, 26])
bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: tor

bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
1
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
2
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
3
hidden_seq torch.Size([3, 10, 128])
hidden_seq2 torch.Size([10, 3, 128])
x_: torch.Size([10, 26])
bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: tor

bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
1
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
2
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
3
hidden_seq torch.Size([3, 10, 128])
hidden_seq2 torch.Size([10, 3, 128])
x_: torch.Size([10, 26])
bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: tor

bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
1
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
2
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
3
hidden_seq torch.Size([3, 10, 128])
hidden_seq2 torch.Size([10, 3, 128])
x_: torch.Size([10, 26])
bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: tor

bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
1
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
2
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
3
hidden_seq torch.Size([3, 10, 128])
hidden_seq2 torch.Size([10, 3, 128])
x_: torch.Size([10, 26])
bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: tor

bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
1
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
2
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
3
hidden_seq torch.Size([3, 10, 128])
hidden_seq2 torch.Size([10, 3, 128])
x_: torch.Size([10, 26])
bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: tor

bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
1
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
2
x_t: torch.Size([10, 26])
self.W: torch.Size([26, 512])
h_t: torch.Size([10, 128])
self.U torch.Size([128, 512])
self.bias: torch.Size([512])
gates: torch.Size([10, 512])
i_t torch.Size([10, 128])
f_t torch.Size([10, 128])
g_t torch.Size([10, 128])
o_t torch.Size([10, 128])
3
hidden_seq torch.Size([3, 10, 128])
hidden_seq2 torch.Size([10, 3, 128])
x_: torch.Size([10, 26])
bs: 10
seq_sz: 3
x_t: torch.Size([10, 26])
self.W: tor

In [3]:
input_batch.shape


torch.Size([10, 3, 26])

In [9]:
x = torch.tensor([1, 2, 3])@torch.tensor([0, 2, 1])

In [10]:
x

tensor(7)