### Pip Install

In [None]:
# install the package
%pip install --upgrade mani_skill
# install a version of torch that is compatible with your system
%pip install torch torchvision torchaudio numpy diffusers


# etc imports
from typing import Tuple, Sequence, Dict, Union, Optional
from collections import OrderedDict
import collections
import math
import h5py
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import display, Image as IPImage
import io
import os
import csv

# mani_skill imports
from mani_skill.utils import common
from mani_skill.utils.io_utils import load_json
from mani_skill.utils.common import flatten_state_dict
import mani_skill.envs
from typing import Any, Dict, Union

import sapien
from mani_skill.envs.scene import ManiSkillScene
from transforms3d.euler import euler2quat
from mani_skill.agents.robots import PandaWristCam
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.envs.utils import randomization
from mani_skill.sensors.camera import CameraConfig
from mani_skill.utils import common, sapien_utils
from mani_skill.utils.geometry import rotation_conversions
from mani_skill.utils.registration import register_env
from mani_skill.utils.scene_builder.table import TableSceneBuilder
from mani_skill.utils.structs.pose import Pose
from mani_skill.utils.structs.types import SimConfig


#torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, Dataset
from torch.utils.data import DataLoader

# diffuser imports
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler

# gym imports
import gymnasium as gym
from gymnasium import spaces

# google colab imports
from google.colab import drive
drive.mount('/content/drive')

### Classes

#### Environments

In [None]:
def _build_box_with_hole(
    scene: ManiSkillScene, inner_radius, outer_radius, depth, center=(0, 0)
):
    builder = scene.create_actor_builder()
    thickness = (outer_radius - inner_radius) * 0.5
    # x-axis is hole direction
    half_center = [x * 0.5 for x in center]
    half_sizes = [
        [depth, thickness - half_center[0], outer_radius],
        [depth, thickness + half_center[0], outer_radius],
        [depth, outer_radius, thickness - half_center[1]],
        [depth, outer_radius, thickness + half_center[1]],
    ]
    offset = thickness + inner_radius
    poses = [
        sapien.Pose([0, offset + half_center[0], 0]),
        sapien.Pose([0, -offset + half_center[0], 0]),
        sapien.Pose([0, 0, offset + half_center[1]]),
        sapien.Pose([0, 0, -offset + half_center[1]]),
    ]

    mat = sapien.render.RenderMaterial(
        base_color=sapien_utils.hex2rgba("#FFD289"), roughness=0.5, specular=0.5
    )

    for half_size, pose in zip(half_sizes, poses):
        builder.add_box_collision(pose, half_size)
        builder.add_box_visual(pose, half_size, material=mat)
    return builder


