In [2]:
import numpy as np
from gensim.models.word2vec import Word2Vec
import pickle
import os
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.nn import Sequential, LSTM, Dropout, Flatten, Linear, CrossEntropyLoss
import torch.optim.adam

# data split(insert space between each word)
1.   prepare the split dataset for word2vec



In [16]:
def split_poetry(file = "poetry_5.txt"):
 all_data = open(file, 'r', encoding = 'utf-8').read()
 all_data_split = " ".join(all_data)
 with open("split.txt", 'w', encoding = 'utf-8') as f:
   f.write(all_data_split)
 return all_data.split('\n')


# Word2Vec convert peom into vector
2 version：
1.   punctuation exist
2.   no punctuation






In [17]:
# with punctuations
def train_vec(split_file = 'split.txt', org_file = "poetry_5.txt", num_train=300):
  if not os.path.exists('split.txt'):
    split_file = split_poetry()
  org_data = open(org_file, 'r', encoding='utf-8').read().split('\n')[:num_train]
  split_data = open(split_file, 'r', encoding='utf-8').read().split('\n')
  split_data[0] = ' ' + split_data[0]

  model = Word2Vec(split_data, size=128, min_count=1, workers=7)
  #implement word2index based on index2word
  word2index = {token: token_index for token_index, token in enumerate(model.wv.index2word)} 
  return model, org_data, model.syn1neg, model.wv.index2word, word2index

In [25]:
split_file = 'split.txt'
split_data = open(split_file, 'r', encoding='utf-8').read().split('\n')

In [19]:
#remove punctuations
def punctuation_remove(file_name):
  file_name = file_name.replace('。','')
  file_name = file_name.replace('，','')
  return file_name

# without punctuations
def train_vec_nopunc(split_file = 'split.txt', org_file = 'poetry_5.txt', train_num=300):
  split_file = split_poetry(file = org_file)
  #remove punctuation in org_data
  org_data = open(org_file, 'r', encoding='utf-8').read()
  org_data = punctuation_remove(org_data)
  org_data = org_data.split('\n')[:train_num]
  #remove punctuation in split_data
  split_data = open(split_file, 'r', encoding='utf-8').read()
  split_data = punctuation_remove(split_data)
  split_data = split_data.split('\n')
  split_data[0] = ' ' + split_data[0]

  model = Word2Vec(split_data, size=128, min_count=1, workers=7)
  #implement word2index based on index2word
  word2index = {token: token_index for token_index, token in enumerate(model.wv.index2word)} 
  return model, org_data, model.syn1neg, model.wv.index2word, word2index

# data_set
1.   all_data: dataset (list)
2.   shape of w1:[word_size(number of poems), embedding numbers] which should be [5364, 101]



In [20]:
class Mydataset(Dataset):
  def __init__(self, all_data, w1, word_2_index, index_2_word ):
    self.w1=w1
    self.word_2_index = word_2_index
    self.all_data = all_data
    self.index_2_word = index_2_word
  
  def __getitem__(self, index):
    a_poetry_words = self.all_data[index]
    a_poetry_index = [self.word_2_index[word] for word in a_poetry_words]
    xs_index = a_poetry_index[:-1]
    ys_index = a_poetry_index[1:]
    xs_embedding = self.w1[xs_index]
    return xs_embedding, np.array(ys_index).astype(np.int64)

  def __len__(self):
    return len(self.all_data)


# LSTM MODEL


