In [2]:
import torch
from models.gen.edm import EDM
from models.gen.blocks import UNet
import tqdm
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
from data.data import SequencesDataset
import torchvision.transforms as transforms
import os
import numpy as np
import matplotlib.pyplot as plt
import random
import tqdm

In [3]:
input_channels = 3
context_length = 4
actions_count = 5
batch_size = 1
num_workers = 2
device = "cuda" if torch.cuda.is_available() else "cpu"
# For Mac OS
if torch.backends.mps.is_available():
    device = "mps"
ROOT_PATH = "../"
def local_path(path):
    return os.path.join(ROOT_PATH, path)
MODEL_PATH = local_path("models/model.pth")

In [4]:
edm = EDM(
    p_mean=-1.2,
    p_std=1.2,
    sigma_data=0.5,
    model=UNet((input_channels) * (context_length + 1), 3, None, actions_count, context_length),
    context_length=context_length,
    device=device
)
edm.load_state_dict(torch.load(MODEL_PATH, map_location=device)["model"])

  edm.load_state_dict(torch.load(MODEL_PATH, map_location=device)["model"])


<All keys matched successfully>

In [5]:
transform_to_tensor = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((.5,.5,.5), (.5,.5,.5))
])

dataset = SequencesDataset(
    images_dir=local_path("training_data_initial/snapshots"),
    actions_path=local_path("training_data_initial/actions"),
    seq_length=context_length,
    transform=transform_to_tensor
)

In [None]:
from IPython.display import display, clear_output, Image as iImage
import ipywidgets as widgets
from PIL import Image
import time
import threading

class State:
    def __init__(self):
        self.action = 0
        self.is_running = False

state = State()

def on_button_click(input_action):
    state.action = input_action

# Create buttons
left_button = widgets.Button(description='Left')
right_button = widgets.Button(description='Right')
up_button = widgets.Button(description='Up')
down_button = widgets.Button(description='Down')
start_button = widgets.Button(description='Start')
stop_button = widgets.Button(description='Stop')

directions = {
    0: "Right",
    1: "Left",
    2: "Up",
    3: "Down"
}

# Set up button callbacks
right_button.on_click(lambda b: on_button_click(0))
left_button.on_click(lambda b: on_button_click(1))
up_button.on_click(lambda b: on_button_click(2))
down_button.on_click(lambda b: on_button_click(3))

# Display buttons horizontally
buttons = widgets.HBox([left_button, widgets.VBox([up_button, down_button]), right_button, start_button, stop_button])

button_output = widgets.Output()
image_output = widgets.Output()

with button_output:
    display(buttons)

def render_loop(image_output: widgets.Output):
    def get_np_img(tensor: torch.Tensor) -> np.ndarray:
        return (tensor * 127.5 + 127.5).long().clip(0,255).permute(1,2,0).detach().cpu().numpy().astype(np.uint8)
    frame_number = 0
    fps = 1
    frame_time = 1 / fps
    length = len(dataset)
    index = random.randint(0, length - 1)
    img, last_imgs, actions = dataset[index]
    
    img = img.to(device)
    last_imgs = last_imgs.to(device)
    actions = actions.to(device)
    gen_imgs = last_imgs.clone()
    while state.is_running and frame_number < 80:
        start_time = time.time()
        actions = torch.concat((actions, torch.tensor([state.action], device=device)))
        gen_img = edm.sample(
            10,
            img.shape,
            gen_imgs[-context_length:].unsqueeze(0),
            actions[-context_length:].unsqueeze(0)
        )[0]
        gen_imgs = torch.concat([gen_imgs, gen_img[None, :, :, :]], dim=0)
        gen_img = get_np_img(gen_img)
        import io
        buffer = io.BytesIO()
        Image.fromarray(gen_img).resize((360, 360), Image.Resampling.LANCZOS).save(buffer, format='PNG')
        img_bytes = buffer.getvalue()
        image_output.outputs = []
        image_output.append_stdout('Direction: {}'.format(directions[state.action]))
        image_output.append_display_data(iImage(data=img_bytes))
        frame_number += 1
    
        # Maintain constant frame rate
        elapsed_time = time.time() - start_time
        if elapsed_time < frame_time:
            time.sleep(frame_time - elapsed_time)
    image_output.outputs = []
    image_output.append_stdout('Finished rendering')
    global render_thread
    render_thread = None

from typing import Optional
render_thread: Optional[threading.Thread] = None

def start_render_loop():
    global render_thread
    if render_thread is not None:
        return
    render_thread = threading.Thread(target=render_loop, args=(image_output,))
    render_thread.start()
    state.is_running = True

def stop_render_loop():
    global render_thread
    state.is_running = False
    if render_thread is not None and render_thread.is_alive():
        render_thread.join() # Wait for thread to finish
    render_thread = None
    image_output.outputs = []

start_button.on_click(lambda b: start_render_loop())
stop_button.on_click(lambda b: stop_render_loop())

display(button_output)
display(image_output)

Output()

Output()