In [187]:
import torch,pprint
import torch.nn.utils.rnn as rnn
import torch.nn as nn

data = [
    'hello world',
    'i choose python',
    'about pytorch',
    'what ?',
]
data = sorted(data, key=lambda x:len(x),reverse=True)  
words = set(''.join(data))
word_to_index = {v:i for i,v in enumerate(words)}

print('feature  length',len(words))

# 字母转换位index
data_index = [[word_to_index[y] for y in x[:-1]] for x in data]
labels_index = [word_to_index[y] for x in data for y in x[1:]]
labels_index = torch.tensor(labels_index)
print(data_index)
print(labels_index)

pprint.pprint(data_index)

batch_size = len(data_index)
# input 为 [:-1], label 为 [1:] 长度都是length-1
max_seq_len = max([len(x) for x in data_index])
features_size = len(words)

def manual_pad():
    """
    手动padding : 新建 zeros tensor,根据data_index 把tensor中相应index的值改为1
    需要给data 加padding
    """
    data_padding = torch.zeros((batch_size,max_seq_len,features_size))
    for i,v in enumerate(data_index):
        for j,v2 in enumerate(v):
            data_padding[i][j][v2] = 1
            
    return data_padding

def auto_pad(features_size):
    """
    先把所有seq转化为 0,1 的tensor后
    使用 pad_sequence
    需要给data 加padding
    """
    def t(x):
        tensor = torch.zeros(len(x),features_size)
        for i,v in enumerate(x):
            tensor[i][v] = 1
        return tensor

    tensor_list = [t(x) for x in data_index]
    data_padding = rnn.pad_sequence(tensor_list,batch_first=True)
    return data_padding


data_padding_1 = manual_pad()
data_padding_2 = auto_pad(features_size)

print(data_padding_1.size())

print(torch.equal(data_padding_1, data_padding_2))
print(torch.equal(label_padding_1, label_padding_2))

# print(data_padding.size())



feature  length 19
[[16, 0, 3, 15, 1, 1, 6, 7, 0, 12, 14, 13, 15, 1], [2, 4, 1, 10, 13, 0, 12, 14, 13, 1, 18, 3], [15, 7, 11, 11, 1, 0, 9, 1, 18, 11], [9, 15, 2, 13, 0]]
tensor([ 0,  3, 15,  1,  1,  6,  7,  0, 12, 14, 13, 15,  1, 17,  4,  1, 10, 13,
         0, 12, 14, 13,  1, 18,  3, 15,  7, 11, 11,  1,  0,  9,  1, 18, 11,  8,
        15,  2, 13,  0,  5])
[[16, 0, 3, 15, 1, 1, 6, 7, 0, 12, 14, 13, 15, 1],
 [2, 4, 1, 10, 13, 0, 12, 14, 13, 1, 18, 3],
 [15, 7, 11, 11, 1, 0, 9, 1, 18, 11],
 [9, 15, 2, 13, 0]]
torch.Size([4, 14, 19])
True
True


In [175]:
data_len = [len(x)-1 for x in data] # batch中,各seq的长度
x = rnn.pack_padded_sequence(data_padding_1, data_len, batch_first=True)
# lables = rnn.pack_padded_sequence(label_padding_1, data_len, batch_first=True)

raw = rnn.pad_packed_sequence(x,batch_first=True)


In [188]:
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.lstm = nn.LSTM(19,50,batch_first=True)
        self.output_layer = nn.Linear(50,19)
    
    def forward(self,x):
        output,(hn,hc) = self.lstm(x)
        
        raw_output,seq_len = rnn.pad_packed_sequence(output,batch_first=True)
        seqs = [seq[:length] for seq,length in zip(raw_output,seq_len)]
        output = torch.cat(seqs)

        output = self.output_layer(output)
        return output
    

def test_train():
    net = Net()
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    for i in range(1000):
        optimizer.zero_grad()
        res = net(x)
        print(res.size(),labels_index.size())
        loss = criterion(res,labels_index)
        print(loss.item())
        loss.backward()
        optimizer.step()
test_train()


