In [1]:
import gensim
from gensim.models import KeyedVectors
import numpy as np
import torch
from torch import nn, autograd, optim
import torch.nn.functional as F
import time
import math
from tqdm import tqdm

In [None]:
fvec = KeyedVectors.load_word2vec_format('sgns.weibo.word', binary=False)
word_vec = fvec.vectors
vocab = ['<PAD>', '<BOS>', '<EOS>', '<UNK>']
vocab.extend(list(fvec.vocab.keys()))
word_vec = np.concatenate((np.array([[0]*word_vec.shape[1]] * 4), word_vec))
word_vec = torch.tensor(word_vec)

In [3]:
word_to_idx = {ch: i for i, ch in enumerate(vocab)}
idx_to_word = {i: ch for i, ch in enumerate(vocab)}

In [4]:
essays = []
topics = []
with open('processed_data.txt', 'r') as f:
    for line in f:
        essay, topic = line.replace('\n', '').split(' </d> ')
        essays.append(essay.split(' '))
        topics.append(topic.split(' '))

In [5]:
corpus_indice = list(map(lambda x: [word_to_idx[w] for w in x], essays[:8000]))
topics_indice = list(map(lambda x: [word_to_idx[w] for w in x], topics[:8000]))
corpus_test = list(map(lambda x: [word_to_idx[w] for w in x], essays[8000:8800]))
topics_test = list(map(lambda x: [word_to_idx[w] for w in x], topics[8000:8800]))

In [6]:
length = list(map(lambda x: len(x), corpus_indice))

In [7]:
def tav_data_iterator(corpus_indice, topics_indice, batch_size, num_steps):
    epoch_size = len(corpus_indice) // batch_size
    for i in range(epoch_size):
        raw_data = corpus_indice[i*batch_size: (i+1)*batch_size]
        key_words = topics_indice[i*batch_size: (i+1)*batch_size]
        data = np.zeros((len(raw_data), num_steps+1), dtype=np.int64)
        for i in range(batch_size):
            doc = raw_data[i]
            tmp = [1]
            tmp.extend(doc)
            tmp.extend([2])
            tmp = np.array(tmp, dtype=np.int64)
            _size = tmp.shape[0]
            data[i][:_size] = tmp
        key_words = np.array(key_words, dtype=np.int64)
        x = data[:, 0:num_steps]
        y = data[:, 1:]
        mask = np.float32(x != 0)
        x = torch.tensor(x)
        y = torch.tensor(y)
        mask = torch.tensor(mask)
        key_words = torch.tensor(key_words)
        yield(x, y, mask, key_words)

In [8]:
class TATLSTM(nn.Module):
    def __init__(self, hidden_dim, embed_dim, num_layers, weight,
                 num_labels, bidirectional, dropout=0.5, **kwargs):
        super(TATLSTM, self).__init__(**kwargs)
        self.hidden_dim = hidden_dim
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.num_labels = num_labels
        self.bidirectional = bidirectional
        if num_layers <= 1:
            self.dropout = 0
        else:
            self.dropout = dropout
        self.embedding = nn.Embedding.from_pretrained(weight)
        self.embedding.weight.requires_grad = False
        self.rnn = nn.GRU(input_size=self.embed_dim, hidden_size=self.hidden_dim,
                          num_layers=self.num_layers, bidirectional=self.bidirectional,
                          dropout=self.dropout)
        if self.bidirectional:
            self.decoder = nn.Linear(hidden_dim * 2 + self.embed_dim, self.num_labels)
        else:
            self.decoder = nn.Linear(hidden_dim + self.embed_dim, self.num_labels)
        self.attn = nn.Linear(self.embed_dim * 5, self.embed_dim)
        
        
    def forward(self, inputs, topics, hidden=None):
        embeddings = self.embedding(inputs)
        topics_embed = self.embedding(topics).float()
        topics_attn = self.attn(topics_embed.reshape((topics_embed.shape[0], -1)))
        topics_attn.unsqueeze_(-1)
#         for i in range(embeddings.shape[0]):
#             embeddings[i][0] = topics_embed[i]
        states, hidden = self.rnn(embeddings.permute([1, 0, 2]).float(), hidden)
        topics_attn = topics_attn.expand(topics_attn.shape[0], topics_attn.shape[1], states.shape[0])
        topics_attn = topics_attn.permute([2, 0, 1])
        states_with_topic = torch.cat([states, topics_attn], dim=2)
        outputs = self.decoder(states_with_topic.reshape((-1, states_with_topic.shape[-1])))
        return(outputs, hidden)
    
    def init_hidden(self, num_layers, batch_size, hidden_dim, **kwargs):
        hidden = torch.zeros(num_layers, batch_size, hidden_dim)
        return hidden

