In [None]:
import sys
sys.path.append('..')
import numpy as np
import cv2
import os
import onnxruntime as ort
from utils.sampling import softmax, multinomial
from utils.video import write_video, transpose_and_clip
from IPython.display import Image, Video
import keyboard

In [None]:
# load model session
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
provider = ["CUDAExecutionProvider", "CPUExecutionProvider"]
session = ort.InferenceSession(f'../models/gpt2m.onnx', options, provider)
# print shapes
input_shapes =  { i.name: (i.shape, i.type) for i in session.get_inputs()  }
output_shapes = { i.name: (i.shape, i.type) for i in session.get_outputs() }
print('input shapes : ', input_shapes)
print('output shapes: ', output_shapes)

# load decoder
decoder_session = ort.InferenceSession(f'../models/decoder.onnx', options, provider)
encoder_session = ort.InferenceSession(f'../models/encoder.onnx', options, provider)

In [None]:
# set constants and define functions
TOKENS_PER_FRAME = 129
MAX_CONTEXT_SIZE = 20*129
BOS_TOKEN        = 1024
PIXELS_PER_CHUNK = 16

def generate_frame_tokens(session, tokens, override=None):
  data = {'input_ids': tokens,
          **{f'past_{i}': np.zeros((2, 1, 16, 0, 64), dtype=np.float16) for i in range(24)}
          }

  data_ortvalue = {}
  for k in data:
      data_ortvalue[k] = ort.OrtValue.ortvalue_from_numpy(data[k], 'cuda', 0)

  io_binding = session.io_binding()
  for k in data:
      io_binding.bind_ortvalue_input(k, data_ortvalue[k])

  output_tokens = []
  for i in range(TOKENS_PER_FRAME):
    if i == 0 and override is not None:
      output_tokens.append([[override]])
      continue

    io_binding.bind_output('logits', 'cuda')
    
    for j in range(24):
      io_binding.bind_output(f'present_{j}', 'cuda')  

    session.run_with_iobinding(io_binding)
    ort_output = io_binding.get_outputs()  

    logits = ort_output[0].numpy()[:,-1,:]
    logits = logits.astype(np.float64)
    probs = softmax(logits, axis=1)
    tokens = multinomial(probs).astype(np.int32)
        
    output_tokens.append(tokens)
    data_ortvalue['input_ids'] = ort.OrtValue.ortvalue_from_numpy(tokens, 'cuda', 0)
    io_binding.bind_ortvalue_input('input_ids', data_ortvalue['input_ids'])

    for j in range(24):
      io_binding.bind_ortvalue_input(f'past_{j}', ort_output[1+j])
  
  return np.concatenate(output_tokens, axis=1)

sg_width = PIXELS_PER_CHUNK
shift_gradient = np.zeros((128, sg_width), dtype=np.float32)
for i in range(sg_width):
    shift_gradient[:, i] = float(i) / float(sg_width)

def save_frame(tokens_gen, img_name, upscale=True):
    output = decoder_session.run(None, {'encoding_indices': tokens_gen[-1].reshape(1,8,16)})
    output = {o.name: x for o,x in zip(decoder_session.get_outputs(), output)}
    output = output['big_decoded_img']

    frame_path = f"../tmp/{img_name}"
    if not os.path.exists(os.path.dirname(frame_path)):
        os.makedirs(os.path.dirname(frame_path))
    decoded_frame = transpose_and_clip([output])[0]
    cv2.imwrite(frame_path, decoded_frame)
    if upscale:
        img = cv2.imread(frame_path)
        img = cv2.resize(img, (1280, int(1280*img.shape[0]/img.shape[1])))
        cv2.imwrite(frame_path, img)

def shift_right(tokens_gen):
    for i in range(tokens_gen.shape[2] - 1):
        tokens_gen[0, :, i] = tokens_gen[0, :, i + 1]
def shift_left(tokens_gen):
    for i in range(tokens_gen.shape[2] - 1, 0, -1):
        tokens_gen[0, :, i] = tokens_gen[0, :, i - 1]

def turn_gen(tokens_gen, shift_func):
    tmp_path = "../tmp/temp.png"
    save_frame(tokens_gen, tmp_path, upscale=False)
    img_orig = cv2.imread(tmp_path)
    shift_func(tokens_gen)
    save_frame(tokens_gen, tmp_path, upscale=False)
    img_turn = cv2.imread(tmp_path)
    return img_orig, img_turn
def encode_turn(img):
    outputs = encoder_session.run(None, {'big_img': img.transpose(2,0,1)[None].astype(np.float32)})
    outputs = {o.name: x for o,x in zip(encoder_session.get_outputs(), outputs)}
    return outputs['encoding_indices'].ravel()

