In [1]:
import torch
import torch.nn.functional as F
from tokenization_unilm import UnilmTokenizer
from modeling_unilm import UnilmForSeq2SeqDecodeSample, UnilmForSeq2SeqDecodeSampleCached, UnilmConfig
import copy
import os
import argparse
import re
from dirty_recognize import dirty_reg
import time

In [2]:
def remove_dirty_sentence(dirty_obj, sentence):
    if len(dirty_obj.match(sentence)) == 0:
        return False
    else:
        return True


def remove_multi_symbol(text):
    r = re.compile(r'([.,，/\\#!！？?。$%^&*;；:：{}=_`´︵~（）()-])[.,，/\\#!！？?。$%^&*;；:：{}=_`´︵~（）()-]+')
    text = r.sub(r'\1', text)
    return text


def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    assert logits.dim() == 1
    top_k = min(top_k, logits.size(-1))
    if top_k > 0:
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits

In [3]:
use_cuda = torch.cuda.is_available() and not False
device = 'cpu'
model_name_or_path = "kuakua_robot_model/"
max_len = 32
topk = 3
topp = 0.95
repetition_penalty = 1.2

In [4]:
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = device
config = UnilmConfig.from_pretrained(model_name_or_path, max_position_embeddings=512)
tokenizer = UnilmTokenizer.from_pretrained(model_name_or_path, do_lower_case=False)
cached_model = UnilmForSeq2SeqDecodeSampleCached.from_pretrained(model_name_or_path, config=config)
cached_model.to(device)
cached_model.eval()

