In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [29]:
class MyLSTM(nn.Module):
    def __init__(self,input_size,hidden_size,worddict_len):
        super(MyLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size)

        self.linear = nn.Linear(hidden_size, worddict_len,bias=False)

        self.bias = nn.Parameter(torch.zeros(worddict_len))

    def forward(self, x,hidcell):
        x = x.transpose(1,0)
        
        outputs,hidcell = self.lstm(x,hidcell)
        output = outputs[-1]
        result = self.linear(output) + self.bias

        return result

        


In [None]:
def make_batch(sentences,word_dict,worddict_len,hidden_size):
    input_batch = []
    target_batch = []
    for se in sentences:
        input = [word_dict[n] for n in se[:-1]]  
        target = word_dict[se[-1]]  
        input_batch.append(np.eye(worddict_len)[input]) 
        target_batch.append(target) 

    batch_size = len(input_batch)
    
    input_batch = torch.FloatTensor(input_batch)
    target_batch = torch.LongTensor(target_batch)
    hidden = torch.zeros(1,batch_size,hidden_size)
    cell = torch.zeros(1,batch_size,hidden_size)
    # 将隐藏状态和细胞状态打包成一个元组
    hicell = (hidden, cell)

    #返回输入、目标、批次大小和隐藏状态
    return input_batch,target_batch,batch_size,hicell 
    
# #test
# word_arr = "abcdefghijklmnopqrstuvwxyz"
# word_dict = {word_arr[i]:i for i in range(len(word_arr))}
# worddict_len = len(word_dict)
# sequences = ["make", "hate", "love", "home", "star"]

# input_batch,target_batch,batch_size,hicell = make_batch(sequences,word_dict,worddict_len, 128)
# print(input_batch.shape)

In [None]:
word_arr = "abcdefghijklmnopqrstuvwxyz"
word_dict = {word_arr[i]:i for i in range(len(word_arr))}
worddict_len = len(word_dict)
print(word_dict)
sequences = ["make", "hate", "love", "home", "star"]

hidden_size = 128
model = MyLSTM(worddict_len,hidden_size,worddict_len)

input_batch,target_batch,batch_size,hicell = make_batch(sequences,word_dict,worddict_len,hidden_size)
print(input_batch.shape)

In [None]:
#训练
epochs = 100
lr = 0.001
criterien = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=lr)

for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    result = model(input_batch,hicell)
    loss = criterien(result,target_batch)
    loss.backward()
    optimizer.step()
    print("epoch:{},loss:{}".format(epoch,loss.item()))

In [None]:
#测试

model.eval()
test_data = ["make"]
test_batch,target_batch, batch_size,hidcell= make_batch(test_data,word_dict,worddict_len,hidden_size)

result = model(test_batch,hidcell)
print(result)
_,predict = torch.max(result,1)
print(predict)

