In [None]:
import gensim
from gensim.models import KeyedVectors
import numpy as np
import torch
from torch import nn, autograd, optim
import torch.nn.functional as F
import time
from tqdm import tqdm
import math

In [2]:
fvec = KeyedVectors.load_word2vec_format('vec.txt', binary=False)
word_vec = fvec.vectors
vocab = ['<PAD>', '<BOS>', '<EOS>', '<UNK>']
vocab.extend(list(fvec.vocab.keys()))
word_vec = np.concatenate((np.array([[0]*word_vec.shape[1]] * 4), word_vec))
word_vec = torch.tensor(word_vec)

In [3]:
word_to_idx = {ch: i for i, ch in enumerate(vocab)}
idx_to_word = {i: ch for i, ch in enumerate(vocab)}

In [4]:
essays = []
topics = []
with open('composition.txt', 'r') as f:
    for line in f:
        essay, topic = line.replace('\n', '').split(' </d> ')
        essays.append(essay.split(' '))
        topics.append(topic.split(' '))

In [5]:
corpus_indice = list(map(lambda x: [word_to_idx[w] for w in x], essays[:8000]))
topics_indice = list(map(lambda x: [word_to_idx[w] for w in x], topics[:8000]))

In [6]:
length = list(map(lambda x: len(x), corpus_indice))

In [7]:
def tav_data_iterator(corpus_indice, topics_indice, batch_size, num_steps):
    epoch_size = len(corpus_indice) // batch_size
    for i in range(epoch_size):
        raw_data = corpus_indice[i*batch_size: (i+1)*batch_size]
        key_words = topics_indice[i*batch_size: (i+1)*batch_size]
        data = np.zeros((len(raw_data), num_steps+1), dtype=np.int64)
        for i in range(batch_size):
            doc = raw_data[i]
            tmp = [1]
            tmp.extend(doc)
            tmp.extend([2])
            tmp = np.array(tmp, dtype=np.int64)
            _size = tmp.shape[0]
            data[i][:_size] = tmp
        key_words = np.array(key_words, dtype=np.int64)
        x = data[:, 0:num_steps]
        y = data[:, 1:]
        mask = np.float32(x != 0)
        x = torch.tensor(x)
        y = torch.tensor(y)
        mask = torch.tensor(mask)
        key_words = torch.tensor(key_words)
        yield(x, y, mask, key_words)

In [8]:
class TAVLSTM(nn.Module):
    def __init__(self, hidden_dim, embed_dim, num_layers, weight,
                 num_labels, bidirectional, dropout=0.5, **kwargs):
        super(TAVLSTM, self).__init__(**kwargs)
        self.hidden_dim = hidden_dim
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.num_labels = num_labels
        self.bidirectional = bidirectional
        if num_layers <= 1:
            self.dropout = 0
        else:
            self.dropout = dropout
        self.embedding = nn.Embedding.from_pretrained(weight)
        self.embedding.weight.requires_grad = False
        self.rnn = nn.GRU(input_size=self.embed_dim, hidden_size=self.hidden_dim,
                          num_layers=self.num_layers, bidirectional=self.bidirectional,
                          dropout=self.dropout)
        if self.bidirectional:
            self.decoder = nn.Linear(hidden_dim * 2, self.num_labels)
        else:
            self.decoder = nn.Linear(hidden_dim, self.num_labels)
        
    def forward(self, inputs, topics, hidden=None):
        embeddings = self.embedding(inputs)
        topics_embed = self.embedding(topics)
        topics_embed = topics_embed.mean(dim=1)
        for i in range(embeddings.shape[0]):
            embeddings[i][0] = topics_embed[i]
        states, hidden = self.rnn(embeddings.permute([1, 0, 2]).float(), hidden)
        outputs = self.decoder(states.reshape((-1, states.shape[-1])))
        return(outputs, hidden)
    
    def init_hidden(self, num_layers, batch_size, hidden_dim, **kwargs):
        hidden = torch.zeros(num_layers, batch_size, hidden_dim)
        return hidden

In [9]:
def predict_rnn(topics, num_chars, model, device, idx_to_word, word_to_idx):
    output = [1]
    topics = [word_to_idx[x] for x in topics]
    topics = torch.tensor(topics)
    hidden = torch.zeros(num_layers, 1, hidden_dim)
    if use_gpu:
        hidden = hidden.to(device)
        topics = topics.to(device)
#         hidden = hidden.cuda()
#         topics = topics.cuda()
    for t in range(num_chars):
        X = torch.tensor(output).reshape((1, len(output)))
        if use_gpu:
            X = X.to(device)
#             X = X.cuda()
        pred, hidden = model(X, topics, hidden)
        if pred.argmax(dim=1)[-1] == 2:
            break
        else:
            output.append(int(pred.argmax(dim=1)[-1]))
    return(''.join([idx_to_word[i] for i in output[1:]]))

In [10]:
embedding_dim = 300
hidden_dim = 256
lr = 1e2
momentum = 0.0
num_epoch = 100
use_gpu = True
num_layers = 1
bidirectional = False
batch_size = 8
device = torch.device('cuda:0')
loss_function = nn.CrossEntropyLoss()