torch.Size([41, 19]) torch.Size([41])
2.918750047683716
torch.Size([41, 19]) torch.Size([41])
2.914654016494751
torch.Size([41, 19]) torch.Size([41])
2.910555124282837
torch.Size([41, 19]) torch.Size([41])
2.906439781188965
torch.Size([41, 19]) torch.Size([41])
2.902294397354126
torch.Size([41, 19]) torch.Size([41])
2.8981027603149414
torch.Size([41, 19]) torch.Size([41])
2.8938474655151367
torch.Size([41, 19]) torch.Size([41])
2.889512300491333
torch.Size([41, 19]) torch.Size([41])
2.8850796222686768
torch.Size([41, 19]) torch.Size([41])
2.88053035736084
torch.Size([41, 19]) torch.Size([41])
2.8758463859558105
torch.Size([41, 19]) torch.Size([41])
2.8710038661956787
torch.Size([41, 19]) torch.Size([41])
2.8659815788269043
torch.Size([41, 19]) torch.Size([41])
2.8607521057128906
torch.Size([41, 19]) torch.Size([41])
2.8552892208099365
torch.Size([41, 19]) torch.Size([41])
2.849561929702759
torch.Size([41, 19]) torch.Size([41])
2.843538284301758
torch.Size([41, 19]) torch.Size([41])
2.8

