In [1]:
import onnxruntime as ort
import numpy as np
import json

In [2]:
# load model
model_sess = ort.InferenceSession('tmp/ngram/model.onnx', providers=['CUDAExecutionProvider'])
config = json.load(open('tmp/ngram/config.json'))

In [3]:
def exec_model(input):
    return model_sess.run(None, {'input': np.array(input, dtype=np.int32)})[0]

In [4]:
stoi = {ch: i for i, ch in enumerate(config["chars"])}
itos = {i: ch for i, ch in enumerate(config["chars"])}
encode = lambda s: [stoi[c] for c in s]  # encoder: take a string, output a list of integers
decode = lambda l: "".join([itos[i] for i in l])  # decoder: take a list of integers, output a string

In [5]:
def fix_func(old_func):
    def new_func(x):
        x = np.array(x)
        if len(np.shape(x)) > 1:
            return np.array([new_func(a) for a in x])
        else:
            return old_func(x)

    return new_func


@fix_func
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)


@fix_func
def choose(probs):
    return np.random.choice(len(probs), p=probs, size=(1,))
    # return [np.argmax(probs, axis=-1)]


def generate(x, *, max_len, temperature=1.0, amt=1):
    assert len(np.shape(x)) == 1, "x should be a list of integers"
    x = np.reshape(x, (1, -1))
    x = np.repeat(x, amt, axis=0)
    for _ in range(max_len):
        logits = exec_model(np.array(x))[:, -1] / temperature
        probs = softmax(logits)
        next_token = choose(probs)
        x = np.concatenate([x, next_token], axis=-1)
    return x


def generate_text(x, **kwargs):
    return [decode(a) for a in generate(encode(x), **kwargs)]

In [None]:
!

for text in generate_text('LUCENT', max_len=100, temperature=0.7, amt=2):
    print("----------")
    print(text)