In [11]:
model = TAVLSTM(hidden_dim=hidden_dim, embed_dim=embedding_dim, num_layers=num_layers,
                num_labels=len(vocab), weight=word_vec, bidirectional=bidirectional)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
if use_gpu:
#     model = nn.DataParallel(model)
    model.to(device)

In [12]:
predict_rnn(['妈妈', '希望', '长大', '孩子', '母爱'], 100, model, device, idx_to_word, word_to_idx)

'新欢开放王骞语句驱鬼盆向医术开关机锁于双生子开放泥土墙棕黄色王璐瑶人对祖了应消遥实际意义心爱陈若雨薄纱裙握手言和胶圈翻遍推逝染指王善哲审报狮园和辉实练教具战败断树不明时元时做害做害询询地转线询地转线询询地转线询地转线小王燕道德修养上趣上趣上趣上趣述诉上趣上趣述诉搭配着村谣述诉中数确立夺谢了恩车盘师才会内置毒奶道德败坏雪芊刚短短的个燥小器拒之门外夏虫咪对丧女之痛新欢开放略放林依轮玩鱼诱引金钢滆造完高脚千吨万语吴丝冰小福子书证出事开放父王揭瓦金钢滆造完'

In [13]:
since = time.time()
for epoch in range(num_epoch):
    start = time.time()
    num, total_loss = 0, 0
#     if epoch == 5000:
#         optimizer.param_groups[0]['lr'] = lr * 0.1
    data = tav_data_iterator(corpus_indice, topics_indice, batch_size, max(length)+1)
    hidden = model.init_hidden(num_layers, batch_size, hidden_dim)
    weight = torch.ones(len(vocab))
    weight[0] = 0
    for X, Y, mask, topics in tqdm(data):
        num += 1
        hidden.detach_()
        if use_gpu:
            X = X.to(device)
            Y = Y.to(device)
            mask = mask.to(device)
            topics = topics.to(device)
            hidden = hidden.to(device)
            weight = weight.to(device)
#             X = X.cuda()
#             Y = Y.cuda()
#             mask = mask.cuda()
#             topics = topics.cuda()
#             hidden = hidden.cuda()
        optimizer.zero_grad()
        output, hidden = model(X, topics, hidden)
        l = F.cross_entropy(output, Y.t().reshape((-1,)), weight)
        l.backward()
        norm = nn.utils.clip_grad_norm_(model.parameters(), 1e-2)
        optimizer.step()
        total_loss += l.item()
    end = time.time()
    s = end - since
    h = math.floor(s / 3600)
    m = s - h * 3600
    m = math.floor(m / 60)
    s -= m * 60
    if(epoch % 10 == 0) or (epoch == (num_epoch - 1)):
        print('epoch %d/%d, loss %.4f, norm %.4f, time %.3fs, since %dh %dm %ds'
              %(epoch+1, num_epoch, total_loss / num, norm, end-start, h, m, s))
        print(predict_rnn(['妈妈', '希望', '长大', '孩子', '母爱'], 100, model, device, idx_to_word, word_to_idx))

1000it [01:21, 12.29it/s]
2it [00:00, 11.76it/s]

epoch 1/100, loss 7.8234, norm 0.8318, time 81.356s, since 0h 1m 21s
，


1000it [01:22, 12.10it/s]
1000it [01:22, 12.16it/s]
1000it [01:23, 12.28it/s]
1000it [01:22, 12.13it/s]
1000it [01:23, 12.12it/s]
1000it [01:23, 11.64it/s]
1000it [01:23, 12.38it/s]
1000it [01:22, 12.17it/s]
1000it [01:22, 12.00it/s]
1000it [01:23, 11.93it/s]
0it [00:00, ?it/s]

epoch 11/100, loss 5.0515, norm 0.7263, time 83.162s, since 0h 15m 12s
，我要去寻找，我的心情，我要去看了。


1000it [01:22, 11.82it/s]
1000it [01:23, 11.86it/s]
1000it [01:22, 12.21it/s]
1000it [01:22, 12.12it/s]
1000it [01:22, 12.11it/s]
1000it [01:22, 12.12it/s]
1000it [01:22, 12.07it/s]
1000it [01:22, 12.01it/s]
1000it [01:23, 12.14it/s]
1000it [01:22, 12.00it/s]
2it [00:00, 12.03it/s]

epoch 21/100, loss 4.0981, norm 0.8440, time 82.996s, since 0h 29m 1s
的。


1000it [01:22, 11.76it/s]
1000it [01:23, 11.87it/s]
1000it [01:22, 12.25it/s]
1000it [01:22, 11.95it/s]
1000it [01:23, 12.18it/s]
1000it [01:24, 12.00it/s]
1000it [01:23, 12.16it/s]
1000it [01:23, 11.97it/s]
1000it [01:23, 12.04it/s]
1000it [01:24, 11.95it/s]


epoch 31/100, loss 3.6018, norm 1.0026, time 84.238s, since 0h 42m 57s


