In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchfly.modules.transformers import CachedBertDecoderLM, ChineseBERTBaseConfig
from torchfly.text.tokenizers import BertTokenizer
from torchfly.utils import get_pretrained_states

In [2]:
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")

In [3]:
model_states = get_pretrained_states("chinese-gpt-bert-small")

File exists: /home/wuqy1203/.cache/torchfly/models/chinese-gpt-bert-small.pth


In [4]:
model = CachedBertDecoderLM(ChineseBERTBaseConfig)
model.load_state_dict(model_states, strict=False)

<All keys matched successfully>

In [5]:
device = torch.device("cuda")
model = model.to(device)

In [6]:
def top_k_logits(logits, k):
    """Mask logits so that only top-k logits remain
    """
    values, _ = torch.topk(logits, k)
    min_values = values[:, -1].unsqueeze(1).repeat(1, logits.shape[-1])
    return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)

In [7]:
prompt = tokenizer.encode("阿里巴巴集团宣布收购雅虎")
batch_size = 1

In [8]:
top_k = 50
temperature = 0.8
length = 0

start_predictions = torch.LongTensor([[101] + prompt]* batch_size).to(device)
mask = torch.ones(batch_size, start_predictions.shape[1]).to(device)
next_position_ids = torch.zeros(1, 1).long().to(device)
past = None

with torch.no_grad():
    # cache saves in past
    logits, past = model(start_predictions, mask, past=None, position_ids=next_position_ids)
    logits = logits[:, -1, :] / temperature
    logits = top_k_logits(logits, k=top_k)
    next_position_ids += 1

    sentence = []

    probs = F.softmax(logits, dim=-1)
    prob, prev_pred = torch.topk(probs, k=1, dim=-1)
    sentence.append(prev_pred)
    length += 1

    # decoding loop
    for i in range(500):
        mask = F.pad(mask, (0, 1), "constant", 1.0)
        logits, past = model(prev_pred, mask, past=past, past_length=length)
        logits = logits.squeeze(1) / temperature
        logits = top_k_logits(logits, k=top_k)
        probs = F.softmax(logits, dim=-1)
        prev_pred = torch.multinomial(probs, num_samples=1)
        sentence.append(prev_pred)
        length += 1

    sentence = torch.cat(sentence, dim=-1)



AttributeError: type object 'ChineseBERTBaseConfig' has no attribute 'n_head'

In [None]:
tokenizer.decode(sentence[0].tolist())