In [2]:
import gensim
from gensim.models import KeyedVectors
import numpy as np
import torch
from torch import nn, autograd, optim
import torch.nn.functional as F
from sru import SRU, SRUCell
import time
import math
from tqdm import tqdm
import collections
import random

In [3]:
fvec = KeyedVectors.load_word2vec_format('vec_100d.txt', binary=False)
word_vec = fvec.vectors
vocab = ['<PAD>', '<BOS>', '<EOS>', '<UNK>']
vocab.extend(list(fvec.vocab.keys()))
word_vec = np.concatenate((np.array([[0]*word_vec.shape[1]] * 4), word_vec))
word_vec = torch.tensor(word_vec).float()

In [4]:
word_to_idx = {ch: i for i, ch in enumerate(vocab)}
idx_to_word = {i: ch for i, ch in enumerate(vocab)}

In [5]:
essays = []
topics = []
with open('composition.txt', 'r') as f:
    for line in f:
        essay, topic = line.replace('\n', '').split(' </d> ')
        words = essay.split(' ')
        new_words = []
        for word in words:
            if word in word_to_idx:
                new_words.append(word)
            else:
                for i in range(len(word)):
                    new_words.append(word[i])
        essays.append(new_words)
        topics.append(topic.split(' '))

In [24]:
corpus_indice = list(map(lambda x: [word_to_idx[w] for w in x], essays[:80]))
topics_indice = list(map(lambda x: [word_to_idx[w] for w in x], topics[:80]))
corpus_test = list(map(lambda x: [word_to_idx[w] for w in x], essays[80:100]))
topics_test = list(map(lambda x: [word_to_idx[w] for w in x], topics[80:100]))

In [25]:
length = list(map(lambda x: len(x), corpus_indice))

In [51]:
def tav_data_iterator(corpus_indice, topics_indice, batch_size, num_steps):
    epoch_size = len(corpus_indice + batch_size - 1) // batch_size
    for i in range(epoch_size):
        raw_data = corpus_indice[i*batch_size: (i+1)*batch_size]
        key_words = topics_indice[i*batch_size: (i+1)*batch_size]
        data = np.zeros((len(raw_data), num_steps+1), dtype=np.int64)
        for i in range(batch_size):
            doc = raw_data[i]
            tmp = [1]
            tmp.extend(doc)
            tmp.extend([2])
            tmp = np.array(tmp, dtype=np.int64)
            _size = tmp.shape[0]
            data[i][:_size] = tmp
        key_words = np.array(key_words, dtype=np.int64)
        x = data[:, 0:num_steps]
        y = data[:, 1:]
        mask = np.float32(x != 0)
        x = torch.tensor(x)
        y = torch.tensor(y)
        mask = torch.tensor(mask)
        key_words = torch.tensor(key_words)
        yield(x, y, mask, key_words)

In [41]:
class TATLSTM(nn.Module):
    def __init__(self, hidden_dim, embed_dim, num_layers, weight,
                 num_labels, bidirectional, dropout=0.5, **kwargs):
        super(TATLSTM, self).__init__(**kwargs)
        self.hidden_dim = hidden_dim
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.num_labels = num_labels
        self.bidirectional = bidirectional
        if num_layers <= 1:
            self.dropout = 0
        else:
            self.dropout = dropout
        self.embedding = nn.Embedding.from_pretrained(weight)
        self.embedding.weight.requires_grad = False
        self.rnn = nn.GRU(input_size=self.embed_dim, hidden_size=self.hidden_dim,
                          num_layers=self.num_layers, bidirectional=self.bidirectional,
                          dropout=self.dropout).to(device)
        if self.bidirectional:
            self.decoder = nn.Linear(
                hidden_dim * 2 + self.embed_dim, 1000).to(device)
        else:
            self.decoder = nn.Linear(
                hidden_dim + self.embed_dim, 1000).to(device)
#         self.attn = nn.Linear(self.embed_dim * 5, self.embed_dim)
        self.attn = nn.Linear(2, 1, bias=False)
        self.attn.weight.data.fill_(1)
        

    def forward(self, inputs, topics, hidden=None):
        embeddings = self.embedding(inputs).to(device)
        topics_embed = self.embedding(topics)
        topics_attn = torch.zeros(
            topics_embed.shape[0], 1, self.embed_dim)
        for i in range(len(topics_embed)):
            topics_attn[i] = self.attn(topics_embed[i].t()).t()
