In [None]:
import pandas as pd
from pathlib import Path

In [None]:
import sys
# sys.path.append("../source/")
sys.path.append("../")

In [None]:
from source.track_progress import timefmt

In [None]:
import os

In [None]:
root_dir = Path("/envroot/trainings/camera_train")
cmd_dir = root_dir / "collected_commands"
pos_dir = root_dir / "collected_positions"
embed_dir = root_dir / "collected_embeds"
img_dir = root_dir / "collected_imgs"

wm_dir = root_dir / "world_models"
ag_dir = root_dir / "agents"

In [None]:
prog_file = root_dir / "progress_model_names.txt"
with open(prog_file, "r") as f:
    models, timestamps = list(zip(*[map(str.strip, l.strip().split("@")) for l in f]))

In [None]:
# 1. sort the model by the images
# 2. train image paths
# 3. parse the wm, agent, image used to train the model, image capture by the agent
import yaml
from source.datasets.ptz_dataset import get_position_datetime_from_labels
import numpy as np
import pprint
import copy
import torch
import logging

logger = logging.getLogger(__name__)


class ModelInfo:
    def __init__(self, model_basedir, model_name):
        self.model_name = model_name
        self.model_dir = Path(model_basedir, model_name)
        model_info_path = self.model_dir / "model_info.yaml"
        with open(model_info_path, "r") as f:
            self.info_dict = yaml.safe_load(f)
        self.ori_info_dict = copy.deepcopy(self.info_dict)
        imgnames = np.array(os.listdir(img_dir))
        imgposs, imgtimes = get_position_datetime_from_labels([iname.strip(".jpg") for iname in imgnames])
        # imgtimes = pd.to_datetime(imgtimes, format=timefmt)
        # figure out images that are used in the trainings
        for k in self.info_dict.keys():
            if not k.startswith("restart_"):
                continue
            restart_dict = self.info_dict[k]
            if len(restart_dict["images"]["start_end"]) == 2:
                starttime, endtime = pd.to_datetime(restart_dict["images"]["start_end"], format=timefmt, utc=True)
                idx = np.where((starttime <= imgtimes) & (imgtimes <= endtime))[0]
            else:
                # is an agent and was used for more than once to get images
                idx = []
                timedata = restart_dict["images"]["start_end"]
                for i in range(len(timedata) // 2):
                    # print(timedata[i*2:i*2+1])
                    starttime, endtime = pd.to_datetime(timedata[i*2:i*2+2], format=timefmt, utc=True)
                    idx.append(np.where((starttime <= imgtimes) & (imgtimes <= endtime))[0])
                # print(type(starttime), type(imgtimes))
                idx = np.stack(idx)
            restart_dict["images"]["filename"] = imgnames[idx]

    def __repr__(self):
        from pprint import pformat
        return pformat(self.ori_info_dict)
        # return repr(self.info_dict)

    def get_images_at_restart(self, restart_iter: int):
        if restart_iter > self.info_dict['num_restart']:
            raise ValueError("Current restart of the model is smaller than the input restart iteration, check again")
        return self.info_dict[f"restart_{restart_iter:0>2}"]["images"]["filename"]


class WorldModelInfo(ModelInfo):
    def __init__(self, root_dir, model_name):
        self.root_dir = Path(root_dir)
        wm_dir = self.root_dir / "world_models"
        model_type_infer = model_name.split("_")[0]
        assert model_type_infer == "wm", f"Requires a world model (wm_*), but got {model_type_infer}"
        super().__init__(wm_dir, model_name)
        self.model_path = self.model_dir / "jepa-latest.pt"


class AgentInfo(ModelInfo):
    def __init__(self, root_dir, model_name):
        self.root_dir = Path(root_dir)
        ag_dir = self.root_dir / "agents"
        model_type_infer = model_name.split("_")[0]
        assert model_type_infer == "ag", f"Requires an agent (ag_*), but got {model_type_infer}"
        super().__init__(ag_dir, model_name)
        self._get_collected_timestamps()
        self.model_target_path = self.model_dir / "jepa-target_latest.pt"
        self.model_policy_path = self.model_dir / "jepa-policy_latest.pt"
        

    def _get_collected_timestamps(self):
        # ! this requires all collected pos, cmd, embed have the same timestamp
        if not (self.root_dir / "collected_positions").exists():
            logger.warning("No timestamps found, cannot analyze collected data info")
            return None
        # positions_at_2024-07-31_17:36:08.173102.txt
        fnames = os.listdir(self.root_dir / "collected_positions")
        ftimes = pd.to_datetime([fn.split("_at_")[-1].strip().strip(".txt") for fn in fnames], format=timefmt, utc=True)
        for k in self.info_dict.keys():
            if not k.startswith("restart_"):
                continue
            idx = []
            restart_dict = self.info_dict[k]
            if "meta" not in restart_dict.keys():
                restart_dict["meta"] = {}
            timedata = restart_dict["images"]["start_end"]
            for i in range(len(timedata) // 2):
                # print(timedata[i*2:i*2+1])
                starttime, endtime = pd.to_datetime(timedata[i*2:i*2+2], format=timefmt, utc=True)
                idx.append(np.where((starttime <= ftimes) & (ftimes <= endtime))[0])
            # print(type(starttime), type(imgtimes))
            idx = np.stack(idx)
            restart_dict["meta"]["collect_timestamp"] = np.array(ftimes.strftime(timefmt))[idx]

    def get_collected_data_at_restart(self, restart_iter: int):
        # pos, cmd, embeds
        if restart_iter > self.info_dict['num_restart']:
            raise ValueError(f"Current restart of the model is smaller than the input restart iteration. Max iteration is {self.info_dict['num_restart']}")
        collts = self.info_dict[f"restart_{restart_iter:0>2}"]["meta"]["collect_timestamp"]
        return list(zip(*[(f"positions_at_{ts}.txt", f"commands_at_{ts}.txt", f"embeds_at_{ts}.pt") for ts in collts.ravel()]))

    def load_collected_data(self, restart_iter):
        fnpos, fncmd, fnembed = self.get_collected_data_at_restart(restart_iter)
        embed_dir = self.root_dir / "collected_embeds"
        pos_dir = self.root_dir / "collected_positions"
        cmd_dir = self.root_dir / "collected_commands"
        li_embed = []
        li_pos = []
        li_cmd = []
        for i in range(len(fnpos)):
            embed = torch.load(embed_dir / fnembed[i])
            with open(pos_dir / fnpos[i], "r") as f:
                pos = np.array([l.strip().split(",") for l in f], dtype=float)
            with open(cmd_dir / fncmd[i], "r") as f:
                cmd = np.array([l.strip().split(",") for l in f], dtype=float)
            li_embed.append(embed)
            li_pos.append(pos)
            li_cmd.append(cmd)
        return li_pos, li_cmd, li_embed


class ProgressTracker:
    # steps:
    # 1. train a random world model
    # 2. generate dreams by a random trained world model
    # 3. train a random agent (needs to stick to a single world model)
    # 4. gather images using a random trained agent
    # last line is always the last model name
    def __init__(self, root_dir):
        self.root_dir = Path(root_dir)
        self.prog_file = self.root_dir / "progress_model_names.txt"
        with open(self.prog_file, "r") as f:
            self.model_names, self.finish_time = list(zip(*[map(str.strip, l.split("@")) for l in f]))
        

    # def get_model

In [None]:
# root_dir = Path("/envroot/trainings/camera_train")
# model_name = "wm_00_06"
# wm_dir = root_dir / "world_models"
# model_info_path = wm_dir / model_name / "model_info.yaml"
# with open(model_info_path, "r") as f:
#     info_dict = yaml.safe_load(f)
# imgnames = np.array(os.listdir(img_dir))
# imgposs, imgtimes = get_position_datetime_from_labels([iname.strip(".jpg") for iname in imgnames])

In [None]:
ag = AgentInfo(root_dir, "ag_00_26")

In [None]:
print(ag)

In [None]:
wm = WorldModelInfo(root_dir, ag.info_dict["restart_00"]["parent_model"])

In [None]:
wm

In [None]:
with open(ag.model_dir / "params-agent.yaml", "r") as f:
    params = yaml.safe_load(f)
model_arch = params['meta']['agent_model_arch']
pred_depth = params['meta']['pred_depth']
pred_emb_dim = params['meta']['pred_emb_dim']
camerabrand = params['meta']['camera_brand']
if not torch.cuda.is_available():
    device = torch.device('cpu')
else:
    device = torch.device('cuda:0')
    torch.cuda.set_device(device)

# -- DATA
use_gaussian_blur = params['data']['use_gaussian_blur']
use_horizontal_flip = params['data']['use_horizontal_flip']
use_color_distortion = params['data']['use_color_distortion']
color_jitter = params['data']['color_jitter_strength']
# --
crop_size = params['data']['crop_size']
crop_scale = params['data']['crop_scale']

# -- MASK
patch_size = params['mask']['patch_size']  # patch-size for model training

In [None]:
# -- ACTIONS
action_noop = params['action']['noop']
action_short_left = params['action']['short']['left']
action_short_right = params['action']['short']['right']
action_short_left_up = params['action']['short']['left_up']
action_short_right_up = params['action']['short']['right_up']
action_short_left_down = params['action']['short']['left_down']
action_short_right_down = params['action']['short']['right_down']
action_short_up = params['action']['short']['up']
action_short_down = params['action']['short']['down']
action_short_zoom_in = params['action']['short']['zoom_in']
action_short_zoom_out = params['action']['short']['zoom_out']

action_long_left = params['action']['long']['left']
action_long_right = params['action']['long']['right']
action_long_up = params['action']['long']['up']
action_long_down = params['action']['long']['down']
action_long_zoom_in = params['action']['long']['zoom_in']
action_long_zoom_out = params['action']['long']['zoom_out']

action_jump_left = params['action']['jump']['left']
action_jump_right = params['action']['jump']['right']
action_jump_up = params['action']['jump']['up']
action_jump_down = params['action']['jump']['down']

actions={}
actions[0]=action_noop
actions[1]=action_short_left
actions[2]=action_short_right
actions[3]=action_short_left_up
actions[4]=action_short_right_up
actions[5]=action_short_left_down
actions[6]=action_short_right_down
actions[7]=action_short_up
actions[8]=action_short_down
actions[9]=action_short_zoom_in
actions[10]=action_short_zoom_out
actions[11]=action_long_left
actions[12]=action_long_right
actions[13]=action_long_up
actions[14]=action_long_down
actions[15]=action_long_zoom_in
actions[16]=action_long_zoom_out
actions[17]=action_jump_left
actions[18]=action_jump_right
actions[19]=action_jump_up
actions[20]=action_jump_down

num_actions=len(actions.keys())

In [None]:
from source.helper import init_world_model, load_checkpoint, init_agent_model
from source.utils.analysis_viz import scale_pca_tsne_transform

In [None]:
# -- init world model
target_encoder, _ = init_world_model(
    device=device,
    patch_size=patch_size,
    crop_size=crop_size,
    pred_depth=pred_depth,
    pred_emb_dim=pred_emb_dim,
    model_arch=model_arch)

for p in target_encoder.parameters():
    p.requires_grad = False

# -- load training checkpoint
_, _, target_encoder, _, _, _ = load_checkpoint(
device=device,
r_path=wm.model_path,
target_encoder=target_encoder)


_, target_predictor = init_agent_model(
    device=device,
    patch_size=patch_size,
    crop_size=crop_size,
    pred_depth=pred_depth,
    pred_emb_dim=pred_emb_dim,
    model_arch=model_arch,
    num_actions=num_actions)

for p in target_predictor.parameters():
    p.requires_grad = False

# -- load training checkpoint
_, target_predictor, _, _, _, _ = load_checkpoint(
    device=device,
    r_path=ag.model_target_path,
    predictor=target_predictor)

In [None]:
pos, cmd, embed = ag.load_collected_data(1)

len(pos), len(cmd), len(embed)
# remove the last postion, where the reward is not used to guide the camera
# pos = [p[:-1] for p in pos]

In [None]:
action_rewards = [target_predictor(embed[i].to(device), torch.tensor(pos[i][:-1]).to(device, dtype=torch.float32)) for i in range(len(pos))]
action_rewards = torch.vstack(action_rewards).cpu().numpy()
rewards = np.max(action_rewards, axis=1)

In [None]:
target_embeds = [torch.mean(embed[i], axis=1) for i in range(len(embed))]

In [None]:
embed_tsne, embed_pca = scale_pca_tsne_transform(torch.vstack(target_embeds))

In [None]:
import matplotlib.pyplot as plt

In [None]:
pos = [p[:-1] for p in pos]
pos = np.vstack(pos)

In [None]:
ag

In [None]:
wm

In [None]:
plt.figure(figsize=(8, 6))
vmin = np.quantile(rewards, 0.1)
vmax = np.quantile(rewards, 0.9)
plt.scatter(embed_tsne[:, 0], embed_tsne[:, 1], s=2,
              c=rewards, cmap="jet", vmin=vmin, vmax=vmax)
plt.gca().set_aspect("equal")
plt.xlabel("tSNE axis-1")
plt.ylabel("tSNE axis-2")
cbar = plt.colorbar()
cbar.ax.set_ylabel('Reward')
# plt.show()
plt.savefig("tsne.png")
# ax[1].legend()
# ax[1].set_xlim([-80, 80])
# ax[1].set_ylim([-80, 80])

fig, ax = plt.subplots(2, 1, figsize=(8, 6))
# for p in pos:
    # ax.scatter(p[:, 0], p[:, 1], s=1)
# ax.scatter(pos[:, 0], pos[:, 1], s=pos[:, 2] / 100)
i = 1
im = ax[i].scatter(pos[:, 0], pos[:, 1], c=np.log10(pos[:, 2]), s=2)
ax[i].set_xlim([-185, 185])
ax[i].set_ylim([-95, 5])
ax[i].set_xlabel("Pan (deg)")
ax[i].set_ylabel("Tilt (deg)")
ax[i].set_aspect("equal")
cbar = fig.colorbar(im)
cbar.ax.set_ylabel('$\log_{10}$ Zoom')
i = 0
im = ax[i].scatter(pos[:, 0], pos[:, 1], c=rewards, s=2, cmap="jet", vmin=vmin, vmax=vmax)
ax[i].set_xlim([-185, 185])
ax[i].set_ylim([-95, 5])
ax[i].set_xlabel("Pan (deg)")
ax[i].set_ylabel("Tilt (deg)")
ax[i].set_aspect("equal")
cbar = fig.colorbar(im)
cbar.ax.set_ylabel('Reward')
fig.savefig("pointings.png")
# plt.show()
# fig.savefig("pointings.png")

In [None]:
# fig, ax = plt.subplots(1, 1, figsize=(8, 4))
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(projection='3d')
# for p in pos:
    # ax.scatter(p[:, 0], p[:, 1], s=1)
im = ax.scatter(pos[:, 0], pos[:, 2], pos[:, 1], s=2, c=rewards, cmap="jet", vmin=vmin, vmax=vmax)
# ax.set_xlim([-185, 185])
# ax.set_ylim([-95, 5])
ax.set_xlabel("Pan (deg)")
ax.set_zlabel("Tilt (deg)")
ax.set_ylabel("Zoom")
fig.colorbar(im)
# ax.set_aspect("equal")
# ax.set_ylim(700, 0)
# fig.savefig("pointings.png")

In [None]:
sel = (pos[:, 0] < -20) & (pos[:, 0] > -25) & (pos[:, 1] < -65) & (pos[:, 1] > -75)
print(pos[sel], np.vstack(cmd)[sel])

In [None]:
plt.plot(rewards[sel])

In [None]:
# import matplotlib.animation as animation

# # pos = []
# # for fn in fnames:
# #     with open(pos_dir / f"{fn}.txt", "r") as f:
# #         pos.append(np.array([l.strip().split(",") for l in f.readlines()], dtype=float))
# # pos = np.vstack(pos)

# fig, ax = plt.subplots(1, 1, figsize=(8, 4))
# ax.set_xlim([-185, 185])
# ax.set_ylim([-95, 5])
# ax.set_xlabel("Pan (deg)")
# ax.set_ylabel("Tilt (deg)")
# ax.set_aspect("equal")
# scat = ax.scatter(pos[0, 0], pos[0, 1], s=5, c='tab:red')


# def animate(i):
#     # if i < 5:
#     #     for j in range(i-1):
#     #         scat.plot(pos[j, 0], pos[j, 1], s=5)
#     # else:
#     # scat.set_offsets((pos[i, 0], pos[i, 1]))
#     # for j in range(max(i-5, 0), i):
#     # # if i > 0:
#     ax.scatter(pos[i, 0], pos[i, 1], s=5, c='tab:red')
#     return scat,

# # def update(frame):
# #     global positions
# #     positions += velocities  # Update positions
# #     lo = 0 if frame < 10 else frame - 10
# #     for i in range(lo, frame):
# #         scat.set_offsets(positions[i])
# #     # Update the trails
# #     for i in range(num_points):
# #         trails[i].append(positions[i].copy())
# #         if len(trails[i]) > 10:  # Limit trail length
# #             trails[i].pop(0)

# #     # Update the scatter plot
# #     scatter.set_offsets(positions)

# #     # Draw the trails
# #     for trail in trails:
# #         if len(trail) > 1:
# #             trail_array = np.array(trail)
# #             ax.plot(trail_array[:, 0], trail_array[:, 1], 'b-', alpha=0.5)

# #     return scatter,


# ani = animation.FuncAnimation(fig, animate, repeat=True,
#                                     frames=len(pos) - 1, interval=20)
# # plt.show()
# writer = animation.PillowWriter(fps=15,
#                                 metadata=dict(artist='Me'),
#                                 bitrate=1800)
# ani.save('scatter.gif', writer=writer)