In [26]:
import onnxruntime as ort
import json, os, pathlib
import numpy as np

model_path = pathlib.Path("data/gptlite")

# load the model from onnx file and generate text
ort_session = ort.InferenceSession(model_path / "model.onnx")

# load the vocabulary
with open(model_path / "meta.json", "r") as f:
    meta = json.load(f)
    chars = meta['chars']
    vocab_size = len(chars)
    assert vocab_size == meta['vocab_size']
    block_size = meta['block_size']
    stoi = {ch: i for i, ch in enumerate(chars)}
    itos = {i: ch for i, ch in enumerate(chars)}

    # define encode and decode functions that convert strings to arrays of tokens and vice-versa
    encode = lambda x: [stoi[ch] for ch in x]
    decode = lambda x: "".join([itos[i] for i in x])
    vocab_size = len(stoi)
    print(f"Vocabulary size: {vocab_size}")


Vocabulary size: 65


In [34]:
np.array([[*([0]*253),1,2,3]]).shape

(1, 256)

In [27]:
ort_session.run(None, {'input.1': [[*([0]*253),1,2,3]]})[0].shape

(1, 65)

In [99]:
def convert(enc):
    # pad to 256 and add extra dim
    # return [[*([0]*(block_size-len(enc))),*enc]]
    res =  np.pad([enc[-256:]], ((0, 0), (max(block_size - len(enc), 0), 0)), mode="constant", constant_values=stoi["\n"])
    assert res.shape == (1, block_size)
    res = np.array(res, dtype=np.int64)
    return res


convert(encode("\n"* 595 + "hi ") )

array([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 

In [100]:
def generate_text_ort(start_text, max_new_tokens=100, temp=1.0, do_print=True):
    # encode the input text
    input_tokens_raw = encode(start_text)
    input_tokens = convert(input_tokens_raw)
    # input_tokens = np.pad(input_tokens, (0, block_size-len(input_tokens)), 'constant', constant_values=(0, 0))

    # generate new tokens
    if do_print:
        print(start_text, end="")
    for i in range(max_new_tokens):
        # get the logits for the next token
        logits = ort_session.run(None, {"input.1": input_tokens})[0][0]
        # print(logits.shape)
        logits = logits / temp
        probs = np.exp(logits) / np.sum(np.exp(logits))
        next_token = np.random.choice(vocab_size, p=probs)
        input_tokens_raw.append(next_token)
        input_tokens = convert(input_tokens_raw)
        if do_print:
            print(decode([next_token]), end="")
    if do_print:
        print("\n-----------")
    # print(input_tokens)
    # decode the generated tokens
    return decode(input_tokens[0]).lstrip()

In [102]:
print(generate_text_ort("""JULIET:""", max_new_tokens=400, temp=0.9))

JULIET:
What, what a well?

ROMEO:
A Mercution, that be repented
To change thy heavier thee?

PETRUCHIO:
O, you of all of poor bosom of what stock it there,
Being but married at the victory,
That you'll drink me from the world, and leave the substitute
Each their life of sights, instance can tween their
words; they are not follow; I'll be tepostioned
To buy and smiles as near upon our hand:
Most woe, bei
-----------
here,
Being but married at the victory,
That you'll drink me from the world, and leave the substitute
Each their life of sights, instance can tween their
words; they are not follow; I'll be tepostioned
To buy and smiles as near upon our hand:
Most woe, bei