def turn_left(tokens_gen, shift_amount):
    assert shift_amount > 0 and shift_amount < PIXELS_PER_CHUNK
    img_orig, img_turn = turn_gen(tokens_gen, shift_right)

    img_orig[:, shift_amount:] = img_orig[:, :-shift_amount]
    img_orig[:, :shift_amount] = img_turn[:, PIXELS_PER_CHUNK-shift_amount:PIXELS_PER_CHUNK]
    for c in range(3):
        img_orig[:, shift_amount:shift_amount+PIXELS_PER_CHUNK, c] = \
            img_orig[:, shift_amount:shift_amount+PIXELS_PER_CHUNK, c] * shift_gradient + \
            img_turn[:, PIXELS_PER_CHUNK:2*PIXELS_PER_CHUNK, c] * (1 - shift_gradient)
    return encode_turn(img_orig)

def turn_right(tokens_gen, shift_amount):
    assert shift_amount > 0 and shift_amount < PIXELS_PER_CHUNK
    img_orig, img_turn = turn_gen(tokens_gen, shift_left)

    img_orig[:, :-shift_amount] = img_orig[:, shift_amount:]
    img_orig[:, -shift_amount:] = img_turn[:, -PIXELS_PER_CHUNK:-PIXELS_PER_CHUNK+shift_amount]
    for c in range(3):
        img_orig[:, -shift_amount-PIXELS_PER_CHUNK:-shift_amount, c] = \
            img_orig[:, -shift_amount-PIXELS_PER_CHUNK:-shift_amount, c] * shift_gradient + \
            img_turn[:, -2*PIXELS_PER_CHUNK:-PIXELS_PER_CHUNK, c] * (1 - shift_gradient)
    return encode_turn(img_orig)

In [None]:
# RUNNER

# load tokens
tokens_condition = np.load("../examples/tokens.npy").astype(np.int32)
tokens_condition = np.c_[np.ones(len(tokens_condition), dtype=np.int32)*BOS_TOKEN, tokens_condition]
tokens_condition = tokens_condition[-(MAX_CONTEXT_SIZE//TOKENS_PER_FRAME - 1):].reshape(1,-1)

frames = []

movement = 0
frame_count = 0
while True:
    tokens = generate_frame_tokens(session, tokens_condition[:, -(MAX_CONTEXT_SIZE-TOKENS_PER_FRAME):])
    tokens_condition = np.concatenate([tokens_condition, tokens], axis=1)

    # reshape and remove BOS token
    tokens_gen = tokens.reshape(-1,TOKENS_PER_FRAME)
    tokens_gen = tokens_gen[:, 1:].astype(np.int64)
    tokens_gen = tokens_gen[-1].reshape(1,8,16)

    def apply_turn_left(turn_amount):
        new_state = turn_left(tokens_gen, turn_amount)
        tokens_condition[0, -TOKENS_PER_FRAME+1:] = new_state.astype(np.int32).reshape(1,-1)
        return new_state
    def apply_turn_right(turn_amount):
        new_state = turn_right(tokens_gen, turn_amount)
        tokens_condition[0, -TOKENS_PER_FRAME+1:] = new_state.astype(np.int32).reshape(1,-1)
        return new_state
    
    if movement < 0:
        tokens_gen = apply_turn_left(-movement)
    elif movement > 0:
        tokens_gen = apply_turn_right(movement)
    
    # reshape and remove BOS token
    tokens_gen = tokens.reshape(-1,TOKENS_PER_FRAME)
    tokens_gen = tokens_gen[:, 1:].astype(np.int64)
    tokens_gen = tokens_gen[-1].reshape(1,8,16)

    output = decoder_session.run(None, {'encoding_indices': tokens_gen})
    output = {o.name: x for o,x in zip(decoder_session.get_outputs(), output)}
    output = output['big_decoded_img']
    frames.append(output)

    frame_path = "../tmp/frame.png"
    if not os.path.exists(os.path.dirname(frame_path)):
        os.makedirs(os.path.dirname(frame_path))
    decoded_frame = transpose_and_clip([output])[0]

    cv2.imwrite(frame_path, decoded_frame)
    img = cv2.imread(frame_path)
    img = cv2.resize(img, (1280, int(1280*img.shape[0]/img.shape[1])))
    cv2.putText(img, str(frame_count), (10,30), cv2.FONT_ITALIC, 1, (0,0,255), 2)
    frame_count += 1
    cv2.imwrite(frame_path, img)

    if keyboard.is_pressed("q"):
        break
    if keyboard.is_pressed("left"):
        movement = -2
    elif keyboard.is_pressed("right"):
        movement = 2
    else:
        movement = 0

if len(frames) > 20:
    decoded_video = transpose_and_clip(frames)
    save_dst = '../tmp/generated.mp4'
    write_video(decoded_video, save_dst, fps=20)    
    Video(save_dst, embed=True, width=700)

######################
# Key   | Action     #
# ================== #
# left  | turn left  #
# right | turn right #
# q     | quit       #
######################

# The frames will be generated at ../tmp/frame.png
# The controls should work OS wide, I launch this block, open the
# generated frame in a VSCode tab, and can control the "car" from there.