In [1]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

import time
import tqdm
import itertools
from allennlp.nn import util

In [2]:
from chinese_gpt import TransformerEncoder as Encoder
from chinese_gpt import TransformerDecoderLM as Decoder
from pytorch_pretrained_bert import BertTokenizer, OpenAIAdam

In [3]:
# You need this library for beam search
from allennlp.nn.beam_search import BeamSearch

In [4]:
def top_k_logits(logits, k):
    values, _ = torch.topk(logits, k)
    min_values = values[:, -1].unsqueeze(1).repeat(1, logits.shape[-1])
    return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)

You need to define a step function

In [5]:
temperature = 1.0

def take_step(last_predictions, state):
    
    past = state['past'].permute(1, 2, 0, 3, 4, 5)
    past = [past[i] for i in range(12)]
    past_length = state["length"][0].item()
    
    logits, past = decoder(last_predictions.view(-1, 1), 
                                  state["mask"], 
                                  past=past, 
                                  past_length=past_length)

    logits = logits.squeeze(1) / temperature
    log_probs = F.log_softmax(logits, dim=-1)
    
    state["mask"] = F.pad(state["mask"], (0, 1), "constant", 1.0)
    state["past"] = torch.stack(past, 0).permute(2, 0, 1, 3, 4, 5)
    state["length"] += 1
    
    return log_probs, state

In [6]:
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")

In [7]:
beam_search = BeamSearch(end_index=102, max_steps=40, beam_size=4)

In [140]:
encoder = Encoder()
decoder = Decoder()
encoder.eval()
decoder.eval()

# pretrained weights
encoder.load_state_dict(torch.load("5encoder.pth"))
decoder.load_state_dict(torch.load("5decoder.pth"))

In [141]:
# load validation dataset
val_data = torch.load("val_data.pth")

In [142]:
batch_size = 1
val_dataset = TensorDataset(*val_data)
val_dataloader = DataLoader(dataset=val_dataset, shuffle=False, batch_size=batch_size)

In [143]:
device = torch.device("cuda")

encoder = encoder.to(device)
decoder = decoder.to(device)

## Beam Search

In [307]:
loader = iter(val_dataloader)

for i in range(530):
    next(loader)

In [308]:
all_outputs = []

with torch.no_grad():
    for batch in tqdm.tqdm_notebook(loader):
        batch = [item.to(device) for item in batch]

        encoder_input, \
                third, \
                mask_encoder_input, \
                mask_third, \
                encoder_type_ids, \
                third_type_ids = batch

        _, past = encoder(encoder_input, mask_encoder_input, encoder_type_ids)

        state = {}
        start_predictions = torch.LongTensor([[101]]* batch_size).to(device)
        mask = torch.ones(batch_size, start_predictions.shape[1]).to(device)
        mask = torch.cat([mask_encoder_input.float(), mask], dim=1)
        state["mask"] = mask
        state["length"] = torch.LongTensor([[0]] * batch_size).to(device)
        state["past"] = torch.stack(past, 0).permute(2, 0, 1, 3, 4, 5)

        all_top_k_predictions, log_probabilities = beam_search.search(start_predictions, state, take_step)


   
        all_outputs.append("".join(tokenizer.convert_ids_to_tokens(all_top_k_predictions[0][0].tolist())
                      ).replace("##", ""))
        print("Generated beam-1:")
        print("".join(tokenizer.convert_ids_to_tokens(all_top_k_predictions[0][0].tolist())
             ).replace("##", ""))
        print("Generated beam-2:")
        print("".join(tokenizer.convert_ids_to_tokens(all_top_k_predictions[0][-1].tolist())
                     ).replace("##", ""))
        print("原标题+summarization:")
        print("".join(tokenizer.convert_ids_to_tokens(encoder_input[0].tolist())
                     ).replace("##", "").replace("[PAD]", ""))
        print("运营:")
        print("".join(tokenizer.convert_ids_to_tokens(third[0].tolist())
                     ).replace("##", "").replace("[PAD]", ""))
        
        break

HBox(children=(IntProgress(value=0, max=5161), HTML(value='')))

