In [64]:
import numpy
import torch
import torch.nn as nn
import torch.optim as optim

In [65]:
def make_batch():
    input_batch = []
    target_batch = []
    
    words = sentence.split()
    for i, word in enumerate(words[:-1]):  
        input_ = [word2idx[word] for word in words[:(i+1)]]
        input_ = input_ + [0]*(max_len - len(input_))
        target_ = word2idx[words[i+1]]
        # 利用numpy.eye(num_class)[list]生成one-hot编码
        input_batch.append(numpy.eye(n_class)[input_])
        target_batch.append(target_)
    return input_batch,target_batch
        

In [88]:
class BiLSTM(nn.Module):
    def __init__(self):
        super(BiLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=n_class,hidden_size=n_hidden, bidirectional=True)
        self.W = nn.Linear(2*n_hidden, n_class, bias=False)
        self.b = nn.Parameter(torch.ones(n_class))
    
    def forward(self, X): #X:[batch_size, seq_len, input_size]
        X = X.transpose(0,1)
        
        h_0 = torch.zeros(2, X.size(1), n_hidden)
        c_0 = torch.zeros(2, X.size(1), n_hidden)
        output, _ = self.lstm(X, (h_0, c_0))
        # 取索引时，-1:  和-1是不同的，-1降维，-1:不降
        output = output[-1]
        out = self.W(output) + self.b  #[batch_size, hidden_size]
        return out
    

In [89]:
if __name__=='__main__':
    sentence = (
            'hello everyone, my name is jia chang min, nice to meet you'
        )
    word2idx = {word:i for i,word in enumerate(list(set(sentence.split())))}
    idx2word = {i:word for i,word in enumerate(list(set(sentence.split())))}
    max_len = len(sentence.split())
    n_class = len(word2idx)
    input_batch, target_batch = make_batch()
    
    n_hidden = 5
    model = BiLSTM()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    
    input_batch = torch.FloatTensor(input_batch)
    target_batch = torch.LongTensor(target_batch)
    for epoch in range(1000):
        optimizer.zero_grad()
        out = model(input_batch)
        loss = criterion(out, target_batch)
        if (epoch + 1)%100==0:
            print(f"Epoch {epoch+1:04d} cost= {loss:.6f} ")
            
        loss.backward()
        optimizer.step()
    

Epoch 0100 cost= 1.269796 
Epoch 0200 cost= 0.791136 
Epoch 0300 cost= 0.443788 
Epoch 0400 cost= 0.301994 
Epoch 0500 cost= 0.217400 
Epoch 0600 cost= 0.124835 
Epoch 0700 cost= 0.080704 
Epoch 0800 cost= 0.057519 
Epoch 0900 cost= 0.043598 
Epoch 1000 cost= 0.034385 


In [91]:
sentence

'hello everyone, my name is jia chang min, nice to meet you'

In [127]:
predict = model(input_batch).detach().max(dim=1, keepdim=True)[1]
[idx2word[idx.item()] for idx in predict]

['everyone,',
 'my',
 'name',
 'is',
 'jia',
 'chang',
 'min,',
 'nice',
 'to',
 'meet',
 'you']

## argmax

In [119]:
model(input_batch).detach().argmax(dim=1)

tensor([11,  3,  6,  2,  7,  5,  4,  1,  8,  9, 10])

## torch.max

In [125]:
model(input_batch).detach().max(dim=1)

torch.return_types.max(
values=tensor([12.2623, 11.9291, 12.8829, 12.8098, 12.9819, 13.3654,  9.9333, 12.5040,
        13.0879, 12.2386, 13.3414]),
indices=tensor([11,  3,  6,  2,  7,  5,  4,  1,  8,  9, 10]))

In [123]:
model(input_batch).detach().max(dim=1, keepdim=True)

torch.return_types.max(
values=tensor([[12.2623],
        [11.9291],
        [12.8829],
        [12.8098],
        [12.9819],
        [13.3654],
        [ 9.9333],
        [12.5040],
        [13.0879],
        [12.2386],
        [13.3414]]),
indices=tensor([[11],
        [ 3],
        [ 6],
        [ 2],
        [ 7],
        [ 5],
        [ 4],
        [ 1],
        [ 8],
        [ 9],
        [10]]))

In [126]:
model(input_batch).detach().max(dim=1, keepdim=True)[1]

tensor([[11],
        [ 3],
        [ 6],
        [ 2],
        [ 7],
        [ 5],
        [ 4],
        [ 1],
        [ 8],
        [ 9],
        [10]])

In [104]:
[idx2word[idx.item()] for idx in target_batch]

['everyone,',
 'my',
 'name',
 'is',
 'jia',
 'chang',
 'min,',
 'nice',
 'to',
 'meet',
 'you']

In [None]:
predict = model(input_batch)

In [72]:
word2idx

{'hello': 0,
 'nice': 1,
 'is': 2,
 'my': 3,
 'min,': 4,
 'chang': 5,
 'name': 6,
 'jia': 7,
 'to': 8,
 'meet': 9,
 'you': 10,
 'everyone,': 11}

In [73]:
idx2word

{0: 'hello',
 1: 'nice',
 2: 'is',
 3: 'my',
 4: 'min,',
 5: 'chang',
 6: 'name',
 7: 'jia',
 8: 'to',
 9: 'meet',
 10: 'you',
 11: 'everyone,'}

In [74]:
max_len

12

In [76]:
n_class

12

In [46]:
torch.ones((10,3))

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

In [50]:
torch.FloatTensor(input_batch).size()

torch.Size([26, 27, 27])

In [51]:
torch.LongTensor(target_batch).size()

torch.Size([26])

In [63]:
f"{max_len:.6f}"

'27.000000'

In [84]:
a = torch.arange(18).view(2,3,3)

In [85]:
a

tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8]],

        [[ 9, 10, 11],
         [12, 13, 14],
         [15, 16, 17]]])

In [86]:
a[-1]

tensor([[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]])

In [87]:
a[-1:]

tensor([[[ 9, 10, 11],
         [12, 13, 14],
         [15, 16, 17]]])