In [9]:
def predict_rnn(topics, num_chars, model, device, idx_to_word, word_to_idx):
    output = [1]
    topics = [word_to_idx[x] for x in topics]
    topics = torch.tensor(topics)
    topics = topics.reshape((1, topics.shape[0], -1))
    hidden = torch.zeros(num_layers, 1, hidden_dim)
    if use_gpu:
        hidden = hidden.to(device)
        topics = topics.to(device)
    for t in range(num_chars):
        X = torch.tensor(output[-1]).reshape((1, 1))
#         X = torch.tensor(output).reshape((1, len(output)))
        if use_gpu:
            X = X.to(device)
        pred, hidden = model(X, topics, hidden)
        if pred.argmax(dim=1)[-1] == 2:
            break
        else:
            output.append(int(pred.argmax(dim=1)))
#             output.append(int(pred.argmax(dim=1)[-1]))
    return(''.join([idx_to_word[i] for i in output[1:]]), output[1:])

In [10]:
def bleu(pred_tokens, label_tokens, k):
    len_pred, len_label = len(pred_tokens), len(label_tokens)
    score = math.exp(min(0, 1 - len_label / len_pred))
    for n in range(1, k+1):
        num_matches = 0
        for i in range(len_pred - n + 1):
            if ' '.join(pred_tokens[i: i + n]) in ' '.join(label_tokens):
                num_matches += 1
        score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))
    return score

In [11]:
embedding_dim = 100
hidden_dim = 256
lr = 1e2
momentum = 0.0
num_epoch = 100
use_gpu = True
num_layers = 1
bidirectional = False
batch_size = 8
verbose = 5
device = torch.device('cuda:0')
loss_function = nn.CrossEntropyLoss()

In [12]:
model = TATLSTM(hidden_dim=hidden_dim, embed_dim=embedding_dim, num_layers=num_layers,
                num_labels=len(vocab), weight=word_vec, bidirectional=bidirectional)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
if use_gpu:
#     model = nn.DataParallel(model)
    model.to(device)

In [13]:
predict_rnn(['变形金刚','三星级'], 100, model, device, idx_to_word, word_to_idx)[0]

'近视不除揣上溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也溜之乎也'

In [14]:
since = time.time()
for epoch in range(num_epoch):
    start = time.time()
    num, total_loss = 0, 0
#     if epoch == 5000:
#         optimizer.param_groups[0]['lr'] = lr * 0.1
    data = tav_data_iterator(
        corpus_indice, topics_indice, batch_size, max(length) + 1)
    hidden = model.init_hidden(num_layers, batch_size, hidden_dim)
    weight = torch.ones(len(vocab))
    weight[0] = 0
    for X, Y, mask, topics in tqdm(data):
        num += 1
        hidden.detach_()
        if use_gpu:
            X = X.to(device)
            Y = Y.to(device)
            mask = mask.to(device)
            topics = topics.to(device)
            hidden = hidden.to(device)
            weight = weight.to(device)
        optimizer.zero_grad()
        output, hidden = model(X, topics, hidden)
        l = F.cross_entropy(output, Y.t().reshape((-1,)), weight)
        l.backward()
        norm = nn.utils.clip_grad_norm_(model.parameters(), 1e-2)
        optimizer.step()
        total_loss += l.item()
    end = time.time()
    s = end - since
    h = math.floor(s / 3600)
    m = s - h * 3600
    m = math.floor(m / 60)
    s -= (m * 60 + h * 3600)
    if((epoch + 1) % verbose == 0) or (epoch == (num_epoch - 1)):
        bleu_score = 0
        for i in range(len(corpus_test)):
            doc = corpus_test[i]
            _, pred = predict_rnn([idx_to_word[x] for x in topics_test[0]],
                                  100, model, device, idx_to_word, word_to_idx)
            bleu_score += bleu([idx_to_word[int(x)] for x in pred],
                               [idx_to_word[x] for x in doc if x not in [0, 2]], k=2)
        print('epoch %d/%d, loss %.4f, norm %.4f, predict bleu: %.4f, time %.3fs, since %dh %dm %ds'
              % (epoch + 1, num_epoch, total_loss / num, norm, bleu_score / 800, end - start, h, m, s))
        print(predict_rnn(['变形金刚','三星级'],
                          100, model, device, idx_to_word, word_to_idx)[0])

1000it [02:17,  7.30it/s]
1000it [02:17,  7.27it/s]
1000it [02:17,  7.28it/s]
1000it [02:17,  7.28it/s]
1000it [02:17,  7.28it/s]


epoch 5/100, loss 5.9458, norm 0.7823, bleu: 0.0093, time 137.402s, since 0h 11m 26s


1it [00:00,  7.39it/s]

