In [1]:
import torch
from edm.edm import EDM
import tqdm
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
import edm.modules as modules
from data import SequencesDataset
from train import train
import torchvision.transforms as transforms
import os
import numpy as np
import matplotlib.pyplot as plt
import random
import tqdm

In [2]:
def save_imgs(
    frames_real: torch.Tensor,
    frames_generation: torch.Tensor,
    path: str
):
    def get_np_img(tensor: torch.Tensor) -> np.ndarray:
        return (tensor * 127.5 + 127.5).long().clip(0,255).permute(1,2,0).detach().cpu().numpy().astype(np.uint8)

    height_row = 5
    col_width = 5
    cols = len(frames_real)
    fig, axes = plt.subplots(2, cols, figsize=(col_width * cols, height_row * 2))
    for row in range(2):
        for i in range(len(frames_real)):
            axes[row, i].imshow(get_np_img(frames_real[i]) if row == 0 else get_np_img(frames_generation[i]))
    plt.subplots_adjust(wspace=0, hspace=0)
    
    # Save the combined figure
    plt.savefig(path, bbox_inches='tight', pad_inches=0)
    plt.close()

In [3]:
input_channels = 3
context_length = 4
actions_count = 5
batch_size = 1
num_workers = 2
device = "cuda" if torch.cuda.is_available() else "cpu"
# For Mac OS
if torch.backends.mps.is_available():
    device = "mps"
ROOT_PATH = "../"
def local_path(path):
    return os.path.join(ROOT_PATH, path)
MODEL_PATH = local_path("test_models/diffusion/model_25_edm.pth")

In [4]:
edm = EDM(
    p_mean=-1.2,
    p_std=1.2,
    sigma_data=0.5,
    model=modules.UNet((input_channels) * (context_length + 1), 3, None, actions_count, context_length),
    context_length=context_length,
    device=device
)
edm.load_state_dict(torch.load(MODEL_PATH, map_location=device)["model"])

  edm.load_state_dict(torch.load(MODEL_PATH, map_location=device)["model"])


<All keys matched successfully>

In [5]:
transform_to_tensor = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((.5,.5,.5), (.5,.5,.5))
])

dataset = SequencesDataset(
    images_dir=local_path("training_data/snapshots"),
    actions_path=local_path("training_data/actions"),
    seq_length=context_length,
    transform=transform_to_tensor
)

In [6]:
# from IPython.display import clear_output
# import keyboard
# import time

# fps = 10
# frame_time = 1.0 / fps
# max_frames = 60
# running = True

# # Movement flags
# inputted_action = 4

# # Keyboard event handlers
# def on_key_press(event):
#     global inputted_action
#     if event.name == 'up':
#         inputted_action = 1
#     elif event.name == 'down':
#         inputted_action = 2
#     elif event.name == 'enter':
#         inputted_action = 3

# # Register keyboard listeners
# keyboard.on_press(on_key_press)

# start_time = time.time()
# length = len(dataset)
# index = random.randint(0, length - 1)
# img, last_imgs, actions = dataset[index]

# img = img.to(device)
# last_imgs = last_imgs.to(device)
# actions = actions.to(device)
# gen_imgs = last_imgs.clone()
# frame_number = 0

# # Main game loop
# plt.figure(figsize=(8, 6))
# try:
#     while running or frame_number < max_frames:
#         actions = torch.concat((actions, torch.tensor([inputted_action])))
#         inputted_action = 4
#         gen_img = edm.sample(
#             img.shape,
#             gen_imgs[-context_length:].unsqueeze(0),
#             actions[-context_length:].unsqueeze(0),
#             num_steps=5
#         )[0][:, 2:-2, 2:-2]
#         gen_imgs = torch.concat([gen_imgs, gen_img[None, :, :, :]], dim=0)
        
#         clear_output(wait=True)
#         def get_np_img(tensor: torch.Tensor) -> np.ndarray:
#             return (tensor * 127.5 + 127.5).long().clip(0,255).permute(1,2,0).detach().cpu().numpy().astype(np.uint8)