Some weights of UnilmForSeq2SeqDecodeSampleCached were not initialized from the model checkpoint at kuakua_robot_model/ and are newly initialized: ['bert.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


UnilmForSeq2SeqDecodeSampleCached(
  (bert): UnilmModelIncr(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(6, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_a

In [5]:
dirty_obj = dirty_reg("data/dirty_words.txt")
_tril_matrix = torch.tril(torch.ones((max_len, max_len), dtype=torch.long))

In [6]:
model = UnilmForSeq2SeqDecodeSample.from_pretrained(model_name_or_path, config=config)
model.to(device)
model.eval()

Some weights of UnilmForSeq2SeqDecodeSample were not initialized from the model checkpoint at kuakua_robot_model/ and are newly initialized: ['bert.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


UnilmForSeq2SeqDecodeSample(
  (bert): UnilmModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(6, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True

In [12]:
text = "你是谁"
input_ids = tokenizer.encode(text)
input_len = len(input_ids)
token_type_ids = [4] * len(input_ids)
generated = []
no_cached_logit = None
for idx in range(max_len):
    curr_input_ids = copy.deepcopy(input_ids)
    curr_input_ids.append(tokenizer.mask_token_id)
    curr_input_tensor = torch.tensor(curr_input_ids).long().to(device).view([1, -1])
    curr_token_type_ids = copy.deepcopy(token_type_ids)
    curr_token_type_ids.extend([5])
    curr_token_type_ids = torch.tensor(curr_token_type_ids).long().to(device).view([1, -1])
    # attention mask
    input_mask = torch.zeros(len(curr_input_ids), len(curr_input_ids), dtype=torch.long)
    input_mask[:, :input_len].fill_(1)
    second_st, second_end = input_len, len(curr_input_ids)
    input_mask[second_st:second_end, second_st:second_end].copy_(_tril_matrix[:second_end-second_st, :second_end-second_st])

    attn_mask = input_mask.unsqueeze(0)
    print("attn_mask:", attn_mask.shape)
    # input_mask = None
    outputs = model(input_ids=curr_input_tensor, token_type_ids=curr_token_type_ids, attention_mask=attn_mask)
    next_token_logits = outputs[-1, -1, :]
    if idx == 2:
        no_cached_logit = next_token_logits
        break
    for id in set(generated):
        next_token_logits[id] /= repetition_penalty
    next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
    filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=topk, top_p=topp)
    next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
    if next_token == tokenizer.sep_token_id:  # 遇到[SEP]则表明生成结束
        break
    generated.append(next_token.item())
    if idx == 0:
        input_ids.append(2769)
    if idx == 1:
        input_ids.append(3221)
    token_type_ids.extend([5])
text = tokenizer.convert_ids_to_tokens(generated)
text = remove_multi_symbol("".join(text))

attn_mask: torch.Size([1, 6, 6])
attn_mask: torch.Size([1, 7, 7])
attn_mask: torch.Size([1, 8, 8])


In [13]:
text = "你是谁"
input_ids = tokenizer.encode(text)
input_len = len(input_ids)
token_type_ids = [4] * len(input_ids)
generated = []
prev_embedding, prev_encoded_layers, position_ids = None, None, None
ori_embedding, ori_encoded_layers = None, None
cached_logit = None
for idx in range(max_len):
    if idx == 0:
        curr_input_ids = copy.deepcopy(input_ids)
        curr_input_ids.append(tokenizer.mask_token_id)
        curr_token_type_ids = copy.deepcopy(token_type_ids)
        curr_token_type_ids.extend([5])  
    else:
        curr_input_ids = copy.deepcopy(input_ids)
        curr_input_ids.append(tokenizer.mask_token_id)
        curr_token_type_ids = copy.deepcopy(token_type_ids)
        curr_token_type_ids.extend([5])
        position_ids = [input_len+idx-1, input_len+idx]
        position_ids = curr_input_tensor = torch.tensor(position_ids).long().to(device).view([1, -1])
    curr_input_tensor = torch.tensor(curr_input_ids).long().to(device).view([1, -1])
    curr_token_type_ids = torch.tensor(curr_token_type_ids).long().to(device).view([1, -1])
    # attention mask
    input_mask = torch.zeros(input_len+1+idx, input_len+1+idx, dtype=torch.long)
    input_mask[:, :input_len].fill_(1)
    second_st, second_end = input_len, input_len+1+idx
    input_mask[second_st:second_end, second_st:second_end].copy_(_tril_matrix[:second_end-second_st, :second_end-second_st])
    if idx >= 1:
        input_mask = input_mask[-2:, :]
#         print(input_mask)
    attn_mask = input_mask.unsqueeze(0)
    prev_embedding, prev_encoded_layers, outputs = cached_model(input_ids=curr_input_tensor, token_type_ids=curr_token_type_ids, 
                                    attention_mask=attn_mask, position_ids=position_ids, output_all_encoded_layers=True, prev_embedding=ori_embedding, 
                                    prev_encoded_layers=ori_encoded_layers)
    # prev_embedding, prev_encoded_layers删除最后一个timestamp的数据
    # (1, n, 768) -> (1, n-1, 768)
    if idx == 0:
        ori_embedding = prev_embedding[:,:-1,:]
        ori_encoded_layers = [layer[:, :-1, :] for layer in prev_encoded_layers]
    else:
        ori_embedding = torch.cat((ori_embedding, prev_embedding[:,:-1,:]), dim=1)
        ori_encoded_layers = [torch.cat((ori_encoded_layers[i], prev_encoded_layers[i][:,:-1,:]), dim=1) for i in range(len(ori_encoded_layers))]
    next_token_logits = outputs[-1, -1, :]
    if idx == 2:
        cached_logit = next_token_logits
        break
    for id in set(generated):
        next_token_logits[id] /= repetition_penalty
    next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
    filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=topk, top_p=topp)
    next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
    if next_token == tokenizer.sep_token_id:  # 遇到[SEP]则表明生成结束
        break
    generated.append(next_token.item())
    if idx == 0:
        input_ids = [2769]
    if idx == 1:
        input_ids = [3221]
    token_type_ids = [5]
text = tokenizer.convert_ids_to_tokens(generated)
text = remove_multi_symbol("".join(text))

input_ids:  tensor([[ 101,  872, 3221, 6443,  102,  103]])
token_type_ids:  tensor([[4, 4, 4, 4, 4, 5]])
embedding_output:  tensor([[[ 0.0093,  0.0541, -0.2536,  ..., -0.0645, -0.1700, -0.2591],
         [-0.4054, -0.5532,  0.6166,  ..., -0.9640,  0.3675,  0.8820],
         [-0.5610,  0.6875, -0.5392,  ..., -0.9702,  0.8969, -0.5634],
         [-1.2825, -0.6398,  2.0296,  ..., -0.0399,  0.3640,  1.0505],
         [-0.3565, -0.1356,  0.1601,  ..., -1.5584, -0.0705, -0.0762],
         [-0.3105,  0.1816,  0.3181,  ..., -0.3709,  0.4040,  0.3473]]],
       grad_fn=<NativeLayerNormBackward>)
input_ids:  tensor([[2769,  103]])
token_type_ids:  tensor([[5, 5]])
embedding_output:  tensor([[[ 0.5855, -0.0974,  1.6734,  ..., -0.7327,  0.9123, -0.3311],
         [-0.0922, -0.0950,  0.1648,  ..., -0.3694,  0.1536,  0.1964]]],
       grad_fn=<NativeLayerNormBackward>)
input_ids:  tensor([[3221,  103]])
token_type_ids:  tensor([[5, 5]])
embedding_output:  tensor([[[-0.1240,  0.6086, -0.1027,  ..., -

In [14]:
no_cached_logit.shape

torch.Size([21128])

In [15]:
cached_logit

tensor([-8.6942, -5.5811, -6.4095,  ..., -6.5995, -8.9505, -7.6315],
       grad_fn=<SliceBackward>)

In [16]:
no_cached_logit

tensor([-8.6942, -5.5811, -6.4095,  ..., -6.5995, -8.9505, -7.6315],
       grad_fn=<SliceBackward>)