In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers as l
from transformers import *
import utils


MAX_SEQ_LEN = 128
MODEL_NAME = "distilbert-base-cased"
classes = utils.docred_labels()
num_labels = len(classes)

tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
config = DistilBertConfig.from_pretrained(MODEL_NAME, num_labels=num_labels)

model = utils.distilbert_model(MODEL_NAME, config, num_labels, MAX_SEQ_LEN)
model.load_weights("distilbert.h5")

In [2]:
def infer(input_text, model, max_seq_len=128):
    tokens = tokenizer.encode(input_text)
    attention_mask = np.asarray([1 for _ in range(len(tokens))])
    tokens = tf.keras.preprocessing.sequence.pad_sequences([tokens],
                                                            maxlen=max_seq_len,
                                                            dtype="int32",
                                                            padding="post",
                                                            truncating="post")
    attention_mask = tf.keras.preprocessing.sequence.pad_sequences([attention_mask],
                                                            maxlen=max_seq_len,
                                                            dtype="int32",
                                                            padding="post",
                                                            truncating="post")
    pred = model.predict([tokens, attention_mask])[0]
    pred = np.argmax(pred)
    return classes[pred]

In [3]:
infer("Tom is a citizen of Singapore", model)

'country_of_citizenship'