@register_env("PegInsertionSide-v2", max_episode_steps=100)
class PegInsertionSideEnv(BaseEnv):
    SUPPORTED_REWARD_MODES = ("normalized_dense", "dense", "sparse", "none")
    SUPPORTED_ROBOTS = ["panda_wristcam"]
    agent: Union[PandaWristCam]
    _clearance = 0.003

    def __init__(
        self,
        *args,
        robot_uids="panda_wristcam",
        num_envs=1,
        reconfiguration_freq=None,
        **kwargs,
    ):
        if reconfiguration_freq is None:
            if num_envs == 1:
                reconfiguration_freq = 1
            else:
                reconfiguration_freq = 0
        super().__init__(
            *args,
            robot_uids=robot_uids,
            num_envs=num_envs,
            reconfiguration_freq=reconfiguration_freq,
            **kwargs,
        )

    @property
    def _default_sim_config(self):
        return SimConfig()

    @property
    def _default_sensor_configs(self):
        pose = sapien_utils.look_at([0, -0.3, 0.2], [0, 0, 0.1])
        return [CameraConfig("base_camera", pose, 128, 128, np.pi / 2, 0.01, 100)]

    @property
    def _default_human_render_camera_configs(self):
        pose = sapien_utils.look_at([0.5, -0.5, 0.8], [0.05, -0.1, 0.4])
        return CameraConfig("render_camera", pose, 512, 512, 1, 0.01, 100)

    def _load_scene(self, options: dict):
        with torch.device(self.device):
            self.table_scene = TableSceneBuilder(self)
            self.table_scene.build()

            lengths = self._episode_rng.uniform(0.085, 0.125, size=(self.num_envs,))
            radii = self._episode_rng.uniform(0.015, 0.025, size=(self.num_envs,))
            centers = (
                0.5
                * (lengths - radii)[:, None]
                * self._episode_rng.uniform(-1, 1, size=(self.num_envs, 2))
            )

            # save some useful values for use later
            self.peg_half_sizes = common.to_tensor(np.vstack([lengths, radii, radii])).T
            peg_head_offsets = torch.zeros((self.num_envs, 3))
            peg_head_offsets[:, 0] = self.peg_half_sizes[:, 0]
            self.peg_head_offsets = Pose.create_from_pq(p=peg_head_offsets)

            box_hole_offsets = torch.zeros((self.num_envs, 3))
            box_hole_offsets[:, 1:] = common.to_tensor(centers)
            self.box_hole_offsets = Pose.create_from_pq(p=box_hole_offsets)
            hole_enlargement_factor = 1.5
            self.box_hole_radii = common.to_tensor(radii * hole_enlargement_factor + self._clearance)

            # in each parallel env we build a different box with a hole and peg (the task is meant to be quite difficult)
            pegs = []
            boxes = []

            for i in range(self.num_envs):
                scene_idxs = [i]
                length = lengths[i]
                radius = radii[i]
                builder = self.scene.create_actor_builder()
                builder.add_box_collision(half_size=[length, radius, radius])
                # peg head
                mat = sapien.render.RenderMaterial(
                    base_color=sapien_utils.hex2rgba("#EC7357"),
                    roughness=0.5,
                    specular=0.5,
                )
                builder.add_box_visual(
                    sapien.Pose([length / 2, 0, 0]),
                    half_size=[length / 2, radius, radius],
                    material=mat,
                )
                # peg tail
                mat = sapien.render.RenderMaterial(
                    base_color=sapien_utils.hex2rgba("#EDF6F9"),
                    roughness=0.5,
                    specular=0.5,
                )
                builder.add_box_visual(
                    sapien.Pose([-length / 2, 0, 0]),
                    half_size=[length / 2, radius, radius],
                    material=mat,
                )
                builder.set_scene_idxs(scene_idxs)
                peg = builder.build(f"peg_{i}")

                # box with hole

                inner_radius, outer_radius, depth = (
                    radius * hole_enlargement_factor + self._clearance,
                    length,
                    length,
                )
                builder = _build_box_with_hole(
                    self.scene, inner_radius, outer_radius, depth, center=centers[i]
                )
                builder.set_scene_idxs(scene_idxs)
                box = builder.build_kinematic(f"box_with_hole_{i}")

                pegs.append(peg)
                boxes.append(box)
            self.peg = Actor.merge(pegs, "peg")
            self.box = Actor.merge(boxes, "box_with_hole")

    def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
        with torch.device(self.device):
            b = len(env_idx)
            self.table_scene.initialize(env_idx)

            # initialize the box and peg
            xy = randomization.uniform(
                low=torch.tensor([-0.1, -0.3]), high=torch.tensor([0.1, 0]), size=(b, 2)
            )
            pos = torch.zeros((b, 3))
            pos[:, :2] = xy
            pos[:, 2] = self.peg_half_sizes[env_idx, 2]
            quat = randomization.random_quaternions(
                b,
                self.device,
                lock_x=True,
                lock_y=True,
                bounds=(np.pi / 2 - np.pi / 3, np.pi / 2 + np.pi / 3),
            )
            self.peg.set_pose(Pose.create_from_pq(pos, quat))

            xy = randomization.uniform(
                low=torch.tensor([-0.05, 0.2]),
                high=torch.tensor([0.05, 0.4]),
                size=(b, 2),
            )
            pos = torch.zeros((b, 3))
            pos[:, :2] = xy
            pos[:, 2] = self.peg_half_sizes[env_idx, 0]
            quat = randomization.random_quaternions(
                b,
                self.device,
                lock_x=True,
                lock_y=True,
                bounds=(np.pi / 2 - np.pi / 8, np.pi / 2 + np.pi / 8),
            )
            self.box.set_pose(Pose.create_from_pq(pos, quat))

            # Initialize the robot
            qpos = np.array(
                [
                    0.0,
                    np.pi / 8,
                    0,
                    -np.pi * 5 / 8,
                    0,
                    np.pi * 3 / 4,
                    -np.pi / 4,
                    0.04,
                    0.04,
                ]
            )
            qpos = self._episode_rng.normal(0, 0.02, (b, len(qpos))) + qpos
            qpos[:, -2:] = 0.04
            self.agent.robot.set_qpos(qpos)
            self.agent.robot.set_pose(sapien.Pose([-0.615, 0, 0]))

    # save some commonly used attributes
    @property
    def peg_head_pos(self):
        return self.peg.pose.p + self.peg_head_offsets.p

    @property
    def peg_head_pose(self):
        return self.peg.pose * self.peg_head_offsets

    @property
    def box_hole_pose(self):
        return self.box.pose * self.box_hole_offsets

    @property
    def goal_pose(self):
        # NOTE (stao): this is fixed after each _initialize_episode call. You can cache this value
        # and simply store it after _initialize_episode or set_state_dict calls.
        return self.box.pose * self.box_hole_offsets * self.peg_head_offsets.inv()

    def has_peg_inserted(self):
        # Only head position is used in fact
        peg_head_pos_at_hole = (self.box_hole_pose.inv() * self.peg_head_pose).p
        # x-axis is hole direction
        x_flag = -0.015 <= peg_head_pos_at_hole[:, 0]
        y_flag = (-self.box_hole_radii <= peg_head_pos_at_hole[:, 1]) & (
            peg_head_pos_at_hole[:, 1] <= self.box_hole_radii
        )
        z_flag = (-self.box_hole_radii <= peg_head_pos_at_hole[:, 2]) & (
            peg_head_pos_at_hole[:, 2] <= self.box_hole_radii
        )
        return (
            x_flag & y_flag & z_flag,
            peg_head_pos_at_hole,
        )

    def evaluate(self):
        success, peg_head_pos_at_hole = self.has_peg_inserted()
        return dict(success=success, peg_head_pos_at_hole=peg_head_pos_at_hole)

    def _get_obs_extra(self, info: Dict):
        obs = dict(tcp_pose=self.agent.tcp.pose.raw_pose)
        if self._obs_mode in ["state", "state_dict"]:
            obs.update(
                peg_pose=self.peg.pose.raw_pose,
                peg_half_size=self.peg_half_sizes,
                box_hole_pose=self.box_hole_pose.raw_pose,
                box_hole_radius=self.box_hole_radii,
            )
        return obs

    def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict):
        # Stage 1: Encourage gripper to be rotated to be lined up with the peg

        # Stage 2: Encourage gripper to move close to peg tail and grasp it
        gripper_pos = self.agent.tcp.pose.p
        tgt_gripper_pose = self.peg.pose
        offset = sapien.Pose(
            [-0.06, 0, 0]
        )  # account for panda gripper width with a bit more leeway
        tgt_gripper_pose = tgt_gripper_pose * (offset)
        gripper_to_peg_dist = torch.linalg.norm(
            gripper_pos - tgt_gripper_pose.p, axis=1
        )

        reaching_reward = 1 - torch.tanh(4.0 * gripper_to_peg_dist)

        # check with max_angle=20 to ensure gripper isn't grasping peg at an awkward pose
        is_grasped = self.agent.is_grasping(self.peg, max_angle=20)
        reward = reaching_reward + is_grasped

        # Stage 3: Orient the grasped peg properly towards the hole

        # pre-insertion award, encouraging both the peg center and the peg head to match the yz coordinates of goal_pose
        peg_head_wrt_goal = self.goal_pose.inv() * self.peg_head_pose
        peg_head_wrt_goal_yz_dist = torch.linalg.norm(
            peg_head_wrt_goal.p[:, 1:], axis=1
        )
        peg_wrt_goal = self.goal_pose.inv() * self.peg.pose
        peg_wrt_goal_yz_dist = torch.linalg.norm(peg_wrt_goal.p[:, 1:], axis=1)

        pre_insertion_reward = 3 * (
            1
            - torch.tanh(
                0.5 * (peg_head_wrt_goal_yz_dist + peg_wrt_goal_yz_dist)
                + 4.5 * torch.maximum(peg_head_wrt_goal_yz_dist, peg_wrt_goal_yz_dist)
            )
        )
        reward += pre_insertion_reward * is_grasped
        # stage 3 passes if peg is correctly oriented in order to insert into hole easily
        pre_inserted = (peg_head_wrt_goal_yz_dist < 0.01) & (
            peg_wrt_goal_yz_dist < 0.01
        )

        # Stage 4: Insert the peg into the hole once it is grasped and lined up
        peg_head_wrt_goal_inside_hole = self.box_hole_pose.inv() * self.peg_head_pose
        insertion_reward = 5 * (
            1
            - torch.tanh(
                5.0 * torch.linalg.norm(peg_head_wrt_goal_inside_hole.p, axis=1)
            )
        )
        reward += insertion_reward * (is_grasped & pre_inserted)

        reward[info["success"]] = 10

        return reward

    def compute_normalized_dense_reward(
        self, obs: Any, action: torch.Tensor, info: Dict
    ):
        return self.compute_dense_reward(obs, action, info) / 10
    