人们的人们都会喜欢我的美丽，我也喜欢我最喜欢我的美丽，我也喜欢我最喜欢我的美丽，我也喜欢我的美丽，我也喜欢我的美丽，我也喜欢我的美丽，我也喜欢我的美丽，我也喜欢我的美丽，我也喜欢我的美丽，我也喜欢我的美丽，我也喜欢我的美丽，我也喜欢我的美丽，我也喜欢我的美丽，


1000it [02:17,  7.29it/s]
1000it [02:17,  7.27it/s]
1000it [02:17,  7.26it/s]
1000it [02:17,  7.25it/s]
1000it [02:17,  7.27it/s]
0it [00:00, ?it/s]

epoch 10/100, loss 5.0203, norm 0.9469, bleu: 0.1399, time 137.403s, since 0h 24m 51s
人们的人们都会喜欢人们的美丽。


1000it [02:16,  7.28it/s]
1000it [02:17,  7.28it/s]
1000it [02:17,  7.28it/s]
1000it [02:17,  7.27it/s]
1000it [02:17,  7.27it/s]
1it [00:00,  7.37it/s]

epoch 15/100, loss 4.3566, norm 1.1449, bleu: 0.0067, time 137.421s, since 0h 46m 41s
人们的美丽的美丽。


1000it [02:17,  7.27it/s]
1000it [02:17,  7.27it/s]
1000it [02:17,  7.28it/s]
1000it [02:17,  7.28it/s]
1000it [02:17,  7.26it/s]
1it [00:00,  7.38it/s]

epoch 20/100, loss 3.8430, norm 1.8767, bleu: 0.0000, time 137.539s, since 1h 0m 6s
人们的美丽的笑脸。


1000it [02:17,  7.26it/s]
1000it [02:17,  7.26it/s]
1000it [02:17,  7.27it/s]
1000it [02:17,  7.28it/s]
1000it [02:17,  7.26it/s]
1it [00:00,  7.42it/s]

epoch 25/100, loss 3.4498, norm 1.4331, bleu: 0.0935, time 137.555s, since 1h 12m 3s
的人们。


1000it [02:16,  7.29it/s]
1000it [02:17,  7.28it/s]
1000it [02:17,  7.28it/s]
1000it [02:17,  7.25it/s]
1000it [02:17,  7.26it/s]
0it [00:00, ?it/s]

epoch 30/100, loss 3.1546, norm 1.3691, bleu: 0.0389, time 137.497s, since 1h 33m 54s
秋天来了，夏天的夏天，人们都会喜欢自己的人。


1000it [02:17,  7.27it/s]
1000it [02:17,  7.27it/s]
1000it [02:17,  7.27it/s]
1000it [02:17,  7.26it/s]
1000it [02:17,  7.27it/s]
1it [00:00,  7.36it/s]

epoch 35/100, loss 2.9591, norm 1.4780, bleu: 0.0001, time 137.582s, since 1h 49m 23s
的人们。


1000it [02:17,  7.27it/s]
1000it [02:17,  7.25it/s]
1000it [02:17,  7.25it/s]
1000it [02:17,  7.28it/s]
1000it [02:17,  7.26it/s]
0it [00:00, ?it/s]

epoch 40/100, loss 2.8253, norm 1.7286, bleu: 0.0428, time 137.564s, since 2h 2m 10s
秋姑娘又来人们带来了无穷的的秋景。


1000it [02:17,  7.28it/s]
1000it [02:17,  7.25it/s]
1000it [02:17,  7.25it/s]
1000it [02:17,  7.27it/s]
1000it [02:17,  7.28it/s]
0it [00:00, ?it/s]

epoch 45/100, loss 2.7306, norm 1.5368, bleu: 0.0030, time 137.588s, since 2h 19m 25s
秋姑娘又来了，他们在欢迎人们的到来，我最喜欢的是美丽的秋景。


1000it [02:17,  7.26it/s]
1000it [02:17,  7.27it/s]
1000it [02:17,  7.27it/s]
1000it [02:17,  7.27it/s]
1000it [02:17,  7.26it/s]


epoch 50/100, loss 2.6637, norm 1.5730, bleu: 0.0110, time 137.627s, since 2h 32m 44s


1it [00:00,  7.40it/s]

秋姑娘的到来，人们也会感受到了夏天的到来，我要喜欢这是秋景的秋景，我爱秋天，因为人们能感受到夏天的寒冷人们带来人们的美丽。


1000it [02:17,  7.25it/s]
1000it [02:17,  7.28it/s]
1000it [02:17,  7.28it/s]
1000it [02:17,  7.28it/s]
1000it [02:17,  7.26it/s]
0it [00:00, ?it/s]

epoch 55/100, loss 2.6071, norm 2.0138, bleu: 0.0188, time 137.630s, since 2h 46m 53s
人们的人们都是在秋中的生活中，人们都会感受到到夏天的到来到来到来。


1000it [02:17,  7.29it/s]
1000it [02:17,  7.26it/s]
1000it [02:17,  7.26it/s]
1000it [02:17,  7.24it/s]
1000it [02:17,  7.26it/s]
0it [00:00, ?it/s]

