In [1]:
import torch
import pytorch_lightning as pl
from transformers import BertTokenizer

from model import PlModel, BertLinearMix
from data import load_label

In [2]:
label_file = r"..\data\label.json"
label2id, id2label = load_label(label_file)

In [3]:
model = BertLinearMix("bert-base-chinese", len(label2id), load_pretrain=False)
pl_model = PlModel.load_from_checkpoint("./model_dir/epoch=2-step=30.ckpt", model=model, id2label=id2label)
pl_model.eval()

PlModel(
  (model): BertLinearMix(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(21128, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (La

In [4]:
# tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
tokenizer = BertTokenizer.from_pretrained("./model_dir/")

In [5]:
query = "加快产城融合 以科技创新引领新城区建设 新城区,城镇化率,中心城区,科技新城,科技创新"
inputs = tokenizer(query, padding="max_length", truncation=True, max_length=64, return_tensors="pt")
print(inputs["input_ids"].shape)
inputs.keys()

torch.Size([1, 64])


dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [6]:
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# labels, probs = pl_model.predict_step((input_ids, attention_mask), 0)
with torch.no_grad():
    labels, probs = pl_model.predict_step((input_ids, attention_mask), 0)
print(labels)
print(probs)

['tech']
[0.20815144]


In [7]:
pl_model.init_predict("./model_dir/", 64)

In [8]:
pl_model.predict_raw([query])

[{'text': '加快产城融合 以科技创新引领新城区建设 新城区,城镇化率,中心城区,科技新城,科技创新',
  'pred_label': 'tech',
  'prob': 0.20815144}]