@register_env("PlugCharger-v2", max_episode_steps=200)
class PlugChargerEnv(BaseEnv):
    _base_size = [2e-2, 1.5e-2, 1.2e-2]  # charger base half size
    _peg_size = [8e-3, 0.75e-3, 3.2e-3]  # charger peg half size
    _peg_gap = 7e-3  # charger peg gap
    _clearance = 5e-4  # single side clearance
    _receptacle_size = [1e-2, 5e-2, 5e-2]  # receptacle half size
    _hole_size_gap_x = 3
    _hole_size_gap_y = 1.5

    SUPPORTED_ROBOTS = ["panda_wristcam"]
    agent: Union[PandaWristCam]

    def __init__(
        self, *args, robot_uids="panda_wristcam", robot_init_qpos_noise=0.02, **kwargs
    ):
        self.robot_init_qpos_noise = robot_init_qpos_noise
        super().__init__(*args, robot_uids=robot_uids, **kwargs)

    @property
    def _default_sim_config(self):
        return SimConfig()

    @property
    def _default_sensor_configs(self):
        pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
        return [
            CameraConfig("base_camera", pose=pose, width=128, height=128, fov=np.pi / 2)
        ]

    @property
    def _default_human_render_camera_configs(self):
        pose = sapien_utils.look_at([0.3, 0.4, 0.1], [0, 0, 0])
        return [
            CameraConfig(
                "render_camera",
                pose=pose,
                width=512,
                height=512,
                fov=1,
                mount=self.receptacle,
            )
        ]

    def _build_charger(self, peg_size, base_size, gap):
        builder = self.scene.create_actor_builder()

        # peg
        mat = sapien.render.RenderMaterial()
        mat.set_base_color([1, 1, 1, 1])
        mat.metallic = 1.0
        mat.roughness = 0.0
        mat.specular = 1.0
        builder.add_box_collision(sapien.Pose([peg_size[0], gap, 0]), peg_size)
        builder.add_box_visual(
            sapien.Pose([peg_size[0], gap, 0]), peg_size, material=mat
        )
        builder.add_box_collision(sapien.Pose([peg_size[0], -gap, 0]), peg_size)
        builder.add_box_visual(
            sapien.Pose([peg_size[0], -gap, 0]), peg_size, material=mat
        )

        # base
        mat = sapien.render.RenderMaterial()
        mat.set_base_color([1, 1, 1, 1])
        mat.metallic = 0.0
        mat.roughness = 0.1
        builder.add_box_collision(sapien.Pose([-base_size[0], 0, 0]), base_size)
        builder.add_box_visual(
            sapien.Pose([-base_size[0], 0, 0]), base_size, material=mat
        )

        return builder.build(name="charger")

    def _build_receptacle(self, peg_size, receptacle_size, gap):
        builder = self.scene.create_actor_builder()

        sy = 0.5 * (receptacle_size[1] - peg_size[1] - gap)
        sz = 0.5 * (receptacle_size[2] - peg_size[2])
        dx = -receptacle_size[0]
        dy = peg_size[1] + gap + sy
        dz = peg_size[2] + sz

        mat = sapien.render.RenderMaterial()
        mat.set_base_color([1, 1, 1, 1])
        mat.metallic = 0.0
        mat.roughness = 0.1

        poses = [
            sapien.Pose([dx, 0, dz]),
            sapien.Pose([dx, 0, -dz]),
            sapien.Pose([dx, dy, 0]),
            sapien.Pose([dx, -dy, 0]),
        ]
        half_sizes = [
            [receptacle_size[0], receptacle_size[1], sz],
            [receptacle_size[0], receptacle_size[1], sz],
            [receptacle_size[0], sy, receptacle_size[2]],
            [receptacle_size[0], sy, receptacle_size[2]],
        ]
        for pose, half_size in zip(poses, half_sizes):
            builder.add_box_collision(pose, half_size)
            builder.add_box_visual(pose, half_size, material=mat)

        # Fill the gap
        pose = sapien.Pose([-receptacle_size[0], 0, 0])
        half_size = [receptacle_size[0], gap - peg_size[1], peg_size[2]]
        builder.add_box_collision(pose, half_size)
        builder.add_box_visual(pose, half_size, material=mat)

        # Add dummy visual for hole
        mat = sapien.render.RenderMaterial()
        mat.set_base_color(sapien_utils.hex2rgba("#DBB539"))
        mat.metallic = 1.0
        mat.roughness = 0.0
        mat.specular = 1.0
        pose = sapien.Pose([-receptacle_size[0], -(gap * 0.5 + peg_size[1]), 0])
        half_size = [receptacle_size[0], peg_size[1], peg_size[2]]
        builder.add_box_visual(pose, half_size, material=mat)
        pose = sapien.Pose([-receptacle_size[0], gap * 0.5 + peg_size[1], 0])
        builder.add_box_visual(pose, half_size, material=mat)

        return builder.build_kinematic(name="receptacle")

    def _load_scene(self, options: dict):
        self.scene_builder = TableSceneBuilder(
            self, robot_init_qpos_noise=self.robot_init_qpos_noise
        )
        self.scene_builder.build()
        self.charger = self._build_charger(
            self._peg_size,
            self._base_size,
            self._peg_gap,
        )
        self.receptacle = self._build_receptacle(
            [
                self._peg_size[0],
                self._peg_size[1] * self._hole_size_gap_x + self._clearance,
                self._peg_size[2] * self._hole_size_gap_y + self._clearance,
            ],
            self._receptacle_size,
            self._peg_gap,
        )

    def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
        with torch.device(self.device):
            b = len(env_idx)
            self.scene_builder.initialize(env_idx)

            # Initialize agent
            qpos = torch.tensor(
                [
                    0.0,
                    np.pi / 8,
                    0,
                    -np.pi * 5 / 8,
                    0,
                    np.pi * 3 / 4,
                    np.pi / 4,
                    0.04,
                    0.04,
                ]
            )
            qpos = (
                torch.normal(
                    0, self.robot_init_qpos_noise, (b, len(qpos)), device=self.device
                )
                + qpos
            )
            qpos[:, -2:] = 0.04
            self.agent.robot.set_qpos(qpos)
            self.agent.robot.set_pose(sapien.Pose([-0.615, 0, 0]))

            # Initialize charger
            xy = randomization.uniform(
                [-0.1, -0.2], [-0.01 - self._peg_size[0] * 2, 0.2], size=(b, 2)
            )
            pos = torch.zeros((b, 3))
            pos[:, :2] = xy
            pos[:, 2] = self._base_size[2]
            ori = randomization.random_quaternions(
                n=b, lock_x=True, lock_y=True, bounds=(-torch.pi / 3, torch.pi / 3)
            )
            self.charger.set_pose(Pose.create_from_pq(pos, ori))

            # Initialize receptacle
            xy = randomization.uniform([0.01, -0.1], [0.1, 0.1], size=(b, 2))
            pos = torch.zeros((b, 3))
            pos[:, :2] = xy
            pos[:, 2] = 0.1
            ori = randomization.random_quaternions(
                n=b,
                lock_x=True,
                lock_y=True,
                bounds=(torch.pi - torch.pi / 8, torch.pi + torch.pi / 8),
            )
            self.receptacle.set_pose(Pose.create_from_pq(pos, ori))

            self.goal_pose = self.receptacle.pose * (
                sapien.Pose(q=euler2quat(0, 0, np.pi))
            )

    @property
    def charger_base_pose(self):
        return self.charger.pose * (sapien.Pose([-self._base_size[0], 0, 0]))

    def _compute_distance(self):
        obj_pose = self.charger.pose
        obj_to_goal_pos = self.goal_pose.p - obj_pose.p
        obj_to_goal_dist = torch.linalg.norm(obj_to_goal_pos, axis=1)

        obj_to_goal_quat = rotation_conversions.quaternion_multiply(
            rotation_conversions.quaternion_invert(self.goal_pose.q), obj_pose.q
        )
        obj_to_goal_axis = rotation_conversions.quaternion_to_axis_angle(
            obj_to_goal_quat
        )
        obj_to_goal_angle = torch.linalg.norm(obj_to_goal_axis, axis=1)
        obj_to_goal_angle = torch.min(
            obj_to_goal_angle, torch.pi * 2 - obj_to_goal_angle
        )

        return obj_to_goal_dist, obj_to_goal_angle

    def evaluate(self):
        obj_to_goal_dist, obj_to_goal_angle = self._compute_distance()
        success = (obj_to_goal_dist <= 5e-3) & (obj_to_goal_angle <= 0.2)
        return dict(
            obj_to_goal_dist=obj_to_goal_dist,
            obj_to_goal_angle=obj_to_goal_angle,
            success=success,
        )

    def _get_obs_extra(self, info: Dict):
        obs = dict(tcp_pose=self.agent.tcp.pose.raw_pose)
        if self._obs_mode in ["state", "state_dict"]:
            obs.update(
                charger_pose=self.charger.pose.raw_pose,
                receptacle_pose=self.receptacle.pose.raw_pose,
                goal_pose=self.goal_pose.raw_pose,
            )
        return obs

    def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict):
        return torch.zeros(self.num_envs, device=self.device)

    def compute_normalized_dense_reward(
        self, obs: Any, action: torch.Tensor, info: Dict
    ):
        max_reward = 1.0
        return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward

#### Dataset

In [None]:
def load_h5_data(data):
    out = dict()
    for k in data.keys():
        if isinstance(data[k], h5py.Dataset):
            out[k] = data[k][:]
        else:
            out[k] = load_h5_data(data[k])
    return out


def create_sample_indices(episode_ends: np.ndarray, sequence_length: int, pad_before: int = 0, pad_after: int = 0):
    # Currently uses truncated as episode ends which is the end of the episode and not the end of the trajectory
    indices = list()
    episode_length = 0
    episode_index = 1 # Start 1 for human readability
    #print(f"episode_ends: {episode_ends}")
    #print(f"len(episode_ends): {len(episode_ends)}")
    end_of_last_episode = False
    for i in range(len(episode_ends)):
        episode_length += 1
        if episode_ends[i] and not end_of_last_episode:
            start_idx = 0 if i <= 0 else i - episode_length + 1
            min_start = -pad_before
            max_start = episode_length - sequence_length + pad_after

            # Create indices for each possible sequence in the episode
            for idx in range(min_start, max_start + 1):
                buffer_start_idx = max(idx, 0) + start_idx
                buffer_end_idx = min(idx + sequence_length, episode_length) + start_idx
                start_offset = buffer_start_idx - (idx + start_idx)
                end_offset = (idx + sequence_length + start_idx) - buffer_end_idx
                sample_start_idx = 0 + start_offset
                sample_end_idx = sequence_length - end_offset
                indices.append([buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx])
            #print(f"Episode {episode_index} has {episode_length} steps")
            episode_length = 0
            episode_index += 1
            end_of_last_episode = True
        elif not episode_ends[i]:
            end_of_last_episode = False
    #print(f"Created {len(indices)} samples from {episode_index - 1} episodes")
    #print(f"All indices: {indices}")
    return np.array(indices)



def sample_sequence(train_data, sequence_length, buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx):
    result = dict()
    for key, input_arr in train_data.items():
        sample = input_arr[buffer_start_idx:buffer_end_idx]
        data = sample
        if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
            if isinstance(input_arr, torch.Tensor):
                data = torch.zeros((sequence_length,) + input_arr.shape[1:], dtype=input_arr.dtype)
            else:
                data = np.zeros(shape=(sequence_length,) + input_arr.shape[1:], dtype=input_arr.dtype)
            if sample_start_idx > 0:
                data[:sample_start_idx] = sample[0]
            if sample_end_idx < sequence_length:
                data[sample_end_idx:] = sample[-1]
            data[sample_start_idx:sample_end_idx] = sample
        result[key] = data
    return result

def remove_np_uint16(x: Union[np.ndarray, dict]):
            if isinstance(x, dict):
                for k in x.keys():
                    x[k] = remove_np_uint16(x[k])
                return x
            else:
                if x.dtype == np.uint16:
                    return x.astype(np.int32)
                return x

def convert_observation(obs, task_id):
    # adds task_id to the observation
    values = list(obs.values())
    example = values[0]
    if isinstance(example, torch.Tensor):
          example = example.numpy()

    # add task_id to the observation
    task_id_array = np.full((example.shape[0], 1), task_id, dtype=example.dtype) 
    values.append(task_id_array)
    # concatenate all the values
    return np.concatenate(values, axis=-1)

def get_observations(obs):
    #ensoure that the observations are in the correct format
    #and ordered correctly across tasks

    cleaned_obs = OrderedDict()
    cleaned_obs["qpos"] = obs["agent"]["qpos"]
    cleaned_obs["qvel"] = obs["agent"]["qvel"]
    cleaned_obs["tcp_pose"] = obs["extra"]["tcp_pose"]
    obs["extra"].pop("tcp_pose")

    #this code is not generic and only works for the specific observation spaces we have
    # Handle different goal position formats gracefully
    goal_pose_keys = ["goal_pose", "goal_pos", "box_hole_pose", "cubeB_pose"]
    for key in goal_pose_keys:
        if key in obs["extra"]:
            pos = obs["extra"][key]

            # Ensure 'pos' is 2D with the correct number of columns
            if pos.ndim == 1:
                pos = pos.reshape(1, -1)  # Reshape to 2D if necessary
            elif pos.ndim > 2:
                raise ValueError(f"Unexpected dimensions for '{key}': {pos.shape}")

            # Pad or truncate 'pos' to have 7 columns
            pos = np.pad(pos[:, :7], ((0, 0), (0, 7 - pos.shape[1])), mode='constant')
            if isinstance(cleaned_obs["tcp_pose"], torch.Tensor):
                pos = torch.tensor(pos, dtype=cleaned_obs["tcp_pose"].dtype)
                
            cleaned_obs["goal_pose"] = pos
            obs["extra"].pop(key)
            break  # Stop once a valid goal pose key is found
    else:
        print("No goal pose found. Setting to zero.")
        length = len(cleaned_obs["tcp_pose"])
        cleaned_obs["goal_pose"] = np.zeros((length, 7), dtype=np.float32)  # Ensure 2D shape
        
    #is_grasped_reshaped = np.reshape(obs["extra"]["is_grasped"], (len(obs["extra"]["is_grasped"]), 1))
    
    # Filter and add other observations with 7 columns
    for key, value in obs["extra"].items():
        if value.shape[-1] == 7 and value.ndim == 2:
            if key != "receptacle_pose":
                cleaned_obs[key] = value

    count = 0
    for key in cleaned_obs.keys():
        count += cleaned_obs[key].shape[-1]
    
    assert count == 39, "Observation size is not 39"

    
    return cleaned_obs



def get_min_max_values(dataloader, exclude_features):
    min_obs = None
    max_obs = None
    min_actions = None
    max_actions = None
    mask = None
    for batch in dataloader:
      for key, value in batch.items():
        obs_reshaped = batch[key].view(-1, batch[key].shape[-1])
        if key == "obs":
          if mask is None:
            mask = torch.ones(obs_reshaped.shape[1], dtype=torch.bool)
            mask[exclude_features] = False
          min_obs = obs_reshaped[:, mask].min(dim=0).values
          max_obs = obs_reshaped[:, mask].max(dim=0).values
        else:
          min_actions = obs_reshaped.min(dim=0).values
          max_actions = obs_reshaped.max(dim=0).values
    return {"obs": {"min": min_obs, "max": max_obs, "mask": mask}, "actions": {"min": min_actions, "max": max_actions}}

def normalize_batch(batch, stats):
    for key, value in batch.items():
      batch_reshaped = batch[key].view(-1, batch[key].shape[-1])

      normalized_batch = batch_reshaped.clone()
      if key == "obs":
        normalized_batch[:, stats[key]["mask"]] = (batch_reshaped[:, stats[key]["mask"]] - stats[key]["min"]) / (stats[key]["max"] - stats[key]["min"] + 0.1)
      else:
        normalized_batch = (batch_reshaped - stats[key]["min"]) / (stats[key]["max"] - stats[key]["min"] + 0.1)
      batch[key] = normalized_batch.view(batch[key].shape)
    return batch

def denormalize_batch(batch, stats):
    for key, value in batch.items():
      batch_reshaped = batch[key].view(-1, batch[key].shape[-1])

      denormalized_batch = batch_reshaped.clone()
      if key == "obs":
        denormalized_batch[:, stats[key]["mask"]] = batch_reshaped[:, stats[key]["mask"]] * (stats[key]["max"] - stats[key]["min"] + 0.1) + stats[key]["min"]
      else:
        denormalized_batch = batch_reshaped * (stats[key]["max"] - stats[key]["min"] + 0.1) + stats[key]["min"]
      batch[key] = denormalized_batch.view(batch[key].shape)
    return batch