Generated beam-1:
终于找到火箭惨败的真因了！别光怪主力不行，这才是最大祸根[SEP][SEP][SEP][SEP]
Generated beam-2:
终于找到火箭惨败的真因了！别光怪主力不行，这1点才是最大祸首[SEP][SEP]
原标题+summarization:
[CLS]引爆火箭交易的不是争冠，而是主力要被累死了[SEP]休斯顿火箭队感恩节之前表现不错，五连胜终于打出了上赛季的火热状态，似乎新赛季一切都随着时间的推移变得好了起来。不仅被活塞队复仇，还输给了东部鱼腩骑士队，面对着矛盾丛生的奇才队也以失利告终。我们经常在比赛的开局阶段表现得很差，经常会打得非常软，从一开始就注定要输球"。[SEP]
运营:
[CLS]势在必行！火箭交易只因1.6亿花得不值，这2人成全队最大障碍[SEP]



In [15]:
all_outputs = [item.replace("[SEP]", "") for item in all_outputs]

In [16]:
with open("results", "w") as f:
    for line in all_outputs:
        f.write(line)
        f.write("\n")

## Sampling Based Search

In [309]:
loader = iter(val_dataloader)

for i in range(530):
    batch = next(loader)

batch = next(loader)
batch = [item.to(device) for item in batch]

In [326]:
length = 0
top_k = 10

total_prob = 0.0

encoder_input, \
        third, \
        mask_encoder_input, \
        mask_third, \
        encoder_type_ids, \
        third_type_ids = batch

_, past = encoder(encoder_input, mask_encoder_input, encoder_type_ids)

start_predictions = torch.LongTensor([[101]]* batch_size).to(device)
mask = torch.ones(batch_size, start_predictions.shape[1]).to(device)
mask = torch.cat([mask_encoder_input.float(), mask], dim=1)

logits, past = decoder(start_predictions, mask, past=past, past_length=0)
logits = logits.squeeze(1) / 1.0
logits = top_k_logits(logits, k=top_k)

sentence = []

probs = F.softmax(logits, dim=-1)
prob, prev_pred = torch.topk(probs, k=1, dim=-1)
sentence.append(prev_pred)
length += 1
total_prob += np.log(prob.item())

for i in range(40):
    mask = F.pad(mask, (0, 1), "constant", 1.0)
    logits, past = decoder(prev_pred, mask, past=past, past_length=length)
    logits = logits.squeeze(1) / 1.0
    logits = top_k_logits(logits, k=top_k)
    probs = F.softmax(logits, dim=-1)
    prev_pred = torch.multinomial(probs, num_samples=1)
    sentence.append(prev_pred)
    length += 1
    total_prob += np.log(probs[0, prev_pred.item()].item())
    
sentence = torch.cat(sentence, dim=-1)

print("Generated sampled")
print(np.exp(total_prob))
print("".join(tokenizer.convert_ids_to_tokens(sentence[0].tolist())
             ).replace("##", ""))
print("原标题+summarization:")
print("".join(tokenizer.convert_ids_to_tokens(encoder_input[0].tolist())
             ).replace("##", "").replace("[PAD]", ""))
print("运营:")
print("".join(tokenizer.convert_ids_to_tokens(third[0].tolist())
             ).replace("##", "").replace("[PAD]", ""))

Generated sampled
1.39405121091997e-11
终于找到火箭失利的真因了！主力要被累死了，这才是最根本祸根[SEP]终结[SEP]终结[SEP]终场哨响[SEP]
原标题+summarization:
[CLS]引爆火箭交易的不是争冠，而是主力要被累死了[SEP]休斯顿火箭队感恩节之前表现不错，五连胜终于打出了上赛季的火热状态，似乎新赛季一切都随着时间的推移变得好了起来。不仅被活塞队复仇，还输给了东部鱼腩骑士队，面对着矛盾丛生的奇才队也以失利告终。我们经常在比赛的开局阶段表现得很差，经常会打得非常软，从一开始就注定要输球"。[SEP]
运营:
[CLS]势在必行！火箭交易只因1.6亿花得不值，这2人成全队最大障碍[SEP]