epoch 60/100, loss 2.5658, norm 1.6831, bleu: 0.0625, time 137.728s, since 3h 0m 36s
人们喜欢的春天，这种也是夏天的炎热。


1000it [02:17,  7.28it/s]
1000it [02:17,  7.26it/s]
1000it [02:17,  7.26it/s]
1000it [02:17,  7.27it/s]
1000it [02:17,  7.26it/s]


epoch 65/100, loss 2.5415, norm 1.8972, bleu: 0.0179, time 137.717s, since 3h 17m 39s


1it [00:00,  7.40it/s]

秋天来了，人们都在歌唱着人们的到来，人们在欢迎着，人们都在歌唱着歌唱，在人们的眼中，他们都在歌唱着，歌唱着，人们说：“欢迎，欢迎，我要要喜欢这！


1000it [02:17,  7.28it/s]
1000it [02:17,  7.25it/s]
1000it [02:17,  7.26it/s]
1000it [02:17,  7.26it/s]
1000it [02:17,  7.26it/s]


epoch 70/100, loss 2.5185, norm 1.4968, bleu: 0.0000, time 137.796s, since 3h 32m 20s


1it [00:00,  7.33it/s]

秋天来了，夏阿姨去了果园里，我又听到了：“我要要看看看看我呀！”


1000it [02:17,  7.26it/s]
1000it [02:17,  7.24it/s]
1000it [02:17,  7.27it/s]
1000it [02:17,  7.25it/s]
1000it [02:17,  7.26it/s]


epoch 75/100, loss 2.5103, norm 2.2964, bleu: 0.0000, time 137.828s, since 3h 44m 19s


1it [00:00,  7.35it/s]

秋天来了，夏天的雨里，人们都在歌唱着人们的到来。我要看了，好多人都会歌唱着。


1000it [02:17,  7.26it/s]
1000it [02:17,  7.27it/s]
1000it [02:17,  7.26it/s]
1000it [02:17,  7.23it/s]
1000it [02:17,  7.26it/s]


epoch 80/100, loss 2.5091, norm 1.7077, bleu: 0.0083, time 137.814s, since 3h 56m 17s


1it [00:00,  7.41it/s]

秋天到了，雨中的人们都觉得很难，但我却觉得自己的骄傲，反而也要闻到了。


1000it [02:17,  7.29it/s]
1000it [02:17,  7.26it/s]
1000it [02:17,  7.26it/s]
1000it [02:17,  7.25it/s]
1000it [02:17,  7.26it/s]


epoch 85/100, loss 2.5159, norm 2.4245, bleu: 0.0982, time 137.747s, since 4h 10m 32s


1it [00:00,  7.38it/s]

秋天来了，秋天来了，夏天的炎热，人们都在人们的眼中中，人们尽情地享受着丰收的果实。


1000it [02:17,  7.26it/s]
1000it [02:17,  7.27it/s]
1000it [02:17,  7.26it/s]
1000it [02:17,  7.27it/s]
1000it [02:17,  7.26it/s]


epoch 90/100, loss 2.5474, norm 2.4875, bleu: 0.0002, time 137.778s, since 4h 28m 19s


1it [00:00,  7.40it/s]

秋天的西瓜成熟了，苹果也要穿上了美丽的秋景，人们都在歌唱着跳舞的秋景。


1000it [02:17,  7.28it/s]
1000it [02:17,  7.29it/s]
1000it [02:17,  7.27it/s]
1000it [02:17,  7.24it/s]
1000it [02:17,  7.26it/s]
0it [00:00, ?it/s]

epoch 95/100, loss 2.5836, norm 4.1016, bleu: 0.0929, time 137.747s, since 4h 42m 15s
秋天的雨，是人们的歌唱。


1000it [02:17,  7.26it/s]
1000it [02:17,  7.28it/s]
1000it [02:17,  7.27it/s]
1000it [02:17,  7.27it/s]
1000it [02:17,  7.28it/s]


epoch 100/100, loss 2.6567, norm 2.2180, bleu: 0.0536, time 137.511s, since 5h 4m 6s
秋天的雨真是美丽的。它虽然不喜欢，夏天的炎热，我要喜欢它，它也不像夏天一样炎热的夏天，但我觉得自己的一切都是一派平凡的景象。


In [16]:
print(''.join([idx_to_word[x] for x in pred]))

从前的松树围着着五彩缤纷的的的湖面。如果湖面上，有几片火红的松树。有松树绣在小草，有的半开围着，有的像小姑娘在捉迷藏。有的松树还是那么惹人。


In [18]:
print(' '.join([idx_to_word[x] for x in topics_test[len(topics_test) - 1]]))

鸟儿 动听 中午 郁郁葱葱 看着