2it [00:00, 11.57it/s]

的地方。如果有机会，如果你会有机会，你就会遇到挫折，你会失去了你。”


1000it [01:24, 11.61it/s]
1000it [01:24, 11.83it/s]
1000it [01:23, 12.06it/s]
1000it [01:23, 12.04it/s]
1000it [01:23, 11.84it/s]
1000it [01:23, 12.00it/s]
1000it [01:23, 12.01it/s]
1000it [01:23, 12.07it/s]
1000it [01:23, 11.60it/s]
1000it [01:23, 11.93it/s]


epoch 41/100, loss 3.3253, norm 1.2088, time 83.560s, since 0h 56m 55s


2it [00:00, 11.86it/s]

。如果你能看见你会说“你遇到什么事，就会有时间。”


1000it [01:23, 11.94it/s]
1000it [01:23, 12.03it/s]
1000it [01:23, 12.09it/s]
1000it [01:23, 11.99it/s]
1000it [01:23, 12.16it/s]
1000it [01:24, 12.20it/s]
1000it [01:24, 12.02it/s]
1000it [01:24, 11.60it/s]
1000it [01:24, 11.65it/s]
1000it [01:23, 11.76it/s]
0it [00:00, ?it/s]

epoch 51/100, loss 3.1635, norm 1.1696, time 83.969s, since 1h 10m 3655s
。如果你有机会，我要珍惜时间。”


1000it [01:23, 11.67it/s]
1000it [01:24, 11.84it/s]
1000it [01:24, 11.95it/s]
1000it [01:23, 11.84it/s]
1000it [01:23, 12.04it/s]
1000it [01:23, 11.91it/s]
1000it [01:23, 11.44it/s]
1000it [01:23, 11.99it/s]
1000it [01:23, 11.91it/s]
1000it [01:23, 12.07it/s]


epoch 61/100, loss 3.0867, norm 1.6841, time 83.741s, since 1h 24m 3654s


2it [00:00, 12.01it/s]

。如果你如果你来了，我就会遇到什么困难，你就会去尝试。


1000it [01:23, 11.68it/s]
1000it [01:23, 11.66it/s]
1000it [01:23, 12.13it/s]
1000it [01:23, 12.10it/s]
1000it [01:23, 11.56it/s]
1000it [01:23, 12.14it/s]
1000it [01:24, 12.13it/s]
1000it [01:23, 12.11it/s]
1000it [01:23, 11.87it/s]
1000it [01:23, 11.97it/s]


epoch 71/100, loss 3.1119, norm 4.3726, time 83.946s, since 1h 38m 3653s


2it [00:00, 12.17it/s]

。如果你能看见你的朋友之间，你会在你面前的你会生气地对你说：“你一定会开花，所以我要珍惜，爱护你们的。”


1000it [01:24, 12.23it/s]
1000it [01:24, 11.93it/s]
1000it [01:23, 12.19it/s]
1000it [01:23, 11.42it/s]
1000it [01:24, 11.99it/s]
1000it [01:24, 11.73it/s]
1000it [01:23, 12.04it/s]
1000it [01:24, 11.41it/s]
1000it [01:23, 11.90it/s]
1000it [01:24, 11.91it/s]


epoch 81/100, loss 3.3210, norm 3.7118, time 84.119s, since 1h 52m 3653s


2it [00:00, 11.47it/s]

。如果你看你会发现你，你就会在你面前炫耀，你一定会在你面前炫耀。


1000it [01:23, 11.95it/s]
1000it [01:24, 11.53it/s]
1000it [01:23, 12.00it/s]
1000it [01:24, 12.03it/s]
1000it [01:23, 12.05it/s]
1000it [01:23, 12.04it/s]
1000it [01:23, 12.15it/s]
1000it [01:24, 11.44it/s]
1000it [01:24, 12.07it/s]
1000it [01:23, 11.98it/s]


epoch 91/100, loss 3.7038, norm 13.2962, time 83.470s, since 2h 6m 7252s


0it [00:00, ?it/s]

。我终于到了，我的手上拿着一只袜子，有的耳朵，想，我想，长大了，我想，长大了，我想长大了。


1000it [01:23, 11.69it/s]
1000it [01:23, 11.92it/s]
1000it [01:24, 12.03it/s]
1000it [01:23, 11.46it/s]
1000it [01:23, 12.01it/s]
1000it [01:23, 12.08it/s]
1000it [01:23, 11.86it/s]
1000it [01:23, 12.03it/s]
1000it [01:23, 12.21it/s]


epoch 100/100, loss 5.3755, norm 407.2407, time 83.604s, since 2h 19m 7227s
的教室里，在这个时候，我的哦！’，我的哦！啊！我的，像一面，活的眼泪。我感谢的是呀！我爱的，爱叶的天使，像一面的仙女，高高高，但她的是农民伯伯伯伯伯伯说：“孩子们，微笑”理解的，感谢的孩子们感谢的孩子们感谢感恩感恩感恩感恩感恩感恩感恩感恩感恩感恩感恩感恩感恩感恩感恩感恩感恩感恩