#         topics_attn = self.attn(topics_embed.reshape((topics_embed.shape[0], -1)))
#         topics_attn.unsqueeze_(-1)
        topics_attn = topics_attn.permute([0, 2, 1])
        self.rnn.flatten_parameters()
        if hidden is None:
            states, hidden = self.rnn(embeddings.permute([1, 0, 2]))
        else:
            states, hidden = self.rnn(
                embeddings.permute([1, 0, 2]), hidden)
        topics_attn = topics_attn.expand(
            topics_attn.shape[0], topics_attn.shape[1], states.shape[0])
        topics_attn = topics_attn.permute([2, 0, 1]).to(device)
        states_with_topic = torch.cat([states, topics_attn], dim=2)
        outputs = self.decoder(states_with_topic.reshape(
            (-1, states_with_topic.shape[-1])))
        return(outputs, hidden)

    def init_hidden(self, num_layers, batch_size, hidden_dim, **kwargs):
        hidden = torch.zeros(num_layers, batch_size, hidden_dim)
        return hidden

In [42]:
def predict_rnn_ada(topics, num_chars, model, idx_to_word, word_to_idx):
    output = [1]
    topics = [word_to_idx[x] for x in topics]
    topics = torch.tensor(topics)
    topics = topics.reshape((1, topics.shape[0]))
    hidden = torch.zeros(num_layers, 1, hidden_dim)
    if use_gpu:
        hidden = hidden.to(device)
        adaptive_softmax.to(device)
#         topics = topics.to(device)
    for t in range(num_chars):
        X = torch.tensor(output[-1]).reshape((1, 1))
#         X = torch.tensor(output).reshape((1, len(output)))
#         if use_gpu:
#             X = X.to(device)
        pred, hidden = model(X, topics, hidden)
        pred = adaptive_softmax.predict(pred)
        if pred[-1] == 2:
            break
        else:
            output.append(int(pred[-1]))
#             output.append(int(pred.argmax(dim=1)[-1]))
    return(''.join([idx_to_word[i] for i in output[1:]]), output[1:])

In [43]:
def bleu(pred_tokens, label_tokens, k):
    len_pred, len_label = len(pred_tokens), len(label_tokens)
    score = math.exp(min(0, 1 - len_label / len_pred))
    for n in range(1, k+1):
        num_matches = 0
        for i in range(len_pred - n + 1):
            if ' '.join(pred_tokens[i: i + n]) in ' '.join(label_tokens):
                num_matches += 1
        score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))
    return score

In [57]:
embedding_dim = 300
hidden_dim = 256
lr = 1e2
momentum = 0.1
num_epoch = 300
use_gpu = True
num_layers = 1
bidirectional = False
batch_size = 1
verbose = 5
vocab_size = len(vocab)
device = torch.device('cuda:0')
loss_function = nn.CrossEntropyLoss()
adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
    1000, len(vocab), cutoffs=[round(vocab_size / 20), 4*round(vocab_size / 20)])

In [58]:
model = TATLSTM(hidden_dim=hidden_dim, embed_dim=embedding_dim, num_layers=num_layers,
                num_labels=len(vocab), weight=word_vec, bidirectional=bidirectional)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
# if use_gpu:
#     model = nn.DataParallel(model)
#     model.to(device)
# optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters)
# hvd.broadcast_parameters(model.state_dict(), root_rank=0)

In [59]:
predict_rnn_ada(['变形金刚', '三星级'], 100, model, idx_to_word, word_to_idx)[0]

'即便即便炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭炒饭'

In [60]:
since = time.time()
for epoch in range(num_epoch):
    start = time.time()
    num, total_loss = 1, 0
#     if epoch == 5000:
#         optimizer.param_groups[0]['lr'] = lr * 0.1
    data = tav_data_iterator(
        corpus_indice, topics_indice, batch_size, max(length) + 1)
#     hidden = model.module.init_hidden(num_layers, batch_size, hidden_dim)
    weight = torch.ones(len(vocab))
    weight[0] = 0
    for X, Y, mask, topics in tqdm(data):
        num += 1
#         hidden.detach_()
        if use_gpu:
            #             X = X.to(device)
            Y = Y.to(device)
            mask = mask.to(device)
#             topics = topics.to(device)
#             hidden = hidden.to(device)
            weight = weight.to(device)
        optimizer.zero_grad()
#         output, hidden = model(X, topics, hidden)
        output, hidden = model(X, topics)
        hidden.detach_()
