In [1]:
import numpy as np

tang_file = np.load("tang.npz",allow_pickle=True)
tang_file.files

['ix2word', 'word2ix', 'data']

In [2]:
data = tang_file['data']
word2ix = tang_file['word2ix'].item()
idx2word = tang_file['ix2word'].item()

In [3]:
def idx2poem(idx_poem):
    poem = []
    for id in idx_poem:
        poem.append(idx2word[id])
    return "".join(poem)

In [4]:
idx2poem(data[314])

'</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><START>南楼夜已寂，暗鸟动林间。不见城郭事，沈沈唯四山。<EOP>'

## 对数据进行预处理，去除所有唐诗中的填充符号 \</s\>

In [5]:
word2ix["</s>"]

8292

In [6]:
poems = []
for poem in data:
    for index,ix in enumerate(poem):
        if ix == word2ix["</s>"]:
            continue
        else:
            break
    poems.append(poem[index:])

In [7]:
# del data
poems[0],idx2poem(poems[0]),len(idx2word)

(array([8291, 6731, 4770, 1787, 8118, 7577, 7066, 4817,  648, 7121, 1542,
        6483, 7435, 7686, 2889, 1671, 5862, 1949, 7066, 2596, 4785, 3629,
        1379, 2703, 7435, 6064, 6041, 4666, 4038, 4881, 7066, 4747, 1534,
          70, 3788, 3823, 7435, 4907, 5567,  201, 2834, 1519, 7066,  782,
         782, 2063, 2031,  846, 7435, 8290], dtype=int32),
 '<START>度门能不访，冒雪屡西东。已想人如玉，遥怜马似骢。乍迷金谷路，稍变上阳宫。还比相思意，纷纷正满空。<EOP>',
 8293)

构建X,Y训练集。

In [8]:
seq_len = 48
X = []
Y = []
poems_data = [j for i in poems for j in i]

for i in range(0,len(poems_data) - seq_len -1,seq_len):
    X.append(poems_data[i:i+seq_len])
    Y.append(poems_data[i+1:i+seq_len+1])


In [9]:
from torch.utils.data import DataLoader,Dataset
import torch
class PoemDataset(Dataset):

    def __init__(self,X,Y):
        self.X = X
        self.Y = Y
        self.len = len(X)
    def __getitem__(self,index):
        x = np.array(X[index])
        y = np.array(Y[index])
        return torch.from_numpy(x).long(),torch.from_numpy(y).long()
    def __len__(self):
        return self.len
        
data_loader = DataLoader(PoemDataset(X,Y),batch_size=1024,num_workers=2)

In [10]:
a,b = next(iter(data_loader))
a.shape,b.shape

(torch.Size([1024, 48]), torch.Size([1024, 48]))

In [11]:
import torch
import torch.nn.functional as F
import torch.nn as nn
class PoemNet(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        """
            vocab_size：训练集合字典大小（8293）
            embedding_dim：word2vec的维度
            hidden_dim：LSTM的hidden_dim
        """
        super(PoemNet, self).__init__()
        self.hidden_dim = hidden_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, self.hidden_dim,batch_first=True)

        self.fc = nn.Sequential(
            nn.Linear(self.hidden_dim,2048),
            nn.ReLU(),
            nn.Dropout(0.25),
            
            nn.Linear(2048,4096),
            nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(4096,vocab_size),
        )

    def forward(self, input,hidden=None):
        """
            input：输入的诗词
            hidden：在生成诗词的时候需要使用，在pytorch中，，如果不指定初始状态h_0和C_0，则其
            默认为0.
            pytorch的LSTM的输出是(output,(h_n,c_n))。实际上，output就是h_1,h_2,……h_n
        """
        embeds = self.embeddings(input)
        batch_size, seq_len = input.size()
        if hidden is None:
            output, hidden = self.lstm(embeds)
        else:
            # h_0,c_0 = hidden
            output, hidden = self.lstm(embeds,hidden)
    
        output = self.fc(output)
        output = output.reshape(batch_size * seq_len, -1)
        output = F.log_softmax(output,dim=1)
        return output,hidden

In [12]:
vocab_size = len(word2ix.keys()) # 8293
embedding_dim = 200
hidden_dim = 1024

In [13]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
my_net = PoemNet(vocab_size,embedding_dim,hidden_dim).to(device)

In [14]:
# x = torch.randint(1,1412,(16,32)).to(device)
# y  = my_net(x)

In [15]:
import torch.optim as optim
optimzer = optim.Adam(my_net.parameters(),lr=0.001)
loss_function = nn.CrossEntropyLoss()

## 开始训练

In [16]:
for epoch in range(100):
    my_net.train()
    losses = []
    for i,data in enumerate(data_loader):
        inputs,target = data[0].to(device),data[1].to(device)
        optimzer.zero_grad()
        outputs,_= my_net(inputs)
        loss = loss_function(outputs,target.view(-1))
        loss.backward()
        optimzer.step()
        losses.append(loss.item())
    print(np.mean(losses))
    

6.86594607681036
6.137322202324867
5.941264726221561
5.740553066134453
5.524653658270836
5.339044913649559
5.185250006616116
5.062027506530285
4.958294026553631
4.8608838096261024
4.768598310649395
4.688210293650627
4.614417053759098
4.543813459575176
4.477163672447205
4.415464676916599
4.357650965452194
4.303020555526018
4.250680532306433
4.202439229935408
4.153485104441643
4.10454773157835
4.052609637379646
4.00092438608408
3.9505771547555923
3.9029736556112766
3.856202345341444
3.815825156867504
3.7771109379827976
3.7380864024162292
3.6966686323285103
3.656085681170225
3.6139936223626137
3.5736645311117172
3.533860892057419
3.4995048344135284
3.469289354979992
3.4374887943267822
3.401293881237507
3.357454866170883
3.316337510943413
3.2771004550158978
3.240685235708952
3.200998354703188
3.1607166416943073
3.1222167909145355
3.085622824728489
3.047156136482954
3.011751551181078
2.9764722660183907
2.9404236376285553
2.9068773686885834
2.876134227961302
2.8468081541359425
2.823373056948

In [19]:
torch.save(my_net,"model.h5")