In [48]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [49]:
n_step = 3
n_hidden = 128
char_arr = [c for c in 'abcdefghijklmnopqrstuvwxyz']
word_dict = {n:i for i,n in enumerate(char_arr)}
number_dict = {i:w for i,w in enumerate(char_arr)}
n_class = len(word_dict) # 26
seq_data = ['make', 'need', 'coal', 'word', 'love',
           'hate', 'live', 'home', 'hash', 'star']

In [50]:
def make_batch(seq_data):
    input_batch, target_batch = [], []
    for seq in seq_data:
        input = [word_dict[n] for n in seq[:-1]] # m a k 对应的编号
        target = word_dict[seq[-1]]
        input_batch.append(np.eye(n_class)[input])
        target_batch.append(target)
    return input_batch, target_batch       

In [51]:
class TextLSTM(nn.Module):
    def __init__(self):
        super(TextLSTM, self).__init__()
        
        self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden) # (26, 128)
        # input (sequence_length, batch, input_size)  input_size在此为 独热编码
        # output (seq_len, batch, num_directions * hidden_size)
        self.W = nn.Linear(n_hidden, n_class, bias=False)
        self.b = nn.Parameter(torch.ones([n_class]))
    
    def forward(self, X):
        input = X.transpose(0, 1) # 10, 3, 26 -> 3, 10, 26
        hidden_state = torch.zeros(1, len(X), n_hidden) #  1, 10, 128
        cell_state = torch.zeros(1, len(X), n_hidden)
        outputs, (_, _) = self.lstm(input, (hidden_state, cell_state))
        # outputs 3, 10, 1*128
        outputs = outputs[-1] # [batch_size, n_hidden]
        model = self.W(outputs) + self.b # model: [batch_size, n_class]
        return model     

In [52]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
input_batch, target_batch = make_batch(seq_data)
input_batch = torch.FloatTensor(input_batch) # 10,3,26
target_batch = torch.LongTensor(target_batch)

In [54]:
for epoch in range(1000):
    optimizer.zero_grad()
    output = model(input_batch)
    loss = criterion(output, target_batch)
    if (epoch + 1) % 100 == 0:
        print(epoch+1,' epoch, cost: ', loss)
    loss.backward()
    optimizer.step()

100  epoch, cost:  tensor(4.7905e-05, grad_fn=<NllLossBackward>)
200  epoch, cost:  tensor(1.6915e-05, grad_fn=<NllLossBackward>)
300  epoch, cost:  tensor(8.6425e-06, grad_fn=<NllLossBackward>)
400  epoch, cost:  tensor(5.2928e-06, grad_fn=<NllLossBackward>)
500  epoch, cost:  tensor(3.5882e-06, grad_fn=<NllLossBackward>)
600  epoch, cost:  tensor(2.6226e-06, grad_fn=<NllLossBackward>)
700  epoch, cost:  tensor(2.0146e-06, grad_fn=<NllLossBackward>)
800  epoch, cost:  tensor(1.5855e-06, grad_fn=<NllLossBackward>)
900  epoch, cost:  tensor(1.2875e-06, grad_fn=<NllLossBackward>)
1000  epoch, cost:  tensor(1.0252e-06, grad_fn=<NllLossBackward>)


In [55]:
predict_data = ['name','this', 'star']
inputs = [sen[:3] for sen in predict_data]

input_batch, target_batch = make_batch(predict_data)
input_batch = torch.FloatTensor(input_batch) # 10,3,26
target_batch = torch.LongTensor(target_batch)

predict = model(input_batch).data.max(1, keepdim=True)[1]
print(inputs, '->', [number_dict[n.item()] for n in predict.squeeze()])

['nam', 'thi', 'sta'] -> ['d', 'e', 'r']