#         l = F.cross_entropy(output, Y.t().reshape((-1,)), weight)
        l, _ = adaptive_softmax(output, Y.t().reshape((-1,)))
        loss = -l.reshape((-1, batch_size)).t() * mask
        loss = loss.sum(dim=1) / mask.sum(dim=1)
        loss = loss.mean()
        loss.backward()
        norm = nn.utils.clip_grad_norm_(model.parameters(), 1e-2)
        optimizer.step()
        total_loss += loss.item()
        params = model.state_dict()
        params['attn.weight'].clamp_(1)
    end = time.time()
    s = end - since
    h = math.floor(s / 3600)
    m = s - h * 3600
    m = math.floor(m / 60)
    s -= (m * 60 + h * 3600)
    if((epoch + 1) % verbose == 0) or (epoch == (num_epoch - 1)):
        bleu_score = 0
#         for i in range(len(corpus_test)):
#             doc = corpus_test[i]
#             _, pred = predict_rnn([idx_to_word[x] for x in topics_test[0]],
#                                   100, model, device, idx_to_word, word_to_idx)
#             bleu_score += bleu([idx_to_word[int(x)] for x in pred],
#                                [idx_to_word[x] for x in doc if x not in [0, 2]], k=2)
        print('epoch %d/%d, loss %.4f, norm %.4f, predict bleu: %.4f, time %.3fs, since %dh %dm %ds'
              % (epoch + 1, num_epoch, total_loss / num, norm, bleu_score / 800, end - start, h, m, s))
        print(predict_rnn_ada(['变形金刚', '三星级'],
                              100, model, idx_to_word, word_to_idx)[0])

47it [00:00, 49.76it/s]
47it [00:00, 53.37it/s]
47it [00:00, 52.11it/s]
47it [00:00, 52.21it/s]
47it [00:00, 52.24it/s]


epoch 5/300, loss 6.6303, norm 1.1307, predict bleu: 0.0000, time 0.902s, since 0h 0m 4s


10it [00:00, 48.39it/s]

<PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD>


47it [00:00, 53.16it/s]
47it [00:00, 50.30it/s]
47it [00:00, 51.62it/s]
47it [00:00, 52.58it/s]
47it [00:00, 52.54it/s]
5it [00:00, 47.30it/s]

epoch 10/300, loss 5.3223, norm 1.5192, predict bleu: 0.0000, time 0.896s, since 0h 0m 9s
你你童年梦


47it [00:00, 52.42it/s]
47it [00:00, 52.19it/s]
47it [00:00, 52.42it/s]
47it [00:00, 51.90it/s]
47it [00:00, 53.80it/s]
5it [00:00, 47.30it/s]

epoch 15/300, loss 3.8957, norm 1.8362, predict bleu: 0.0000, time 0.877s, since 0h 0m 13s
你把我的童年记忆给毁的了


47it [00:00, 51.87it/s]
47it [00:00, 51.58it/s]
47it [00:00, 52.56it/s]
47it [00:00, 52.54it/s]
47it [00:00, 53.27it/s]
6it [00:00, 53.24it/s]

epoch 20/300, loss 2.6489, norm 1.5467, predict bleu: 0.0000, time 0.885s, since 0h 0m 18s
你把我的童年记忆给毁的了


47it [00:00, 54.42it/s]
47it [00:00, 51.84it/s]
47it [00:00, 51.39it/s]
47it [00:00, 52.71it/s]
47it [00:00, 52.62it/s]
5it [00:00, 46.86it/s]

epoch 25/300, loss 1.7989, norm 1.3062, predict bleu: 0.0000, time 0.895s, since 0h 0m 22s
你把我的童年记忆给毁的了


47it [00:00, 53.19it/s]
47it [00:00, 52.30it/s]
47it [00:00, 52.38it/s]
47it [00:00, 51.95it/s]
47it [00:00, 51.90it/s]
5it [00:00, 47.75it/s]

epoch 30/300, loss 1.2048, norm 0.4773, predict bleu: 0.0000, time 0.908s, since 0h 0m 27s
你把我的童年记忆给毁的了


47it [00:00, 51.27it/s]
47it [00:00, 52.22it/s]
47it [00:00, 52.59it/s]
47it [00:00, 52.48it/s]
47it [00:00, 51.43it/s]
5it [00:00, 48.67it/s]

epoch 35/300, loss 0.8187, norm 0.2958, predict bleu: 0.0000, time 0.917s, since 0h 0m 32s
你把我的童年记忆给毁的了