#         plt.imshow(get_np_img(gen_img), cmap='gray')
#         plt.axis('off')
#         plt.title(f'FPS: 10')
#         plt.draw()
#         plt.pause(0.01)
#         frame_number += 1

#         # Maintain constant frame rate
#         elapsed_time = time.time() - start_time
#         if elapsed_time < frame_time:
#             time.sleep(frame_time - elapsed_time)

# except KeyboardInterrupt:
#     running = False

# # Cleanup
# keyboard.unhook_all()
# plt.close()

In [8]:
from IPython.display import display, clear_output, Image as iImage
import ipywidgets as widgets
from PIL import Image
import time
import threading

class State:
    def __init__(self):
        self.action = 4
        self.is_running = False

state = State()

def on_button_click(input_action):
    state.action = input_action

# Create buttons
circle_button = widgets.Button(description='Rotate left')
square_button = widgets.Button(description='Rotate right')
start_button = widgets.Button(description='Start')
stop_button = widgets.Button(description='Stop')

# Set up button callbacks
circle_button.on_click(lambda b: on_button_click(1))
square_button.on_click(lambda b: on_button_click(2))

# Display buttons horizontally
buttons = widgets.HBox([circle_button, square_button, start_button, stop_button])

button_output = widgets.Output()
image_output = widgets.Output()

with button_output:
    display(buttons)

def render_loop(image_output: widgets.Output):
    def get_np_img(tensor: torch.Tensor) -> np.ndarray:
        return (tensor * 127.5 + 127.5).long().clip(0,255).permute(1,2,0).detach().cpu().numpy().astype(np.uint8)
    
    frame_number = 0
    fps = 10
    frame_time = 1 / fps
    length = len(dataset)
    index = random.randint(0, length - 1)
    img, last_imgs, actions = dataset[index]
    
    img = img.to(device)
    last_imgs = last_imgs.to(device)
    actions = actions.to(device)
    gen_imgs = last_imgs.clone()
    while state.is_running and frame_number < 80:
        start_time = time.time()
        actions = torch.concat((actions, torch.tensor([state.action], device=device)))
        state.action = 4
        gen_img = edm.sample(
            img.shape,
            gen_imgs[-context_length:].unsqueeze(0),
            actions[-context_length:].unsqueeze(0),
            num_steps=5
        )[0][:, 2:-2, 2:-2]
        gen_imgs = torch.concat([gen_imgs, gen_img[None, :, :, :]], dim=0)
        gen_img = get_np_img(gen_img)
        # gen_img = get_np_img(torch.ones(3,60,60) * (state.action - 2))
        import io
        buffer = io.BytesIO()
        Image.fromarray(gen_img).resize((360, 360), Image.Resampling.LANCZOS).save(buffer, format='PNG')
        img_bytes = buffer.getvalue()
        image_output.outputs = []
        image_output.append_stdout('{}'.format(state.action))
        image_output.append_display_data(iImage(data=img_bytes))
        frame_number += 1
    
        # Maintain constant frame rate
        elapsed_time = time.time() - start_time
        if elapsed_time < frame_time:
            time.sleep(frame_time - elapsed_time)

from typing import Optional
render_thread: Optional[threading.Thread] = None

def start_render_loop():
    global render_thread
    if render_thread is not None:
        return
    render_thread = threading.Thread(target=render_loop, args=(image_output,))
    render_thread.start()
    state.is_running = True

def stop_render_loop():
    global render_thread
    state.is_running = False
    if render_thread is not None and render_thread.is_alive():
        render_thread.join()  # Wait for thread to finish
    # button_output.append_stdout("finish")
    render_thread = None

start_button.on_click(lambda b: start_render_loop())
stop_button.on_click(lambda b: stop_render_loop())

display(button_output)
display(image_output)

Output()

Output()