class StateDataset(Dataset):
    """
    A general torch Dataset you can drop in and use immediately with just about any trajectory .h5 data generated from ManiSkill.
    This class simply is a simple starter code to load trajectory data easily, but does not do any data transformation or anything
    advanced. We recommend you to copy this code directly and modify it for more advanced use cases

    Args:
        dataset_file (str): path to the .h5 file containing the data you want to load
        load_count (int): the number of trajectories from the dataset to load into memory. If -1, will load all into memory
        success_only (bool): whether to skip trajectories that are not successful in the end. Default is false
        device: The location to save data to. If None will store as numpy (the default), otherwise will move data to that device
    """

    def __init__(
        self, dataset_file: str, pred_horizon: int, obs_horizon: int, action_horizon:int, task_id: np.float32, load_count=-1, device=None
    ) -> None:
        self.dataset_file = dataset_file
        self.pred_horizon = pred_horizon
        self.obs_horizon = obs_horizon
        self.action_horizon = action_horizon
        self.task_id = task_id
        self.device = device
        self.data = h5py.File(dataset_file, "r")
        json_path = dataset_file.replace(".h5", ".json")
        self.json_data = load_json(json_path)
        self.episodes = self.json_data["episodes"]
        self.env_info = self.json_data["env_info"]
        self.env_id = self.env_info["env_id"]
        self.env_kwargs = self.env_info["env_kwargs"]

        self.obs = None
        self.actions = []
        self.terminated = []
        self.truncated = []
        self.end_episode = []
        self.success, self.fail, self.rewards = None, None, None
        if load_count == -1:
            load_count = len(self.episodes)
        for eps_id in tqdm(range(load_count), desc="Loading Episodes", colour="green"):
            eps = self.episodes[eps_id]
            assert (
                "success" in eps
            ), "episodes in this dataset do not have the success attribute, cannot load dataset with success_only=True"
            if not eps["success"]:
                continue
            trajectory = self.data[f"traj_{eps['episode_id']}"]
            trajectory = load_h5_data(trajectory)
            eps_len = len(trajectory["actions"])
            #print(f"Episode {eps_id} has {eps_len} steps")

            # exclude the final observation as most learning workflows do not use it
            obs = common.index_dict_array(trajectory["obs"], slice(eps_len))
            if eps_id == 0:
                self.obs = obs
            else:
                self.obs = common.append_dict_array(self.obs, obs)

            self.actions.append(trajectory["actions"])
            self.terminated.append(trajectory["terminated"])
            self.truncated.append(trajectory["truncated"])


            end_episode = [False] * eps_len
            end_episode[-1] = True
            #is_terminated = False
            #for i in range(len(end_episode)):
            #    if trajectory["terminated"][i] == True or is_terminated:
            #        end_episode[i] = True
            #        is_terminated = True
            #    else:
            #        end_episode[i] = False

            #print(f"Episode {eps_id} has {end_episode.count(True)} end of episodes")
            self.end_episode.append(end_episode)
            #self.truncated[self.terminated:] = True

            # handle data that might optionally be in the trajectory
            if "rewards" in trajectory:
                if self.rewards is None:
                    self.rewards = [trajectory["rewards"]]
                else:
                    self.rewards.append(trajectory["rewards"])
            if "success" in trajectory:
                if self.success is None:
                    self.success = [trajectory["success"]]
                else:
                    self.success.append(trajectory["success"])
            if "fail" in trajectory:
                if self.fail is None:
                    self.fail = [trajectory["fail"]]
                else:
                    self.fail.append(trajectory["fail"])

        self.actions = np.vstack(self.actions)
        self.terminated = np.concatenate(self.terminated)
        self.truncated = np.concatenate(self.truncated)
        self.end_episode = np.concatenate(self.end_episode)
        


        if self.rewards is not None:
            self.rewards = np.concatenate(self.rewards)
        if self.success is not None:
            self.success = np.concatenate(self.success)
        if self.fail is not None:
            self.fail = np.concatenate(self.fail)

        def remove_np_uint16(x: Union[np.ndarray, dict]):
            if isinstance(x, dict):
                for k in x.keys():
                    x[k] = remove_np_uint16(x[k])
                return x
            else:
                if x.dtype == np.uint16:
                    return x.astype(np.int32)
                return x

        # uint16 dtype is used to conserve disk space and memory
        # you can optimize this dataset code to keep it as uint16 and process that
        # dtype of data yourself. for simplicity we simply cast to a int32 so
        # it can automatically be converted to torch tensors without complaint
        self.obs = remove_np_uint16(self.obs)
        
        if device is not None:
            self.actions = common.to_tensor(self.actions, device=device)
            self.obs = common.to_tensor(self.obs, device=device)
            self.terminated = common.to_tensor(self.terminated, device=device)
            self.truncated = common.to_tensor(self.truncated, device=device)
            if self.rewards is not None:
                self.rewards = common.to_tensor(self.rewards, device=device)
            if self.success is not None:
                self.success = common.to_tensor(self.terminated, device=device)
            if self.fail is not None:
                self.fail = common.to_tensor(self.truncated, device=device)
        


        # Added code for diffusion policy
        obs_dict = get_observations(self.obs)
        self.train_data = dict(
                        obs=convert_observation(obs_dict, self.task_id),
                        actions=self.actions,
                        )

         # Initialize index lists and stat dicts
        self.indices = create_sample_indices(
            episode_ends=self.end_episode, 
            sequence_length=self.pred_horizon,
            pad_before=self.obs_horizon - 1,
            pad_after=self.action_horizon - 1
        )


    def __len__(self):
        # all possible sequenzes of the dataset
        return len(self.indices)

    def __getitem__(self, idx):
        # Change data to fit diffusion policy
        buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx = self.indices[idx]

    
        sampled = sample_sequence(
            train_data=self.train_data, 
            sequence_length=self.pred_horizon,
            buffer_start_idx=buffer_start_idx,
            buffer_end_idx=buffer_end_idx,
            sample_start_idx=sample_start_idx,
            sample_end_idx=sample_end_idx
        )
    
        # discard unused observations in the sequence
        for k in sampled.keys():
            if k != "actions":
                # discard unused observations in the sequence
                sampled[k] = sampled[k][:self.obs_horizon,:]
        sampled[k] = common.to_tensor(sampled[k], device=self.device)

        return sampled

#### Network

In [None]:

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class Downsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)

class Upsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Conv1dBlock(nn.Module):
    '''
        Conv1d --> GroupNorm --> Mish
    '''

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.GroupNorm(n_groups, out_channels),
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)


class ConditionalResidualBlock1D(nn.Module):
    def __init__(self,
            in_channels,
            out_channels,
            cond_dim,
            kernel_size=3,
            n_groups=8):
        super().__init__()

        self.blocks = nn.ModuleList([
            Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
            Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
        ])

        # FiLM modulation https://arxiv.org/abs/1709.07871
        # predicts per-channel scale and bias
        cond_channels = out_channels * 2
        self.out_channels = out_channels
        self.cond_encoder = nn.Sequential(
            nn.Mish(),
            nn.Linear(cond_dim, cond_channels),
            nn.Unflatten(-1, (-1, 1))
        )

        # make sure dimensions compatible
        self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
            if in_channels != out_channels else nn.Identity()

    def forward(self, x, cond):
        '''
            x : [ batch_size x in_channels x horizon ]
            cond : [ batch_size x cond_dim]

            returns:
            out : [ batch_size x out_channels x horizon ]
        '''
        out = self.blocks[0](x)
        embed = self.cond_encoder(cond)

        embed = embed.reshape(
            embed.shape[0], 2, self.out_channels, 1)
        scale = embed[:,0,...]
        bias = embed[:,1,...]
        out = scale * out + bias

        out = self.blocks[1](out)
        out = out + self.residual_conv(x)
        return out