47it [00:00, 51.52it/s]
47it [00:00, 53.36it/s]
47it [00:00, 51.77it/s]
47it [00:00, 53.21it/s]
47it [00:00, 52.62it/s]
5it [00:00, 45.17it/s]

epoch 40/300, loss 0.5628, norm 0.1772, predict bleu: 0.0000, time 0.896s, since 0h 0m 36s
你把我的童年记忆给毁的了


47it [00:00, 52.06it/s]
47it [00:00, 51.38it/s]
47it [00:00, 53.53it/s]
47it [00:00, 52.12it/s]
47it [00:00, 53.46it/s]
5it [00:00, 49.15it/s]

epoch 45/300, loss 0.4038, norm 0.1618, predict bleu: 0.0000, time 0.882s, since 0h 0m 41s
你把我的童年记忆给毁的了


47it [00:00, 53.16it/s]
47it [00:00, 52.47it/s]
47it [00:00, 52.35it/s]
47it [00:00, 53.20it/s]
47it [00:00, 52.02it/s]
5it [00:00, 47.26it/s]

epoch 50/300, loss 0.3274, norm 0.1691, predict bleu: 0.0000, time 0.907s, since 0h 0m 45s
你把我的童年记忆给毁的了


47it [00:00, 52.07it/s]
47it [00:00, 51.12it/s]
47it [00:00, 50.97it/s]
47it [00:00, 51.13it/s]
47it [00:00, 50.52it/s]
6it [00:00, 51.42it/s]

epoch 55/300, loss 0.3009, norm 0.1814, predict bleu: 0.0000, time 0.935s, since 0h 0m 50s
你把我的童年记忆给毁的了


47it [00:00, 51.84it/s]
47it [00:00, 51.79it/s]
47it [00:00, 53.19it/s]
47it [00:00, 52.38it/s]
47it [00:00, 52.53it/s]
6it [00:00, 50.13it/s]

epoch 60/300, loss 0.2843, norm 0.1858, predict bleu: 0.0000, time 0.897s, since 0h 0m 54s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 53.20it/s]
47it [00:00, 51.20it/s]
47it [00:00, 51.90it/s]
47it [00:00, 52.41it/s]
47it [00:00, 52.58it/s]
5it [00:00, 46.43it/s]

epoch 65/300, loss 0.2711, norm 0.1878, predict bleu: 0.0000, time 0.896s, since 0h 0m 59s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 51.77it/s]
47it [00:00, 52.07it/s]
47it [00:00, 53.03it/s]
47it [00:00, 52.63it/s]
47it [00:00, 50.06it/s]
5it [00:00, 44.76it/s]

epoch 70/300, loss 0.2618, norm 0.1905, predict bleu: 0.0000, time 0.941s, since 0h 1m 3s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 49.24it/s]
47it [00:00, 49.45it/s]
47it [00:00, 49.46it/s]
47it [00:00, 49.73it/s]
47it [00:00, 52.06it/s]
5it [00:00, 49.15it/s]

epoch 75/300, loss 0.2563, norm 0.1904, predict bleu: 0.0000, time 0.904s, since 0h 1m 8s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 52.59it/s]
47it [00:00, 51.73it/s]
47it [00:00, 52.05it/s]
47it [00:00, 51.38it/s]
47it [00:00, 52.19it/s]
5it [00:00, 46.77it/s]

epoch 80/300, loss 0.2483, norm 0.1869, predict bleu: 0.0000, time 0.904s, since 0h 1m 13s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 52.09it/s]
47it [00:00, 52.82it/s]
47it [00:00, 51.23it/s]
47it [00:00, 52.76it/s]
47it [00:00, 51.39it/s]
5it [00:00, 45.78it/s]

epoch 85/300, loss 0.2427, norm 0.1818, predict bleu: 0.0000, time 0.916s, since 0h 1m 17s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 51.80it/s]
47it [00:00, 52.52it/s]
47it [00:00, 51.94it/s]
47it [00:00, 52.25it/s]
47it [00:00, 52.20it/s]
5it [00:00, 47.75it/s]

epoch 90/300, loss 0.2377, norm 0.1750, predict bleu: 0.0000, time 0.903s, since 0h 1m 22s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 51.81it/s]
47it [00:00, 52.89it/s]
47it [00:00, 51.96it/s]
47it [00:00, 52.15it/s]
47it [00:00, 53.15it/s]
5it [00:00, 48.21it/s]

