In [1]:
import sys
sys.path.insert(0, "../../torchdriveenv")

In [2]:
import cv2
import io
import os
import pickle
import random
import torch
import numpy as np

from PIL import Image
from matplotlib import pyplot as plt

In [3]:
def get_value(key, step_data):
    if key == "obs_birdview":
        return step_data.obs_birdview.squeeze()
    if key == "recurrent_state":
        return torch.Tensor(step_data.recurrent_states[0]).squeeze() # .cuda()
    if key == "action":
        return step_data.ego_action.squeeze()

In [4]:
def to_video(pil_images, fps=10):

    # Convert PIL images to numpy arrays
    frames = [np.array(img) for img in pil_images]

    # Get frame size (height, width)
    height, width, layers = frames[0].shape
    size = (width, height)

    # Use in-memory buffer for video
    output_buffer = io.BytesIO()

    # Define video writer using OpenCV and FFMPEG with memory buffer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_writer = cv2.VideoWriter("output.mp4",
                                    fourcc, fps, size)

    for frame in frames:
        bgr_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        video_writer.write(bgr_frame)

    video_writer.release()

    output_buffer.seek(0)  # Reset buffer position to the start
    return output_buffer

In [5]:
def to_image(plt):
    buf = io.BytesIO()
    plt.savefig(buf, format='png', bbox_inches='tight', dpi=50)
    buf.seek(0)
    plt.close()
    return Image.open(buf)


def plot_samples(x):
    x = x.cpu().numpy()
    plt.scatter(x[1], x[0])
    plt.xlim(-2, 2)
    plt.ylim(-2, 2)
    return to_image(plt)

In [7]:
def visualize_episode(file_path):
    obs_birdviews = []
    actions = []
    action_images = []
    with open(file_path, "rb") as f:
        episode_data = pickle.load(f)
    print(episode_data)
    for step_data in episode_data.step_data:
        obs_birdviews.append(get_value("obs_birdview", step_data))
        action = get_value("action", step_data)
        actions.append(action)
        action_images.append(plot_samples(action)) # .append(get_value("action", step_data))
        
    return [Image.fromarray(img.astype(np.uint8).transpose(1, 2, 0), 'RGB') for img in obs_birdviews], \
            action_images, actions

In [8]:
data_dir = "../data/waypoint_graph_no_rendering_test"
file_paths = [os.path.join(data_dir, file) for file in os.listdir(data_dir)]
# file_paths

In [11]:
for i in range(100):
    file_path = random.choice(file_paths)
    with open(file_path, "rb") as f:
        episode_data = pickle.load(f)
#     print(episode_data)
    if episode_data.step_data[-1].info["reached_waypoint_num"] <= 2:
        print("too short")
    else:
        print(episode_data.step_data[-1].info["reached_waypoint_num"], ' ', file_path)

too short
too short
too short
too short
too short
too short
too short
too short
3   ../data/waypoint_graph_no_rendering_test/episode_67_72605.pkl
8   ../data/waypoint_graph_no_rendering_test/episode_46_30204.pkl
too short
8   ../data/waypoint_graph_no_rendering_test/episode_74_602.pkl
too short
10   ../data/waypoint_graph_no_rendering_test/episode_86_12007.pkl
too short
too short
too short
4   ../data/waypoint_graph_no_rendering_test/episode_42_2628.pkl
10   ../data/waypoint_graph_no_rendering_test/episode_70_55538.pkl
10   ../data/waypoint_graph_no_rendering_test/episode_46_26128.pkl
too short
too short
too short
4   ../data/waypoint_graph_no_rendering_test/episode_31_71556.pkl
8   ../data/waypoint_graph_no_rendering_test/episode_23_13161.pkl
too short
too short
too short
6   ../data/waypoint_graph_no_rendering_test/episode_75_50360.pkl
too short
5   ../data/waypoint_graph_no_rendering_test/episode_77_30144.pkl
too short
too short
too short
6   ../data/waypoint_graph_no_rendering_test

In [None]:
file_path = random.choice(file_paths)
file_path 

In [None]:
# file_path = '../data/waypoint_graph_no_rendering_test/episode_9_95910.pkl'

In [None]:
# !ls -l ../data/itra_data/episode_2367_13523.pkl

In [None]:
obs_images, action_images, actions = visualize_episode(file_path)

In [None]:
to_video(obs_images)

In [None]:
obs_images[0]

In [None]:
actions

In [None]:
np.argmax(torch.stack(actions).cpu(), axis=0)

In [None]:
np.argmin(torch.stack(actions).cpu(), axis=0)

In [None]:
actions[11]

In [None]:
# actions[97]

In [None]:
action_images[11]

In [None]:
obs_images[11]

In [None]:
obs_images[0].save('test_obs.gif',
                   save_all=True, append_images=obs_images[1:], optimize=False, duration=40, loop=0)

In [None]:
action_images[0].save('test_action.gif',
                      save_all=True, append_images=action_images[1:], optimize=False, duration=40, loop=0)