In [21]:
class lstm_model(nn.Module):
  def __init__(self, embedding_num, hidden_num, word_size):
    super().__init__()
    self.embedding_num = embedding_num
    self.hidden_num = hidden_num
    self.word_size = word_size
    self.lstm = LSTM(input_size = embedding_num, hidden_size=hidden_num, batch_first = True)
    self.model = Sequential(Dropout(0.3),
                            Flatten(0,1),
                            Linear(hidden_num, word_size)
                            )   
    self.cross_entropy = CrossEntropyLoss()
    self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
  def forward(self, xs_embedding, h_0=None, c_0=None):
    if  h_0 == None and c_0 == None:
      h_0 = torch.tensor(np.zeros((1, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
      c_0 = torch.tensor(np.zeros((1, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
    h_0 = h_0.to(self.device)
    c_0 = c_0.to(self.device)
    pre, (h, c) = self.lstm(xs_embedding, (h_0, c_0))
    return self.model(pre), (h, c)

# auto_generater

In [29]:
def auto_generater(word_size, model, hidden_num):
  result = ''
  word_index = np.random.randint(0, word_size, 1)[0]
  result += index_2_word[word_index]
  h_0 = torch.tensor(np.zeros((1, 1, hidden_num), dtype=np.float32))
  c_0 = torch.tensor(np.zeros((1, 1, hidden_num), dtype=np.float32))
  for i in range(0, 23):
    word_embeddings = torch.tensor(w1[word_index][None][None])
    pre, (h_0,c_0) = model(word_embeddings, h_0, c_0)
    word_index = int(torch.argmax(pre))
    result += index_2_word[word_index]
  
  return result


# acrostic_generater
a verse in which the initial characters of the lines form a word/phrase; 



In [22]:
def acrostic_generater(word_size, model, hidden_num, input_text):
  result = ''
  h_0 = torch.tensor(np.zeros((1, 1, hidden_num), dtype=np.float32))
  c_0 = torch.tensor(np.zeros((1, 1, hidden_num), dtype=np.float32))
  punctuation_list = '，。，。'
  for i in range(0, 4):
    word_index = word_2_index[input_text[i]]
    result+=input_text[i]
    for i in range(0, 4):
      word_embeddings = torch.tensor(w1[word_index][None][None])
      pre, (h_0,c_0) = model(word_embeddings, h_0, c_0)
      word_index = int(torch.argmax(pre))
      result += index_2_word[word_index]
    result += punctuation_list[i]
  return result

# poem generation

In [26]:
batch_size =15 #number of words
lr = 0.01 #learning rate
model_vec, all_data, w1, index_2_word, word_2_index = train_vec()
word_size, embedding_num = w1.shape
hidden_num = 51
model = lstm_model(embedding_num, hidden_num, word_size)
data_set = Mydataset(all_data, w1, word_2_index, index_2_word)
data_loader = DataLoader(data_set, batch_size = batch_size)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

  if sys.path[0] == '':


In [30]:
epochs = 1000
for epoch in range(0, epochs):
  print('--- round {} generation ---'.format(epoch+1))
  for batch_index, (xs_embedding, ys_index) in enumerate(data_loader):
    pre,(h_0, c_0) = model(xs_embedding)
    loss = model.cross_entropy(pre, ys_index.reshape(-1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if batch_index%50==0:
      result = auto_generater(word_size, model, hidden_num)
      print('beautiful chinese 5 words poem:{} loss:{}'.format(result, loss))

--- round 1 generation ---
beautiful chinese 5 words poem:砥山嗟山山嗟殊山殊拟雷城宁城，，甲城，山，，嗟嗟 loss:8.189482688903809
--- round 2 generation ---
beautiful chinese 5 words poem:苓。。，，，。，，，，。，，，，，，，，，。。， loss:6.597034931182861
--- round 3 generation ---
beautiful chinese 5 words poem:衲。。。。。，。。。。，，，，。，。。。。，。。 loss:6.6297454833984375
--- round 4 generation ---
beautiful chinese 5 words poem:簰，，，，，，。，。。，，，，。，，，，，，，， loss:6.540030002593994
--- round 5 generation ---
beautiful chinese 5 words poem:嗟，风，，，，风，，，，，风，，，，，风，风，， loss:6.50416374206543
--- round 6 generation ---
beautiful chinese 5 words poem:迎风风，风，风，风风，风，风，，风，风，风，风风 loss:6.3922882080078125
--- round 7 generation ---
beautiful chinese 5 words poem:岔海风，风无，，风风无，，风，不，无风，，风风， loss:6.253900051116943
--- round 8 generation ---
beautiful chinese 5 words poem:湘风无无，海风风，无风无，，风风无，，不不无，， loss:6.012547969818115
--- round 9 generation ---
beautiful chinese 5 words poem:严雨风无，，不无无，，无无无，。不不一，，不不无 loss:5.903372287750244
--- round 10 generation ---
beautiful chinese

In [None]:
%debug


> [0;32m<ipython-input-195-295b1f124138>[0m(8)[0;36macrostic_generater[0;34m()[0m
[0;32m      6 [0;31m  [0;32mfor[0m [0mi[0m [0;32min[0m [0mrange[0m[0;34m([0m[0;36m0[0m[0;34m,[0m [0;36m4[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      7 [0;31m    [0mword_index[0m [0;34m=[0m [0mword_2_index[0m[0;34m[[0m[0minput_text[0m[0;34m[[0m[0mi[0m[0;34m][0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 8 [0;31m    [0mresult[0m[0;34m+=[0m[0mimput_text[0m[0;34m[[0m[0mi[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m    [0;32mfor[0m [0mi[0m [0;32min[0m [0mrange[0m[0;34m([0m[0;36m0[0m[0;34m,[0m [0;36m7[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     10 [0;31m      [0mword_embeddings[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtensor[0m[0;34m([0m[0mw1[0m[0;34m[[0m[0mword_index[0m[0;34m][0m[0;34m[[0m[0;32mNone[0m[0;34m][0m[0;34m[[0m[0;32mNone[