epoch 95/300, loss 0.2336, norm 0.1673, predict bleu: 0.0000, time 0.885s, since 0h 1m 26s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 51.42it/s]
47it [00:00, 52.99it/s]
47it [00:00, 48.92it/s]
47it [00:00, 49.70it/s]
47it [00:00, 51.46it/s]
5it [00:00, 46.42it/s]

epoch 100/300, loss 0.2303, norm 0.1595, predict bleu: 0.0000, time 0.917s, since 0h 1m 31s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 51.44it/s]
47it [00:00, 51.22it/s]
47it [00:00, 51.89it/s]
47it [00:00, 52.59it/s]
47it [00:00, 50.08it/s]
5it [00:00, 42.85it/s]

epoch 105/300, loss 0.2271, norm 0.1514, predict bleu: 0.0000, time 0.941s, since 0h 1m 36s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 49.34it/s]
47it [00:00, 51.90it/s]
47it [00:00, 51.19it/s]
47it [00:00, 51.33it/s]
47it [00:00, 51.45it/s]
5it [00:00, 46.63it/s]

epoch 110/300, loss 0.2237, norm 0.1435, predict bleu: 0.0000, time 0.917s, since 0h 1m 40s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 51.64it/s]
47it [00:00, 52.02it/s]
47it [00:00, 51.88it/s]
47it [00:00, 51.50it/s]
47it [00:00, 53.01it/s]
6it [00:00, 53.24it/s]

epoch 115/300, loss 0.2207, norm 0.1381, predict bleu: 0.0000, time 0.890s, since 0h 1m 45s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 53.43it/s]
47it [00:00, 52.42it/s]
47it [00:00, 52.08it/s]
47it [00:00, 52.83it/s]
47it [00:00, 50.68it/s]
5it [00:00, 48.91it/s]

epoch 120/300, loss 0.2177, norm 0.1363, predict bleu: 0.0000, time 0.929s, since 0h 1m 49s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:01, 46.95it/s]
47it [00:00, 50.55it/s]
47it [00:00, 52.36it/s]
47it [00:00, 51.61it/s]
47it [00:00, 52.11it/s]
5it [00:00, 47.31it/s]

epoch 125/300, loss 0.2152, norm 0.1369, predict bleu: 0.0000, time 0.905s, since 0h 1m 54s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 51.56it/s]
47it [00:00, 53.19it/s]
47it [00:00, 51.73it/s]
47it [00:00, 52.59it/s]
47it [00:00, 52.05it/s]
6it [00:00, 50.98it/s]

epoch 130/300, loss 0.2129, norm 0.1392, predict bleu: 0.0000, time 0.905s, since 0h 1m 59s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 52.89it/s]
47it [00:00, 51.46it/s]
47it [00:00, 51.96it/s]
47it [00:00, 49.40it/s]
47it [00:00, 49.68it/s]
5it [00:00, 46.85it/s]

epoch 135/300, loss 0.2111, norm 0.1418, predict bleu: 0.0000, time 0.949s, since 0h 2m 3s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 50.45it/s]
47it [00:00, 49.00it/s]
47it [00:00, 49.64it/s]
47it [00:00, 49.97it/s]
47it [00:00, 50.07it/s]
5it [00:00, 45.99it/s]

epoch 140/300, loss 0.2092, norm 0.1418, predict bleu: 0.0000, time 0.942s, since 0h 2m 8s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 48.94it/s]
47it [00:00, 50.08it/s]
47it [00:00, 49.30it/s]
47it [00:00, 51.72it/s]
47it [00:00, 51.31it/s]
5it [00:00, 47.29it/s]

epoch 145/300, loss 0.2077, norm 0.1402, predict bleu: 0.0000, time 0.919s, since 0h 2m 13s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 50.95it/s]
47it [00:00, 52.82it/s]
47it [00:00, 52.86it/s]
47it [00:00, 52.95it/s]
47it [00:00, 51.50it/s]
5it [00:00, 48.67it/s]

epoch 150/300, loss 0.2065, norm 0.1378, predict bleu: 0.0000, time 0.916s, since 0h 2m 17s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 51.51it/s]
47it [00:00, 51.96it/s]
47it [00:00, 52.19it/s]
47it [00:00, 51.66it/s]
47it [00:00, 51.38it/s]
5it [00:00, 48.21it/s]