class ConditionalUnet1D(nn.Module):
    def __init__(self,
        input_dim,
        global_cond_dim,
        diffusion_step_embed_dim=256,
        down_dims=[256,512,1024],
        kernel_size=5,
        n_groups=8
        ):
        """
        input_dim: Dim of actions.
        global_cond_dim: Dim of global conditioning applied with FiLM
          in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
        diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
        down_dims: Channel size for each UNet level.
          The length of this array determines numebr of levels.
        kernel_size: Conv kernel size
        n_groups: Number of groups for GroupNorm
        """

        super().__init__()
        all_dims = [input_dim] + list(down_dims)
        start_dim = down_dims[0]

        dsed = diffusion_step_embed_dim
        diffusion_step_encoder = nn.Sequential(
            SinusoidalPosEmb(dsed),
            nn.Linear(dsed, dsed * 4),
            nn.Mish(),
            nn.Linear(dsed * 4, dsed),
        )
        cond_dim = dsed + global_cond_dim

        in_out = list(zip(all_dims[:-1], all_dims[1:]))
        mid_dim = all_dims[-1]
        self.mid_modules = nn.ModuleList([
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
        ])

        down_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (len(in_out) - 1)
            down_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_in, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_out, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Downsample1d(dim_out) if not is_last else nn.Identity()
            ]))

        up_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (len(in_out) - 1)
            up_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_out*2, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_in, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Upsample1d(dim_in) if not is_last else nn.Identity()
            ]))

        final_conv = nn.Sequential(
            Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
            nn.Conv1d(start_dim, input_dim, 1),
        )

        self.diffusion_step_encoder = diffusion_step_encoder
        self.up_modules = up_modules
        self.down_modules = down_modules
        self.final_conv = final_conv

        print("number of parameters: {:e}".format(
            sum(p.numel() for p in self.parameters()))
        )

    def forward(self,
            sample: torch.Tensor,
            timestep: Union[torch.Tensor, float, int],
            global_cond=None):
        """
        x: (B,T,input_dim)
        timestep: (B,) or int, diffusion step
        global_cond: (B,global_cond_dim)
        output: (B,T,input_dim)
        """
        # (B,T,C)
        sample = sample.moveaxis(-1,-2)
        # (B,C,T)

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])

        global_feature = self.diffusion_step_encoder(timesteps)

        if global_cond is not None:
            global_feature = torch.cat([
                global_feature, global_cond
            ], axis=-1)

        x = sample
        h = []
        for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            h.append(x)
            x = downsample(x)

        for mid_module in self.mid_modules:
            x = mid_module(x, global_feature)

        for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            x = upsample(x)

        x = self.final_conv(x)

        # (B,C,T)
        x = x.moveaxis(-1,-2)
        # (B,T,C)
        return x

### Setup

In [None]:

#=====================================CHANGE=========================================
env_id = 'PickCube-v1'
env_id_transfer = 'StackCube-v1'
#env_id = 'StackCube-v1'
#env_id = 'PegInsertionSide-v2'
#env_id = 'PlugCharger-v2'
#env_id = 'PushCube-v1'
obs_mode = 'state_dict'
control_mode = 'pd_joint_delta_pos'

pred_horizon = 16
obs_horizon = 2
action_horizon = 8

num_epochs = 50 # number of epochs to train (default: 50)

#======================================CHANGE========================================

task_id = {
    'PickCube-v1': 0.0,
    'StackCube-v1': 0.1,
    'PegInsertionSide-v1': 0.2,
    'PlugCharger-v1': 0.3,
    'PushCube-v1': 0.4
}

#exclude_features = [25, 26, 27, 28, 29, 30, 31, 39] # goal pose x, y, z, qw, qx, qy, qz and task_id
exclude_features = [39] # task_id should be excluded and not used for normalization

# part of the path to the dataset
base_path = '/content/drive/MyDrive/Data'
generated_path = f'{base_path}/Generated/{env_id}/motionplanning'
generated_path_transfer = f'{base_path}/Generated/{env_id_transfer}/motionplanning'
checkpoints_load_path = f'{base_path}/Checkpoints/{env_id}'
checkpoints_path = f'{base_path}/Checkpoints/{env_id}_to_{env_id_transfer}'
results_path = f'{base_path}/Results/{env_id}_to_{env_id_transfer}'
information = f'_p{pred_horizon}_o{obs_horizon}_a{action_horizon}_e{100}' # probably 100 for epochs
information_transfer = f'_p{pred_horizon}_o{obs_horizon}_a{action_horizon}_e{num_epochs}'

# load data
train_dataset_path_transfer = f'{generated_path_transfer}/training.{obs_mode}.{control_mode}.h5'
val_dataset_path = f'{generated_path}/validation.{obs_mode}.{control_mode}.h5'
val_dataset_path_transfer = f'{generated_path_transfer}/validation.{obs_mode}.{control_mode}.h5'
model_path = f'{checkpoints_load_path}/model{information}.pt'

# save results
model_path_transfer = f'{checkpoints_path}/model{information_transfer}.pt'
loss_path = f'{results_path}/loss{information_transfer}.npz'
plot_path = f'{results_path}/plot{information_transfer}.png'
animation_path = f'{results_path}/animation{information_transfer}.gif'


# create dataset from file
train_dataset = StateDataset(
    dataset_file=train_dataset_path_transfer,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon,
    task_id=task_id[env_id_transfer],
    load_count=300, # should be around 1000 for training
    device=None
)

val_dataset = StateDataset(
    dataset_file=val_dataset_path,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon,
    task_id=task_id[env_id],
    load_count=100, # should be around 100 for validation
    device=None
)

val_dataset_transfer = StateDataset(
    dataset_file=val_dataset_path_transfer,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon,
    task_id=task_id[env_id_transfer],
    load_count=100, # should be around 100 for validation
    device=None
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=128,
    num_workers=1,
    # don't kill worker process afte each epoch
    persistent_workers=True,
    shuffle=True
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=128,
    num_workers=1,
    # don't kill worker process afte each epoch
    persistent_workers=True,
    shuffle=True
)

val_dataloader_transfer = DataLoader(
    val_dataset_transfer,
    batch_size=128,
    num_workers=1,
    # don't kill worker process afte each epoch
    persistent_workers=True,
    shuffle=True
)

stats = get_min_max_values(train_dataloader, exclude_features)

# visualize data in batch
batch = next(iter(train_dataloader))
print(batch.keys())
print("Data obs:", batch['obs'].shape, batch['obs'].dtype)
print("Data actions:", batch['actions'].shape, batch['actions'].dtype)


# observation and action dimensions corrsponding to the dataset
obs_dim = batch['obs'].shape[-1]
action_dim = batch['actions'].shape[-1]
print("obs_dim:", obs_dim)
print("action_dim:", action_dim)

# create network object
noise_pred_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon
)

# example inputs
noised_action = torch.randn((1, pred_horizon, action_dim))
obs = torch.zeros((1, obs_horizon, obs_dim))
diffusion_iter = torch.zeros((1,))

# for this demo, we use DDPMScheduler with 100 diffusion iterations
num_diffusion_iters = 100
noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_diffusion_iters,
    # the choise of beta schedule has big impact on performance
    # we found squared cosine works the best
    beta_schedule='squaredcos_cap_v2',
    # clip output to [-1,1] to improve stability
    clip_sample=True,
    # our network predicts noise (instead of denoised action)
    prediction_type='epsilon'
)