1.2072595357894897
torch.Size([41, 19]) torch.Size([41])
1.1946924924850464
torch.Size([41, 19]) torch.Size([41])
1.1822588443756104
torch.Size([41, 19]) torch.Size([41])
1.1699599027633667
torch.Size([41, 19]) torch.Size([41])
1.157796859741211
torch.Size([41, 19]) torch.Size([41])
1.1457706689834595
torch.Size([41, 19]) torch.Size([41])
1.1338822841644287
torch.Size([41, 19]) torch.Size([41])
1.122132420539856
torch.Size([41, 19]) torch.Size([41])
1.110520601272583
torch.Size([41, 19]) torch.Size([41])
1.0990477800369263
torch.Size([41, 19]) torch.Size([41])
1.0877134799957275
torch.Size([41, 19]) torch.Size([41])
1.076516032218933
torch.Size([41, 19]) torch.Size([41])
1.0654538869857788
torch.Size([41, 19]) torch.Size([41])
1.0545237064361572
torch.Size([41, 19]) torch.Size([41])
1.0437222719192505
torch.Size([41, 19]) torch.Size([41])
1.0330464839935303
torch.Size([41, 19]) torch.Size([41])
1.0224931240081787
torch.Size([41, 19]) torch.Size([41])
1.0120586156845093
torch.Size([41, 

0.2635660767555237
torch.Size([41, 19]) torch.Size([41])
0.2607729732990265
torch.Size([41, 19]) torch.Size([41])
0.25801077485084534
torch.Size([41, 19]) torch.Size([41])
0.25527918338775635
torch.Size([41, 19]) torch.Size([41])
0.2525775134563446
torch.Size([41, 19]) torch.Size([41])
0.24990561604499817
torch.Size([41, 19]) torch.Size([41])
0.2472628653049469
torch.Size([41, 19]) torch.Size([41])
0.24464881420135498
torch.Size([41, 19]) torch.Size([41])
0.24206319451332092
torch.Size([41, 19]) torch.Size([41])
0.2395055592060089
torch.Size([41, 19]) torch.Size([41])
0.23697534203529358
torch.Size([41, 19]) torch.Size([41])
0.2344723641872406
torch.Size([41, 19]) torch.Size([41])
0.231996089220047
torch.Size([41, 19]) torch.Size([41])
0.2295462191104889
torch.Size([41, 19]) torch.Size([41])
0.22712235152721405
torch.Size([41, 19]) torch.Size([41])
0.22472411394119263
torch.Size([41, 19]) torch.Size([41])
0.22235125303268433
torch.Size([41, 19]) torch.Size([41])
0.2200034111738205
torc

0.062092166393995285
torch.Size([41, 19]) torch.Size([41])
0.061554793268442154
torch.Size([41, 19]) torch.Size([41])
0.06102399528026581
torch.Size([41, 19]) torch.Size([41])
0.06049955263733864
torch.Size([41, 19]) torch.Size([41])
0.05998159572482109
torch.Size([41, 19]) torch.Size([41])
0.05946982651948929
torch.Size([41, 19]) torch.Size([41])
0.058964285999536514
torch.Size([41, 19]) torch.Size([41])
0.05846472084522247
torch.Size([41, 19]) torch.Size([41])
0.0579712800681591
torch.Size([41, 19]) torch.Size([41])
0.057483870536088943
torch.Size([41, 19]) torch.Size([41])
0.057002101093530655
torch.Size([41, 19]) torch.Size([41])
0.05652627721428871
torch.Size([41, 19]) torch.Size([41])
0.056056104600429535
torch.Size([41, 19]) torch.Size([41])
0.055591605603694916
torch.Size([41, 19]) torch.Size([41])
0.05513262003660202
torch.Size([41, 19]) torch.Size([41])
0.05467917397618294
torch.Size([41, 19]) torch.Size([41])
0.05423098802566528
torch.Size([41, 19]) torch.Size([41])
0.053788

0.023143233731389046
torch.Size([41, 19]) torch.Size([41])
0.02302461676299572
torch.Size([41, 19]) torch.Size([41])
0.02290688455104828
torch.Size([41, 19]) torch.Size([41])
0.02279009483754635
torch.Size([41, 19]) torch.Size([41])
0.022674281150102615
torch.Size([41, 19]) torch.Size([41])
0.02255934104323387
torch.Size([41, 19]) torch.Size([41])
0.022445259615778923
torch.Size([41, 19]) torch.Size([41])
0.02233220264315605
torch.Size([41, 19]) torch.Size([41])
0.022220076993107796
torch.Size([41, 19]) torch.Size([41])
0.022108763456344604
torch.Size([41, 19]) torch.Size([41])
0.021998276934027672
torch.Size([41, 19]) torch.Size([41])
0.021888790652155876
torch.Size([41, 19]) torch.Size([41])
0.021780025213956833
torch.Size([41, 19]) torch.Size([41])
0.02167210914194584
torch.Size([41, 19]) torch.Size([41])
0.021565239876508713
torch.Size([41, 19]) torch.Size([41])
0.021459009498357773
torch.Size([41, 19]) torch.Size([41])
0.021353652700781822
torch.Size([41, 19]) torch.Size([41])
0.0

0.012326507829129696
torch.Size([41, 19]) torch.Size([41])
0.01228268537670374
torch.Size([41, 19]) torch.Size([41])
0.012239142321050167
torch.Size([41, 19]) torch.Size([41])
0.01219585444778204
torch.Size([41, 19]) torch.Size([41])
0.012152648530900478
torch.Size([41, 19]) torch.Size([41])
0.01210991945117712
torch.Size([41, 19]) torch.Size([41])
0.01206734124571085
torch.Size([41, 19]) torch.Size([41])
0.012025019153952599
torch.Size([41, 19]) torch.Size([41])
0.0119828712195158
torch.Size([41, 19]) torch.Size([41])
0.011941095814108849
torch.Size([41, 19]) torch.Size([41])
0.011899378150701523
torch.Size([41, 19]) torch.Size([41])
0.01185809075832367
torch.Size([41, 19]) torch.Size([41])
0.011816920712590218
torch.Size([41, 19]) torch.Size([41])
0.01177594717592001
torch.Size([41, 19]) torch.Size([41])
0.011735265143215656
torch.Size([41, 19]) torch.Size([41])
0.011694827117025852
torch.Size([41, 19]) torch.Size([41])
0.011654586531221867
torch.Size([41, 19]) torch.Size([41])
0.011

0.007947270758450031
torch.Size([41, 19]) torch.Size([41])
0.00792551040649414
torch.Size([41, 19]) torch.Size([41])
0.007903855293989182
torch.Size([41, 19]) torch.Size([41])
0.00788229238241911
torch.Size([41, 19]) torch.Size([41])
0.007860776968300343
torch.Size([41, 19]) torch.Size([41])
0.007839307188987732
torch.Size([41, 19]) torch.Size([41])
0.007817954756319523
torch.Size([41, 19]) torch.Size([41])
0.00779682258144021
torch.Size([41, 19]) torch.Size([41])
0.0077756671234965324
torch.Size([41, 19]) torch.Size([41])
0.007754651363939047
torch.Size([41, 19]) torch.Size([41])
0.00773384515196085
torch.Size([41, 19]) torch.Size([41])
0.007712829392403364
torch.Size([41, 19]) torch.Size([41])
0.007692069746553898
torch.Size([41, 19]) torch.Size([41])
0.0076714144088327885
torch.Size([41, 19]) torch.Size([41])
0.007650840561836958
torch.Size([41, 19]) torch.Size([41])
0.007630394771695137
torch.Size([41, 19]) torch.Size([41])
0.007609960623085499
torch.Size([41, 19]) torch.Size([41])