epoch 155/300, loss 0.2059, norm 0.1364, predict bleu: 0.0000, time 0.917s, since 0h 2m 22s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 47.61it/s]
47it [00:01, 45.75it/s]
47it [00:01, 39.15it/s]
47it [00:01, 43.65it/s]
47it [00:00, 52.84it/s]
5it [00:00, 49.65it/s]

epoch 160/300, loss 0.2052, norm 0.1335, predict bleu: 0.0000, time 0.892s, since 0h 2m 27s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 51.19it/s]
47it [00:00, 50.83it/s]
47it [00:01, 42.51it/s]
47it [00:01, 44.63it/s]
47it [00:01, 42.60it/s]
5it [00:00, 46.42it/s]

epoch 165/300, loss 0.2046, norm 0.1299, predict bleu: 0.0000, time 1.107s, since 0h 2m 32s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 50.24it/s]
47it [00:00, 50.33it/s]
47it [00:00, 51.16it/s]
47it [00:00, 51.79it/s]
47it [00:01, 45.89it/s]
5it [00:00, 45.11it/s]

epoch 170/300, loss 0.2040, norm 0.1265, predict bleu: 0.0000, time 1.026s, since 0h 2m 37s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 49.89it/s]
47it [00:00, 48.67it/s]
47it [00:00, 49.68it/s]
47it [00:01, 43.45it/s]
47it [00:00, 48.91it/s]
5it [00:00, 45.90it/s]

epoch 175/300, loss 0.2035, norm 0.1237, predict bleu: 0.0000, time 0.964s, since 0h 2m 42s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 51.10it/s]
47it [00:00, 51.25it/s]
47it [00:00, 52.67it/s]
47it [00:00, 53.19it/s]
47it [00:00, 52.65it/s]
5it [00:00, 46.42it/s]

epoch 180/300, loss 0.2031, norm 0.1224, predict bleu: 0.0000, time 0.894s, since 0h 2m 47s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 52.19it/s]
47it [00:00, 52.00it/s]
47it [00:00, 50.08it/s]
47it [00:00, 49.49it/s]
47it [00:00, 50.49it/s]
5it [00:00, 44.16it/s]

epoch 185/300, loss 0.2026, norm 0.1216, predict bleu: 0.0000, time 0.935s, since 0h 2m 51s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 50.15it/s]
47it [00:00, 50.58it/s]
47it [00:00, 49.18it/s]
47it [00:00, 48.42it/s]
47it [00:01, 46.55it/s]
5it [00:00, 45.17it/s]

epoch 190/300, loss 0.2023, norm 0.1207, predict bleu: 0.0000, time 1.012s, since 0h 2m 56s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 49.06it/s]
47it [00:00, 50.08it/s]
47it [00:00, 48.63it/s]
47it [00:01, 46.36it/s]
47it [00:00, 48.63it/s]
4it [00:00, 39.31it/s]

epoch 195/300, loss 0.2018, norm 0.1194, predict bleu: 0.0000, time 0.968s, since 0h 3m 1s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 49.29it/s]
47it [00:00, 51.62it/s]
47it [00:00, 50.92it/s]
47it [00:00, 50.56it/s]
47it [00:00, 51.87it/s]
5it [00:00, 44.37it/s]

epoch 200/300, loss 0.2013, norm 0.1173, predict bleu: 0.0000, time 0.909s, since 0h 3m 6s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 49.72it/s]
47it [00:00, 51.84it/s]
47it [00:00, 51.09it/s]
47it [00:00, 52.01it/s]
47it [00:00, 52.19it/s]
5it [00:00, 45.91it/s]

epoch 205/300, loss 0.2010, norm 0.1154, predict bleu: 0.0000, time 0.904s, since 0h 3m 10s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 51.16it/s]
47it [00:00, 51.79it/s]
47it [00:00, 51.11it/s]
47it [00:00, 51.61it/s]
47it [00:00, 51.11it/s]
5it [00:00, 46.04it/s]

epoch 210/300, loss 0.2005, norm 0.1143, predict bleu: 0.0000, time 0.922s, since 0h 3m 15s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 50.52it/s]
47it [00:00, 50.36it/s]
47it [00:00, 50.56it/s]
47it [00:00, 51.58it/s]
47it [00:00, 50.90it/s]
5it [00:00, 47.30it/s]