# device transfer
device = torch.device("cuda")
_ = noise_pred_net.to(device)

### Training

In [None]:
# Exponential Moving Average
# accelerates training and improves stability
# holds a copy of the model weights
ema = EMAModel(
    parameters=noise_pred_net.parameters(),
    power=0.75)

# Standard ADAM optimizer
# Note that EMA parametesr are not optimized
optimizer = torch.optim.AdamW(
    params=noise_pred_net.parameters(),
    lr=1e-4, weight_decay=1e-6)

# Cosine LR schedule with linear warmup
lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(train_dataloader) * num_epochs
)

state_dict = torch.load(model_path, map_location='cuda')
noise_pred_net.load_state_dict(state_dict['model_state_dict'])
print('Pretrained weights loaded.')

train_losses = list()
val_losses = list()
val_losses_transfer = list()

with tqdm(range(num_epochs), desc='Epoch') as tglobal:
    # epoch loop
    for epoch_idx in tglobal:
        train_loss = list()
        
        # training loop
        noise_pred_net.train()
        with tqdm(train_dataloader, desc='Train Batch', leave=False) as tepoch:
            for batch in tepoch:
                # data normalized in dataset
                nbatch = normalize_batch({"obs": batch["obs"]}, stats)
                nbatch = nbatch["obs"]

                # device transfer
                nobs = nbatch.to(device)
                naction = batch['actions'].to(device)
                B = nobs.shape[0]

                # observation as FiLM conditioning
                # (B, obs_horizon, obs_dim)
                obs_cond = nobs[:,:obs_horizon,:]
                # (B, obs_horizon * obs_dim)
                obs_cond = obs_cond.flatten(start_dim=1)

                # sample noise to add to actions
                noise = torch.randn(naction.shape, device=device)

                # sample a diffusion iteration for each data point
                timesteps = torch.randint(
                    0, noise_scheduler.config.num_train_timesteps,
                    (B,), device=device
                ).long()

                # add noise to the clean images according to the noise magnitude at each diffusion iteration
                # (this is the forward diffusion process)
                noisy_actions = noise_scheduler.add_noise(
                    naction, noise, timesteps)

                # predict the noise residual
                noise_pred = noise_pred_net(
                    noisy_actions, timesteps, global_cond=obs_cond)

                # L2 loss
                loss = nn.functional.mse_loss(noise_pred, noise)

                # optimize
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                # step lr scheduler every batch
                # this is different from standard pytorch behavior
                lr_scheduler.step()

                # update Exponential Moving Average of the model weights
                ema.step(noise_pred_net.parameters())

                # logging
                loss_cpu = loss.item()
                train_loss.append(loss_cpu)
                tepoch.set_postfix(loss=loss_cpu)
        
        
        # validation loop
        val_loss = list()
        noise_pred_net.eval()
        with torch.no_grad():
            with tqdm(val_dataloader, desc='Val Batch', leave=False) as vepoch:
                for batch in vepoch:
                    # Normalize data
                    nbatch = normalize_batch({"obs": batch["obs"]}, stats)
                    nbatch = nbatch["obs"]
                    # Device transfer
                    nobs = nbatch.to(device)
                    naction = batch['actions'].to(device)
                    B = nobs.shape[0]

                    # Observation as FiLM conditioning
                    obs_cond = nobs[:, :obs_horizon, :].flatten(start_dim=1)

                    # Sample noise
                    noise = torch.randn(naction.shape, device=device)

                    # Sample diffusion iteration
                    timesteps = torch.randint(
                        0, noise_scheduler.config.num_train_timesteps,
                        (B,), device=device
                    ).long()

                    # Add noise to actions
                    noisy_actions = noise_scheduler.add_noise(
                        naction, noise, timesteps)

                    # Predict noise residual
                    noise_pred = noise_pred_net(
                        noisy_actions, timesteps, global_cond=obs_cond)

                    # L2 loss
                    loss = F.mse_loss(noise_pred, noise)

                    # Logging
                    loss_cpu = loss.item()
                    val_loss.append(loss_cpu)
                    vepoch.set_postfix(loss=loss_cpu)

        # validation loop 2
        val_loss_transfer = list()
        noise_pred_net.eval()
        with torch.no_grad():
            with tqdm(val_dataloader_transfer, desc='Val Batch', leave=False) as vepoch:
                for batch in vepoch:
                    # Normalize data
                    nbatch = normalize_batch({"obs": batch["obs"]}, stats)
                    nbatch = nbatch["obs"]
                    # Device transfer
                    nobs = nbatch.to(device)
                    naction = batch['actions'].to(device)
                    B = nobs.shape[0]

                    # Observation as FiLM conditioning
                    obs_cond = nobs[:, :obs_horizon, :].flatten(start_dim=1)

                    # Sample noise
                    noise = torch.randn(naction.shape, device=device)

                    # Sample diffusion iteration
                    timesteps = torch.randint(
                        0, noise_scheduler.config.num_train_timesteps,
                        (B,), device=device
                    ).long()

                    # Add noise to actions
                    noisy_actions = noise_scheduler.add_noise(
                        naction, noise, timesteps)

                    # Predict noise residual
                    noise_pred = noise_pred_net(
                        noisy_actions, timesteps, global_cond=obs_cond)

                    # L2 loss
                    loss = F.mse_loss(noise_pred, noise)

                    # Logging
                    loss_cpu = loss.item()
                    val_loss_transfer.append(loss_cpu)
                    vepoch.set_postfix(loss=loss_cpu)

        # Logging
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        val_losses_transfer.append(val_loss_transfer)
        tglobal.set_postfix(
            train_loss=np.mean(train_loss),
            val_loss=np.mean(val_loss),
            val_loss_transfer=np.mean(val_loss_transfer)
        )

# Weights of the EMA model
# is used for inference
ema_noise_pred_net = noise_pred_net
ema.copy_to(ema_noise_pred_net.parameters())

#### Saving Training Data

In [None]:
torch.save({
    'model_state_dict': ema_noise_pred_net.state_dict(),
    'ema_model_state_dict': ema.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'lr_scheduler_state_dict': lr_scheduler.state_dict(),
    'epoch': epoch_idx,
    "stats": stats,
    'loss': loss, # Save the current epoch
}, model_path_transfer)

# Save the training losses
np.savez(loss_path, train_losses=train_losses, val_losses=val_losses, val_losses_transfer=val_losses_transfer)  # Multiple arrays in one file

#### Visualization

In [None]:
# Load the data from NPZ file
with np.load(loss_path) as data:
    train_losses = data['train_losses']
    val_losses = data['val_losses']
    val_losses_transfer = data['val_losses_transfer']

# Number of epochs
num_epochs = train_losses.shape[0]

# Calculate x-axis positions for each step, based on epoch boundaries
x_positions_train = []
x_positions_val = []
x_positions_val_transfer = []
for epoch in range(num_epochs):
    epoch_start = epoch 
    epoch_end = epoch + 1

    # Training steps within the epoch
    num_steps_train = len(train_losses[epoch])
    step_positions_train = np.linspace(epoch_start, epoch_end, num_steps_train, endpoint=False)
    x_positions_train.extend(step_positions_train)

    # Validation steps within the epoch (if available)
    if epoch < len(val_losses):  
        num_steps_val = len(val_losses[epoch])
        step_positions_val = np.linspace(epoch_start, epoch_end, num_steps_val, endpoint=False)
        x_positions_val.extend(step_positions_val)
    
    # Validation steps within the epoch (if available)
    if epoch < len(val_losses_transfer):  
        num_steps_val = len(val_losses_transfer[epoch])
        step_positions_val = np.linspace(epoch_start, epoch_end, num_steps_val, endpoint=False)
        x_positions_val_transfer.extend(step_positions_val)

