# Conversion to ONNX
ONNX is a different format for running machine learning models. The ONNX format is much faster on CPU, sometimes 5 times as fast as PyTorch!

While the EAWSW model is designed to be small, accurate and accessible, for some people it's still too much to run...

Hosting the model as a free service for players is an option. An ONNX version of the model allows us to host the model on CPU yet have faster response times! Given that the model is made in a time with chip shortage, running on hardware I already have inside a server is efficient, scalable and cheaper.

An important note is that ONNX doesn't execute logic by itself, and you have to do that yourself, `onnx_model_manager.py` intends to deal with this for us.

In [1]:
%load_ext autoreload
%autoreload 2

from model_utils import train_model, split_data, split_branches, get_model, set_pretrained_model_dropout, get_dataset
from config import Config
import json
import matplotlib.pyplot as plt
%matplotlib inline
import math
import random
import time
import onnx
import logging
from onnx_model_manager import OnnxModelManager
from onnxruntime.quantization import quantize_dynamic, QuantType
import os
import datasets
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from model_manager import ModelManager

In [2]:
saved_model_path = os.path.join("models", "awsw_main")
saved_model_onnx_path = os.path.join("models", "awsw_onnx")
if not os.path.exists(os.path.join(saved_model_path, "special_tokens_map.json")):
    print("Copying config files from huggingface (needed for conversion)... WARNING: this assumes the structure of the model isn't changed!")
    !cd $saved_model_path && git clone https://huggingface.co/$Config.base_model_name
    !cp -n $saved_model_path/$Config.base_model_basename/* $saved_model_path
    !rm -rf $saved_model_path/$Config.base_model_basename
if not os.path.exists(os.path.join(saved_model_onnx_path, "model.onnx")):
    !python3 -m transformers.onnx --model=$saved_model_path --feature=causal-lm --atol=1e-03 $saved_model_onnx_path

Cloning into 'gpt-neo-125M'...
remote: Enumerating objects: 44, done.[K
remote: Total 44 (delta 0), reused 0 (delta 0), pack-reused 44[K
Unpacking objects: 100% (44/44), 543.14 KiB | 1.23 MiB/s, done.
Using framework PyTorch: 1.10.1+cu113
Overriding 1 configuration item(s)
	- use_cache -> False
  assert batch_size > 0, "batch_size has to be defined and > 0"
Validating ONNX model...
	-[✓] ONNX model output names match reference model ({'logits'})
	- Validating ONNX Model output "logits":
		-[✓] (2, 8, 50257) matches (2, 8, 50257)
		-[✓] all values close (atol: 0.001)
All good, model saved at: models/awsw_onnx/model.onnx


In [3]:
def optimize_onnx():
    model_quant = os.path.join(saved_model_onnx_path, "model_quant.onnx")
    if not os.path.exists(model_quant):
        model_fp32 = os.path.join(saved_model_onnx_path, "model.onnx")
        model_opt = os.path.join(saved_model_onnx_path, "model-opt.onnx")
        quantized_model = quantize_dynamic(model_fp32, model_quant, weight_type = QuantType.QInt8)
        #!rm $model_opt
optimize_onnx()

Ignore MatMul due to non constant B: /[MatMul_102]
Ignore MatMul due to non constant B: /[MatMul_133]
Ignore MatMul due to non constant B: /[MatMul_235]
Ignore MatMul due to non constant B: /[MatMul_266]
Ignore MatMul due to non constant B: /[MatMul_368]
Ignore MatMul due to non constant B: /[MatMul_399]
Ignore MatMul due to non constant B: /[MatMul_501]
Ignore MatMul due to non constant B: /[MatMul_532]
Ignore MatMul due to non constant B: /[MatMul_634]
Ignore MatMul due to non constant B: /[MatMul_665]
Ignore MatMul due to non constant B: /[MatMul_767]
Ignore MatMul due to non constant B: /[MatMul_798]
Ignore MatMul due to non constant B: /[MatMul_900]
Ignore MatMul due to non constant B: /[MatMul_931]
Ignore MatMul due to non constant B: /[MatMul_1033]
Ignore MatMul due to non constant B: /[MatMul_1064]
Ignore MatMul due to non constant B: /[MatMul_1166]
Ignore MatMul due to non constant B: /[MatMul_1197]
Ignore MatMul due to non constant B: /[MatMul_1299]
Ignore MatMul due to non c

In [4]:
# Tell pytorch to run this model on the GPU.
device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
# device_name = 'cpu'
device = torch.device(device_name)

onnx_model_manager = OnnxModelManager(os.path.join(saved_model_onnx_path, "model-opt.onnx"))
onnx_model_manager_quant = OnnxModelManager(os.path.join(saved_model_onnx_path, "model_quant.onnx"))
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125M')
model = AutoModelForCausalLM.from_pretrained(saved_model_path)
model.to(device)
model.eval()
model_manager = ModelManager(model=model, tokenizer=tokenizer, device=device)
print(f"Pretrained model loaded on {device_name}")

Pretrained model loaded on cuda:0


In [5]:
prompt = "In my dreams, I'm a dragon"
for i in range(2):
    print("ONNX:", onnx_model_manager.say_raw(prompt, do_sample=True))
    print("ONNX (Quantized):", onnx_model_manager_quant.say_raw(prompt, do_sample=True))
    print("PyTorch:", model_manager.say_raw(prompt, 50, 0.7))
    print('-' * 100)

ONNX: In my dreams, I'm a dragon."<p><msg>c "I see."<|endoftext|>
ONNX (Quantized): In my dreams, I'm a dragon. My dreams, I'm a dragon. I'm a dragon. I'm a. My dreams, I'm a. And...

And I'm a dragon, and I'm arian. My opponent.<|endoftext|>
PyTorch: In my dreams, I'm a dragon. I'm a human with horns, wings, and tails, and I can feel the hunger that comes from hearing those stories."<d><scn>hallway<msg>Rz "You can't be serious? You don't have to explain it to them."<p><msg>c "But you do."<d><scn>hallway<msg>Rz "If I didn't have the opportunity, I certainly would."<p><msg>c "It's not as if you have to explain it to them."<p>
----------------------------------------------------------------------------------------------------
ONNX: In my dreams, I'm a dragon."<p><msg></p><d>"<d><scn>o2<msg>Ad "Oh. I see."<p><msg>c "I see."<p><msg>c "I see."<p><msg>c "I see."<p><msg>c "I see you, [d]kay."<p><msg>c "I see ya."<p><msg>c "leave"<d><scn>park1<msg>m "Oh, Park1<msg>Ip "Oh
ONNX (Quantized): In m

# Testing

We created a few past (for context) + present prompts (player input) and see the different reactions. This way, we can test the models across different iterations.
The first test involves a old prompt to compare the pre-trained model with the one trained on AWSW. Did it manage to store it's data well? Is it able to write down things that have nothing to do with AWSW? (So we know we didn't overfit).

**This test generates boring and repetetive** replies! It's because we use no good sampling algorithm, but it does give us a indication of what the model has learned!

In [6]:
prompts = [
    ('<p><msg>c "Hey Remy!"<d><scn>park2<msg>Ry "Hello, [player_name]."', "How are you?"),
    ('<p><msg>c "I was with Lorem today."<d><scn>park2<msg>Ad "Very nice."', "What do you think of Lorem?"),
    ('<p><msg>m "In Tatsu park, Adine and I sat down."', "Oh my god, Adine. What is this?"),
    ('<p><msg>m "I sat down on a chair in Anna\'s lab."', "What will we do here?"),
]

for (past, prompt) in prompts:
    print(f"Prompt: {prompt}")
    reply = model_manager.say(past, prompt)
    print(f"[Pytorch] Reply: {reply}\n")
    reply = onnx_model_manager.say(past, prompt)
    print(f"[ONNX] Reply: {reply}\n")
    reply = onnx_model_manager_quant.say(past, prompt)
    print(f"[ONNX Quantized] Reply: {reply}\n")
    print("-" * 10)

Prompt: How are you?
[Pytorch] Reply: park2<msg>Ry "I'm alright. I just wanted to talk to you about something."<p><msg>c "I was with Katsuharu today"<d><scn>park2<msg>Ry "Very nice"<|endoftext|>

[ONNX] Reply: park2<msg>Ry "I'm alright. I just wanted to talk to you about something."<p><msg>c "I was with Katsuharu today"<d><scn>park2<msg>Ry "Very nice"<|endoftext|>

[ONNX Quantized] Reply: <msg>office<msg<msg>d "<p>cave"<d><scn><d>cave<msg<msg>r "<d><HER2>set<msg>Rey<skip><d><p<oje< damned>cave<><msg>Ad "Begging for a nice and precious chance<d>Ad only..."<p><msg>c "Begging for not so"<leave>rcast<msg>Ad "I am not. IADers...]p<p>cxx<d>Ad

----------
Prompt: What do you think of Lorem?
[Pytorch] Reply: park2<msg>Ad "I think he is funny."<|endoftext|>

[ONNX] Reply: park2<msg>Ad "I think he is funny."<|endoftext|>

[ONNX Quantized] Reply: <msg>Ad "<p><msg>c "
<p> adore naomina
<p> ad "Not only is it its own, Adine the ad-taker, and I the only adore."<d><<|endoftext|>

----------
Prompt: O

# Sampling test

This is gonna be interesting!

In [7]:
for (past, prompt) in prompts:
    print(f"Prompt: {prompt}")
    reply = model_manager.say(past, prompt, top_k = 50, top_p = 0.7)
    print(f"[Pytorch] Reply: {reply}\n")
    reply = onnx_model_manager.say(past, prompt, do_sample = True)
    print(f"[ONNX] Reply: {reply}\n")
    reply = onnx_model_manager_quant.say(past, prompt, do_sample = True)
    print(f"[ONNX Quantized] Reply: {reply}\n")
    print("-" * 10)

Prompt: How are you?
[Pytorch] Reply: park2<msg>Ry "Not bad, thanks for the flowers."<p><msg>c "I was with Sebastian today"<d><scn>park2<msg>Ry "Very nice"<|endoftext|>

[ONNX] Reply: park2<msg>Ry "Good. I was just leaving for a meeting today."<p><msg>c "I was with Zhong today"<d><scn>park2<msg>Ry "Very nice!"<|endoftext|>

[ONNX Quantized] Reply: <msg>office<msg<msg>d "<p>cave<HER2?is<msg>sad<HERna>c "She looks like she is going to be at least at least being at least at least as soon as the [shendongle<ERT](she may be that) [# of notches]/loremapta"<parms>cxx<msg>rtle<msg>c "Ad "I see that the ad is not so chosen and that its not even the best its notched[/possiblefights not only the<br><pad>

----------
Prompt: What do you think of Lorem?
[Pytorch] Reply: park2<msg>Ad "I think he is funny."<|endoftext|>

[ONNX] Reply: park2<msg>Ad "I like the fact that he is also a member of the police department. What a great opportunity."<p><msg>c "What are you talking about?"<d><scn>park2<msg>Ad "

# RP test
Testing out the injected roleplay actions

In [8]:
test_rps = [
    "Visit Lorem",
    "Meet with Lorem",
    "Visit Adine",
    "Fight",
    "Attack"
]

for rp in test_rps:
    print(f'[Pytorch] {rp} -> {model_manager.say("", rp, top_k = 50, top_p = 0.7)}')
    print(f'[ONNX] {rp} -> {onnx_model_manager.say("", rp, do_sample = True)}')
    print(f'[ONNX Quantized] {rp} -> {onnx_model_manager_quant.say("", rp, do_sample = True)}')
    print("-" * 10)
    
print("Lowercase test")

for rp in test_rps:
    rp = rp[0].lower() + rp[1:]
    print(f'[Pytorch] {rp} -> {model_manager.say("", rp, top_k = 50, top_p = 0.7)}')
    print(f'[ONNX] {rp} -> {onnx_model_manager.say("", rp, do_sample = True)}')
    print(f'[ONNX Quantized] {rp} -> {onnx_model_manager_quant.say("", rp, do_sample = True)}')
    rp = rp.lower()
    print(f'[Pytorch] {rp} -> {model_manager.say("", rp, top_k = 50, top_p = 0.7)}')
    print(f'[ONNX] {rp} -> {onnx_model_manager.say("", rp, do_sample = True)}')
    print(f'[ONNX Quantized] {rp} -> {onnx_model_manager_quant.say("", rp, do_sample = True)}')
    print("-" * 10)

[Pytorch] Visit Lorem -> loremapt<msg>Lo "Oh, [player_name], I wasn't expecting visitors."<|endoftext|>
[ONNX] Visit Lorem -> loremapt<msg>Lo "Oh, really?"<p><msg>c "I was with Katsuhiko today"<d><scn>loremapt<msg>Lo "Very nice"<|endoftext|>
[ONNX Quantized] Visit Lorem -> <msg>m<p<msg>d "M<d>Adine<t><msg>c: "<d><p><msg>c "Be posted"<d><msg>Ad "d>Adine<t>Adspace<d>Adspace<><p><c< borderline"<d><d><<>oreapts<ERTfightersinemediated><td><IANNared<td><<p><3<^<p><actualjeckxx<Iadj><Iadj>ad<Iadj
----------
[Pytorch] Meet with Lorem -> park1<msg>Em "Oh, [player_name], I wasn't expecting visitors."<|endoftext|>
[ONNX] Meet with Lorem -> loremapt<msg>Lo "Oh, I see."<p><msg>c "I was talking with Lorem about it, and I told her that she should go back to her home town. If she doesn't want to come back, I can arrange to meet here, but I'm not sure what else I should be doing."<d><scn>loremapt<msg>Lo "She should probably go now, though."<p><msg>c "Alright."<d><scn>loremapt<msg>Lo "I'm sorry if this 