epoch 215/300, loss 0.1998, norm 0.1137, predict bleu: 0.0000, time 0.926s, since 0h 3m 20s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 51.83it/s]
47it [00:00, 52.06it/s]
47it [00:00, 51.23it/s]
47it [00:00, 51.46it/s]
47it [00:00, 50.42it/s]
5it [00:00, 46.42it/s]

epoch 220/300, loss 0.1991, norm 0.1134, predict bleu: 0.0000, time 0.935s, since 0h 3m 24s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 51.79it/s]
47it [00:00, 49.46it/s]
47it [00:01, 46.02it/s]
47it [00:00, 51.73it/s]
47it [00:00, 51.77it/s]
5it [00:00, 46.00it/s]

epoch 225/300, loss 0.1984, norm 0.1132, predict bleu: 0.0000, time 0.910s, since 0h 3m 29s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 51.25it/s]
47it [00:00, 50.98it/s]
47it [00:00, 48.54it/s]
47it [00:00, 49.04it/s]
47it [00:00, 48.17it/s]
5it [00:00, 43.98it/s]

epoch 230/300, loss 0.1976, norm 0.1131, predict bleu: 0.0000, time 0.978s, since 0h 3m 34s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 48.53it/s]
47it [00:00, 48.83it/s]
47it [00:00, 48.38it/s]
47it [00:00, 49.33it/s]
47it [00:00, 51.46it/s]
5it [00:00, 47.75it/s]

epoch 235/300, loss 0.1972, norm 0.1136, predict bleu: 0.0000, time 0.916s, since 0h 3m 39s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 50.90it/s]
47it [00:00, 51.57it/s]
47it [00:00, 51.85it/s]
47it [00:00, 50.62it/s]
47it [00:00, 50.84it/s]
5it [00:00, 46.81it/s]

epoch 240/300, loss 0.1969, norm 0.1143, predict bleu: 0.0000, time 0.928s, since 0h 3m 43s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 50.95it/s]
47it [00:00, 51.73it/s]
47it [00:00, 50.83it/s]
47it [00:00, 50.61it/s]
47it [00:00, 51.94it/s]
5it [00:00, 47.96it/s]

epoch 245/300, loss 0.1969, norm 0.1145, predict bleu: 0.0000, time 0.906s, since 0h 3m 48s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 50.75it/s]
47it [00:00, 50.00it/s]
47it [00:00, 51.71it/s]
47it [00:00, 50.49it/s]
47it [00:00, 51.37it/s]
5it [00:00, 42.54it/s]

epoch 250/300, loss 0.1961, norm 0.1148, predict bleu: 0.0000, time 0.917s, since 0h 3m 53s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 48.56it/s]
47it [00:00, 50.77it/s]
47it [00:00, 52.07it/s]
47it [00:00, 51.20it/s]
47it [00:00, 51.97it/s]
5it [00:00, 48.21it/s]

epoch 255/300, loss 0.2045, norm 0.1164, predict bleu: 0.0000, time 0.906s, since 0h 3m 57s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 50.35it/s]
47it [00:00, 50.89it/s]
47it [00:00, 52.02it/s]
47it [00:00, 51.27it/s]
47it [00:00, 50.78it/s]
5it [00:00, 47.30it/s]

epoch 260/300, loss 0.2039, norm 0.1114, predict bleu: 0.0000, time 0.927s, since 0h 4m 2s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 51.49it/s]
47it [00:00, 51.49it/s]
47it [00:00, 50.35it/s]
47it [00:00, 50.97it/s]
47it [00:00, 51.16it/s]
5it [00:00, 45.17it/s]

epoch 265/300, loss 0.1985, norm 0.1173, predict bleu: 0.0000, time 0.921s, since 0h 4m 7s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 50.78it/s]
47it [00:00, 47.52it/s]
47it [00:01, 46.75it/s]
47it [00:00, 48.10it/s]
47it [00:00, 48.70it/s]
5it [00:00, 44.37it/s]

epoch 270/300, loss 0.1981, norm 0.1200, predict bleu: 0.0000, time 0.968s, since 0h 4m 12s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 48.58it/s]
47it [00:00, 49.08it/s]
47it [00:00, 49.92it/s]
47it [00:00, 50.64it/s]
47it [00:00, 50.32it/s]
5it [00:00, 48.11it/s]

epoch 275/300, loss 0.1977, norm 0.1215, predict bleu: 0.0000, time 0.937s, since 0h 4m 16s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 51.00it/s]
47it [00:00, 50.11it/s]
47it [00:00, 50.07it/s]
47it [00:00, 50.97it/s]
47it [00:00, 50.33it/s]
5it [00:00, 43.59it/s]