x_positions_train = np.array(x_positions_train)
x_positions_val = np.array(x_positions_val)
x_positions_val_transfer = np.array(x_positions_val_transfer)

# Smoothing for better visualization (adjust window size as needed)
def smooth_curve(data, x_positions, window_size=10):
    smoothed_data = np.convolve(data, np.ones(window_size) / window_size, mode='valid')
    valid_x_positions = x_positions[window_size // 2:-window_size // 2 + 1]
    if len(smoothed_data) > len(valid_x_positions):  # handle case when data is shorter than window
        smoothed_data = smoothed_data[:len(valid_x_positions)]
    return valid_x_positions, smoothed_data  

x_positions_train, smoothed_training_losses = smooth_curve(train_losses.flatten(), x_positions_train)
x_positions_val, smoothed_validation_losses = smooth_curve(val_losses.flatten(), x_positions_val)
x_positions_val_transfer, smoothed_validation_losses_transfer = smooth_curve(val_losses_transfer.flatten(), x_positions_val_transfer)


# Create plot
plt.figure(figsize=(12, 6))  # Adjust figure size
plt.plot(x_positions_train, smoothed_training_losses, label='Training Loss', color='black', linestyle='--')
plt.plot(x_positions_val, smoothed_validation_losses, label='Task Val Loss', color='red', linestyle='-')
plt.plot(x_positions_val_transfer, smoothed_validation_losses_transfer, label='Transfer Task Val Loss ', color='blue', linestyle='-')
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.yscale('log')
plt.title(f'Training and Validation Loss per Epoch for {env_id_transfer} on {env_id} model', fontsize=14)
plt.xticks(range(num_epochs+1))  # Set x ticks at integer epoch values
plt.legend(fontsize=12)
plt.grid(axis='y', linestyle='--')  # Grid only on y-axis for better readability
plt.tight_layout()
plt.savefig(plot_path)  # Save the plot
plt.show()

### Inference

In [None]:
env = gym.make(env_id_transfer, obs_mode=obs_mode, control_mode=control_mode, render_mode='rgb_array')

max_steps = 400

helper_techniques = False

num_episodes = 50
mean_success = 0 
mean_reward = 0
rewards = []
csv_file = f"{results_path}/results{information}.csv"
with open(csv_file, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['Episode', 'Max Reward', 'Success'])
    print(f"Opened file {csv_file} for writing.")

    with tqdm(range(num_episodes), desc='Epoch') as episodes:

        for episode in episodes:
            
            # reset 
            obs, info = env.reset()
            obs = get_observations(obs)
            obs = convert_observation(obs, task_id[env_id_transfer])

            # save observations
            obs_deque = collections.deque([obs] * obs_horizon, maxlen=obs_horizon)
            actions_deque = collections.deque([[0,0,0,0,0,0,1]] * action_horizon, maxlen=action_horizon)

            obs_seq = np.stack(obs_deque)  
            print("obs unnormalized", obs_seq)
            nobs = normalize_batch({'obs': torch.tensor(obs_seq, dtype=torch.float32)}, stats)
            nobs = nobs['obs']
            print("obs normalized",nobs)
            print("stats",stats["obs"]["max"])
            # save visualization
            imgs = []
            rewards = []
            done = False
            step_idx = 0
            unsuccessful = False


            with tqdm(total=max_steps, desc="Eval", leave=False) as pbar:
                while not done:
                    B = 1
                    # stack the last obs_horizon (2) number of observations
                    obs_seq = np.stack(obs_deque)
                    if env_id == "PickCube-v1" and helper_techniques:
                      obs_seq[:, :,27] += 0.01
                    nobs = normalize_batch({'obs': torch.tensor(obs_seq, dtype=torch.float32)}, stats)
                    nobs = nobs['obs']
                    
                    # device transfer
                    #nobs = torch.from_numpy(nobs).to(device, dtype=torch.float32)
                    nobs= nobs.to(device)

                    # infer action
                    with torch.no_grad():
                        # reshape observation to (B,obs_horizon*obs_dim)
                        obs_cond = nobs.unsqueeze(0).flatten(start_dim=1)

                        # initialize action from Guassian noise
                        noisy_action = torch.randn(
                            (B, pred_horizon, action_dim), device=device)
                        naction = noisy_action

                        # init scheduler
                        noise_scheduler.set_timesteps(num_diffusion_iters)

                        for k in noise_scheduler.timesteps:
                            # predict noise
                            noise_pred = ema_noise_pred_net(
                                sample=naction,
                                timestep=k,
                                global_cond=obs_cond
                            )

                            # inverse diffusion step (remove noise)
                            naction = noise_scheduler.step(
                                model_output=noise_pred,
                                timestep=k,
                                sample=naction
                            ).prev_sample

                    # unnormalize action
                    naction = naction.detach().to('cpu').numpy()
                    # (B, pred_horizon, action_dim)
                    action_pred = naction[0] # we dont have to denormalize the action

                    # only take action_horizon number of actions
                    start = obs_horizon - 1
                    end = start + action_horizon
                    action = action_pred[start:end,:]

                    for i in range(len(action)):
                        action[i] = np.clip(action[i], -1, 1)
                        actions_deque.append(action[i])

                    # execute action_horizon number of steps
                    # without replanning
                    for i in range(len(action)):
               
                        # only allow gripper action to be same for action_horizon number of steps
                        modified_action = action[i]
                        
                        #if env_id == "PickCube-v1":

                          #if len(rewards) > 0 and rewards[-1] > 0.4: #Threshold for almost there
                          #  modified_action[-1] = -1

                        if env_id == "StackCube-v1" and helper_techniques:
                          same_gripper_action = 0
                          last_x_action = 5
                          start_idx = len(actions_deque) - last_x_action
                          last_x_actions = list(actions_deque)[start_idx:]
        
                          for i in last_x_actions:
                              same_gripper_action += i[-1]
                      
                          if same_gripper_action >= 0:
                            modified_action[-1] = 1# change the gripper action to opposite
                          else :
                            modified_action[-1] = -1
                        
                        
                        obs, reward, done, _, info = env.step(modified_action)

                        # process observation
                        # From the observation dictionary, we concatenate all the observations
                        # as done in the training data
                        obs = get_observations(obs)
                        obs = convert_observation(obs, task_id[env_id_transfer])

                        # save observations
                        obs_deque.append(obs)

                        # and reward/vis
                        rewards.append(reward)
                        imgs.append(env.render())

                        # update progress bar
                        step_idx += 1
                        pbar.update(1)
                        pbar.set_postfix(reward=reward)
                        if step_idx > max_steps:
                          
                            done = True
                            unsuccessful = True
                        if done:
                            break
            
            if not unsuccessful:
                mean_success += 1
            mean_reward += max(rewards)
            writer.writerow([episode + 1, max(rewards), int(not unsuccessful)])
            episodes.set_postfix(
                reward=mean_reward / (episode + 1),
                success=mean_success / (episode + 1)
            )

            
        

    print("Reward: ", mean_reward / num_episodes)
    print("Success: ", mean_success/num_episodes)

### Save gif

In [None]:
images = [Image.fromarray(img.squeeze(0).cpu().numpy()) for img in imgs]

# Save to a bytes buffer
buffer = io.BytesIO()
images[0].save(buffer, format='GIF', save_all=True, append_images=images[1:], optimize=False, duration=50, loop=0)
buffer.seek(0)

# Save to a file
with open(animation_path, 'wb') as f:
    f.write(buffer.getvalue())

# Display the GIF (optional)
display(IPImage(data=buffer.getvalue()))