epoch 280/300, loss 0.2001, norm 0.1234, predict bleu: 0.0000, time 0.937s, since 0h 4m 21s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 49.55it/s]
47it [00:00, 49.71it/s]
47it [00:00, 50.43it/s]
47it [00:00, 49.79it/s]
47it [00:00, 50.70it/s]
5it [00:00, 43.59it/s]

epoch 285/300, loss 0.2021, norm 0.1265, predict bleu: 0.0000, time 0.928s, since 0h 4m 26s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 49.19it/s]
47it [00:00, 51.33it/s]
47it [00:00, 50.94it/s]
47it [00:00, 50.18it/s]
47it [00:00, 49.58it/s]
5it [00:00, 46.42it/s]

epoch 290/300, loss 0.2040, norm 0.1313, predict bleu: 0.0000, time 0.951s, since 0h 4m 31s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 49.16it/s]
47it [00:00, 50.02it/s]
47it [00:00, 50.85it/s]
47it [00:00, 50.40it/s]
47it [00:00, 49.98it/s]
5it [00:00, 47.04it/s]

epoch 295/300, loss 0.2063, norm 0.1355, predict bleu: 0.0000, time 0.943s, since 0h 4m 35s
这个魔方居然还有和如意金箍棒一样的缩放功能。


47it [00:00, 48.53it/s]
47it [00:00, 49.17it/s]
47it [00:00, 48.86it/s]
47it [00:00, 48.28it/s]
47it [00:00, 49.09it/s]


epoch 300/300, loss 0.2090, norm 0.1398, predict bleu: 0.0000, time 0.958s, since 0h 4m 40s
这个魔方居然还有和如意金箍棒一样的缩放功能。


In [61]:
bleu_score = 0
for i in range(len(corpus_test)):
    doc = corpus_test[i]
    _, pred = predict_rnn_ada([idx_to_word[x] for x in topics_test[0]],
                              100, model, idx_to_word, word_to_idx)
    bleu_score += bleu([idx_to_word[int(x)] for x in pred],
                       [idx_to_word[x] for x in doc if x not in [0, 2]], k=2)

In [71]:
movies = ['变形金刚', '仙人']
comments = ['一星级','二星级','三星级','四星级','五星级']
for movie in movies:
    for comment in comments:
        print(predict_rnn_ada([movie, comment], 100, model, idx_to_word, word_to_idx)[0])

你把我的童年记忆给毁的了
导演想表达的是一群阳痿的外星机械老儿崇拜症患者来地球炫耀性功能，深感自卑的地球人投靠汽车人领袖~大老二~，并立志推翻不带套子乱搞的~霸天虎~。霸天虎心智正常的表示~你们家打飞机才带套呢~
这个魔方居然还有和如意金箍棒一样的缩放功能。
巨幕厅看的超过瘾！梅根的身材超棒
其实重新再刷会发现，明亮与阴暗的画风对一部电影加成着实不一样，虽然我更喜欢环太，但环太在变1前还有B级片的气质，至于后面四部嘛……
这都什么呀诛仙是这样的？一点都不仙，看着挺脏的。演员表情狰狞，像极了乡村土巴佬
场面很华丽，但是不得不说爱豆演戏真的无法让人有代入感，硬演！看海报感觉是第二个捉妖记，看完了大失所望。。。
程小东这次没失手，东方仙侠终于在近几年有了一部拿的出手的作品。在质感上，《诛仙Ⅰ》做到了《蜀山传》之后最佳，演员方面肖战在气质上很贴合原著中张小凡，而肖战和原著两大IP加持，看好成为中秋爆款的潜质。
最近仙侠不景气，这部看起来还可以啊。
李沁终于找到了她正确的打开方式，陆雪琪一开始选择她就觉得太契合了。最素的妆容造型最考验女演员的气质，她执剑而立的时候翩然如踏雪飞鸿。希望未来能有更多这类的角色出演，古装于她，添了不止一分仙气。


In [17]:
print(''.join([idx_to_word[x] for x in pred]))

我的妈妈是个“小金鱼”的“小”。妈妈说：“妈妈，我是一个好的。”妈妈说：“你是个好的。”妈妈说：“你是谁的吗？”


In [18]:
print(' '.join([idx_to_word[x] for x in topics_test[len(topics_test) - 1]]))

音乐 节目 兴奋 响起 小手
