### Pip Install

In [1]:
# install the package
%pip install --upgrade mani_skill
# install a version of torch that is compatible with your system
%pip install torch torchvision torchaudio numpy diffusers

import math
from typing import Union
import h5py
from tqdm import tqdm
import numpy as np
import os

from mani_skill.utils import common
from mani_skill.utils.io_utils import load_json
from mani_skill.utils.common import flatten_state_dict
import mani_skill.envs
from collections import OrderedDict
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.utils.data import IterableDataset, Dataset
from torch.utils.data import DataLoader
from typing import Tuple, Sequence, Dict, Union, Optional
import numpy as np
import math
import torch
import torch.nn as nn
import collections
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm

# env import
import gymnasium as gym
from gymnasium import spaces

Collecting mani_skill
  Downloading mani_skill-3.0.0b5-py3-none-any.whl (58.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.9/58.9 MB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
Collecting dacite (from mani_skill)
  Downloading dacite-1.8.1-py3-none-any.whl (14 kB)
Collecting gymnasium==0.29.1 (from mani_skill)
  Downloading gymnasium-0.29.1-py3-none-any.whl (953 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m953.9/953.9 kB[0m [31m39.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sapien==3.0.0.b1 (from mani_skill)
  Downloading sapien-3.0.0b1-cp310-cp310-manylinux2014_x86_64.whl (49.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.6/49.6 MB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython (from mani_skill)
  Downloading GitPython-3.1.43-py3-none-any.whl (207 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
Collecting tran

  warn("Failed to find system libvulkan. Fallback to SAPIEN builtin libvulkan.")
  warn(
  warn(
The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import h5py
import json
import numpy as np

# File names
original_h5_file = '/content/drive/MyDrive/data3/trajectory.state_dict.pd_joint_pos.h5'
original_json_file = '/content/drive/MyDrive/data3/trajectory.state_dict.pd_joint_pos.json'
training_h5_file = '03_pd_joint_pos_training_set.h5'
validation_h5_file = '03_pd_joint_pos_validation_set.h5'
training_json_file = '03_pd_joint_pos_training_set.json'
validation_json_file = '03_pd_joint_pos_validation_set.json'

with open(original_json_file, 'r') as f:
    json_data = json.load(f)

with h5py.File(original_h5_file, 'r') as f:
    traj_datasets = [key for key in f.keys() if key.startswith('traj_')]

    np.random.shuffle(traj_datasets)

    split_point = int(0.75 * len(traj_datasets))

    training_datasets = traj_datasets[:split_point]
    validation_datasets = traj_datasets[split_point:]

    with h5py.File(training_h5_file, 'w') as f_train:
        for key in training_datasets:
            f.copy(key, f_train)

    with h5py.File(validation_h5_file, 'w') as f_val:
        for key in validation_datasets:
            f.copy(key, f_val)

dataset_indices = {f'traj_{i}': i for i in range(len(json_data['episodes']))}

training_episodes = [json_data['episodes'][dataset_indices[key]] for key in training_datasets]
validation_episodes = [json_data['episodes'][dataset_indices[key]] for key in validation_datasets]

training_json = {
    'env_info': json_data['env_info'],
    'episodes': training_episodes
}

validation_json = {
    'env_info': json_data['env_info'],
    'episodes': validation_episodes
}

with open(training_json_file, 'w') as f:
    json.dump(training_json, f, indent=4)

with open(validation_json_file, 'w') as f:
    json.dump(validation_json, f, indent=4)

### Classes

#### Dataset Code

In [4]:
def load_h5_data(data):
    out = dict()
    for k in data.keys():
        if isinstance(data[k], h5py.Dataset):
            out[k] = data[k][:]
        else:
            out[k] = load_h5_data(data[k])
    return out


def create_sample_indices(episode_ends: np.ndarray, sequence_length: int, pad_before: int = 0, pad_after: int = 0):
    # Currently uses truncated as episode ends which is the end of the episode and not the end of the trajectory
    indices = list()
    episode_length = 0
    episode_index = 1 # Start 1 for human readability
    for i in range(len(episode_ends)):
        episode_length += 1
        if episode_ends[i]:
            start_idx = 0 if i <= 0 else i - episode_length + 1
            min_start = -pad_before
            max_start = episode_length - sequence_length + pad_after

            # Create indices for each possible sequence in the episode
            for idx in range(min_start, max_start + 1):
                buffer_start_idx = max(idx, 0) + start_idx
                buffer_end_idx = min(idx + sequence_length, episode_length) + start_idx
                start_offset = buffer_start_idx - (idx + start_idx)
                end_offset = (idx + sequence_length + start_idx) - buffer_end_idx
                sample_start_idx = 0 + start_offset
                sample_end_idx = sequence_length - end_offset
                indices.append([buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx])
            episode_length = 0
            episode_index += 1
    return np.array(indices)


def sample_sequence(train_data, sequence_length, buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx):
    result = dict()
    for key, input_arr in train_data.items():
        sample = input_arr[buffer_start_idx:buffer_end_idx]
        data = sample
        if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
            if isinstance(input_arr, torch.Tensor):
                data = torch.zeros((sequence_length,) + input_arr.shape[1:], dtype=input_arr.dtype)
            else:
                data = np.zeros(shape=(sequence_length,) + input_arr.shape[1:], dtype=input_arr.dtype)
            if sample_start_idx > 0:
                data[:sample_start_idx] = sample[0]
            if sample_end_idx < sequence_length:
                data[sample_end_idx:] = sample[-1]
            data[sample_start_idx:sample_end_idx] = sample
        result[key] = data
    return result

def remove_np_uint16(x: Union[np.ndarray, dict]):
            if isinstance(x, dict):
                for k in x.keys():
                    x[k] = remove_np_uint16(x[k])
                return x
            else:
                if x.dtype == np.uint16:
                    return x.astype(np.int32)
                return x

def convert_observation(obs, task_id):
    # adds task_id to the observation
    values = list(obs.values())
    example = values[0]
    if isinstance(example, torch.Tensor):
          example = example.numpy()

    # add task_id to the observation
    task_id_array = np.full((example.shape[0], 1), task_id, dtype=example.dtype)
    values.append(task_id_array)
    # concatenate all the values
    return np.concatenate(values, axis=-1)

def get_observations(obs):
    #ensoure that the observations are in the correct format
    #and ordered correctly across tasks

    cleaned_obs = OrderedDict()
    cleaned_obs["qpos"] = obs["agent"]["qpos"]
    cleaned_obs["qvel"] = obs["agent"]["qvel"]
    cleaned_obs["tcp_pose"] = obs["extra"]["tcp_pose"]
    obs["extra"].pop("tcp_pose")

    #this code is not generic and only works for the specific observation spaces we have
    # Handle different goal position formats gracefully
    goal_pose_keys = ["goal_pose", "goal_pos", "box_hole_pose", "cubeB_pose"]
    for key in goal_pose_keys:
        if key in obs["extra"]:
            pos = obs["extra"][key]

            # Ensure 'pos' is 2D with the correct number of columns
            if pos.ndim == 1:
                pos = pos.reshape(1, -1)  # Reshape to 2D if necessary
            elif pos.ndim > 2:
                raise ValueError(f"Unexpected dimensions for '{key}': {pos.shape}")

            # Pad or truncate 'pos' to have 7 columns
            pos = np.pad(pos[:, :7], ((0, 0), (0, 7 - pos.shape[1])), mode='constant')
            if isinstance(cleaned_obs["tcp_pose"], torch.Tensor):
                pos = torch.tensor(pos, dtype=cleaned_obs["tcp_pose"].dtype)

            cleaned_obs["goal_pose"] = pos
            obs["extra"].pop(key)
            break  # Stop once a valid goal pose key is found
    else:
        print("No goal pose found. Setting to zero.")
        length = len(cleaned_obs["tcp_pose"])
        cleaned_obs["goal_pose"] = np.zeros((length, 7), dtype=np.float32)  # Ensure 2D shape

    #is_grasped_reshaped = np.reshape(obs["extra"]["is_grasped"], (len(obs["extra"]["is_grasped"]), 1))

    # Filter and add other observations with 7 columns
    for key, value in obs["extra"].items():
        if value.shape[-1] == 7 and value.ndim == 2:
            cleaned_obs[key] = value

    return cleaned_obs


def normalize_batch(batch, min_vals, max_vals, exclude_features):
    batch = batch["obs"]
    batch_reshaped = batch.view(-1, batch.shape[-1])
    mask = torch.ones(batch_reshaped.shape[1], dtype=torch.bool)
    mask[exclude_features] = False

    normalized_batch = batch_reshaped.clone()
    normalized_batch[:, mask] = (batch_reshaped[:, mask] - min_vals) / (max_vals - min_vals + 0.1)
    return normalized_batch.view(batch.shape)

def denormalize_batch(batch, min_vals, max_vals, exclude_features):
    batch = batch["obs"]
    batch_reshaped = batch.view(-1, batch.shape[-1])
    mask = torch.ones(batch_reshaped.shape[1], dtype=torch.bool)
    mask[exclude_features] = False

    denormalized_batch = batch_reshaped.clone()
    denormalized_batch[:, mask] = batch_reshaped[:, mask] * (max_vals - min_vals + 0.1) + min_vals
    return denormalized_batch.view(batch.shape)


class StateDataset(Dataset):
    """
    A general torch Dataset you can drop in and use immediately with just about any trajectory .h5 data generated from ManiSkill.
    This class simply is a simple starter code to load trajectory data easily, but does not do any data transformation or anything
    advanced. We recommend you to copy this code directly and modify it for more advanced use cases

    Args:
        dataset_file (str): path to the .h5 file containing the data you want to load
        load_count (int): the number of trajectories from the dataset to load into memory. If -1, will load all into memory
        success_only (bool): whether to skip trajectories that are not successful in the end. Default is false
        device: The location to save data to. If None will store as numpy (the default), otherwise will move data to that device
    """

    def __init__(
        self, dataset_file: str, pred_horizon: int, obs_horizon: int, action_horizon:int, task_id: np.float32, load_count=-1, device=None
    ) -> None:
        self.dataset_file = dataset_file
        self.pred_horizon = pred_horizon
        self.obs_horizon = obs_horizon
        self.action_horizon = action_horizon
        self.task_id = task_id
        self.device = device
        self.data = h5py.File(dataset_file, "r")
        json_path = dataset_file.replace(".h5", ".json")
        self.json_data = load_json(json_path)
        self.episodes = self.json_data["episodes"]
        self.env_info = self.json_data["env_info"]
        self.env_id = self.env_info["env_id"]
        self.env_kwargs = self.env_info["env_kwargs"]

        self.obs = None
        self.actions = []
        self.terminated = []
        self.truncated = []
        self.success, self.fail, self.rewards = None, None, None
        if load_count == -1:
            load_count = len(self.episodes)
        for eps_id in tqdm(range(load_count), desc="Loading Episodes", colour="green"):
            eps = self.episodes[eps_id]
            assert (
                "success" in eps
            ), "episodes in this dataset do not have the success attribute, cannot load dataset with success_only=True"
            if not eps["success"]:
                continue
            trajectory = self.data[f"traj_{eps['episode_id']}"]
            trajectory = load_h5_data(trajectory)
            eps_len = len(trajectory["actions"])

            # exclude the final observation as most learning workflows do not use it
            obs = common.index_dict_array(trajectory["obs"], slice(eps_len))
            if eps_id == 0:
                self.obs = obs
            else:
                self.obs = common.append_dict_array(self.obs, obs)

            self.actions.append(trajectory["actions"])
            self.terminated.append(trajectory["terminated"])
            self.truncated.append(trajectory["truncated"])

            # handle data that might optionally be in the trajectory
            if "rewards" in trajectory:
                if self.rewards is None:
                    self.rewards = [trajectory["rewards"]]
                else:
                    self.rewards.append(trajectory["rewards"])
            if "success" in trajectory:
                if self.success is None:
                    self.success = [trajectory["success"]]
                else:
                    self.success.append(trajectory["success"])
            if "fail" in trajectory:
                if self.fail is None:
                    self.fail = [trajectory["fail"]]
                else:
                    self.fail.append(trajectory["fail"])

        self.actions = np.vstack(self.actions)
        self.terminated = np.concatenate(self.terminated)
        self.truncated = np.concatenate(self.truncated)

        self.truncated = np.zeros(self.actions.shape[0], dtype=bool)
        self.truncated[-1] = True

        if self.rewards is not None:
            self.rewards = np.concatenate(self.rewards)
        if self.success is not None:
            self.success = np.concatenate(self.success)
        if self.fail is not None:
            self.fail = np.concatenate(self.fail)

        def remove_np_uint16(x: Union[np.ndarray, dict]):
            if isinstance(x, dict):
                for k in x.keys():
                    x[k] = remove_np_uint16(x[k])
                return x
            else:
                if x.dtype == np.uint16:
                    return x.astype(np.int32)
                return x

        # uint16 dtype is used to conserve disk space and memory
        # you can optimize this dataset code to keep it as uint16 and process that
        # dtype of data yourself. for simplicity we simply cast to a int32 so
        # it can automatically be converted to torch tensors without complaint
        self.obs = remove_np_uint16(self.obs)

        if device is not None:
            self.actions = common.to_tensor(self.actions, device=device)
            self.obs = common.to_tensor(self.obs, device=device)
            self.terminated = common.to_tensor(self.terminated, device=device)
            self.truncated = common.to_tensor(self.truncated, device=device)
            if self.rewards is not None:
                self.rewards = common.to_tensor(self.rewards, device=device)
            if self.success is not None:
                self.success = common.to_tensor(self.terminated, device=device)
            if self.fail is not None:
                self.fail = common.to_tensor(self.truncated, device=device)



        # Added code for diffusion policy
        obs_dict = get_observations(self.obs)
        self.train_data = dict(
                        obs=convert_observation(obs_dict, self.task_id),
                        actions=self.actions,
                        )

         # Initialize index lists and stat dicts
        self.indices = create_sample_indices(
            episode_ends=self.truncated,
            sequence_length=self.pred_horizon,
            pad_before=self.obs_horizon - 1,
            pad_after=self.action_horizon - 1
        )


    def __len__(self):
        # all possible sequenzes of the dataset
        return len(self.indices)

    def __getitem__(self, idx):
        # Change data to fit diffusion policy
        buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx = self.indices[idx]


        sampled = sample_sequence(
            train_data=self.train_data,
            sequence_length=self.pred_horizon,
            buffer_start_idx=buffer_start_idx,
            buffer_end_idx=buffer_end_idx,
            sample_start_idx=sample_start_idx,
            sample_end_idx=sample_end_idx
        )

        # discard unused observations in the sequence
        for k in sampled.keys():
            if k != "actions":
                # discard unused observations in the sequence
                sampled[k] = sampled[k][:self.obs_horizon,:]
        sampled[k] = common.to_tensor(sampled[k], device=self.device)

        return sampled

#### Network code

In [5]:

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class Downsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)

class Upsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Conv1dBlock(nn.Module):
    '''
        Conv1d --> GroupNorm --> Mish
    '''

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.GroupNorm(n_groups, out_channels),
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)


class ConditionalResidualBlock1D(nn.Module):
    def __init__(self,
            in_channels,
            out_channels,
            cond_dim,
            kernel_size=3,
            n_groups=8):
        super().__init__()

        self.blocks = nn.ModuleList([
            Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
            Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
        ])

        # FiLM modulation https://arxiv.org/abs/1709.07871
        # predicts per-channel scale and bias
        cond_channels = out_channels * 2
        self.out_channels = out_channels
        self.cond_encoder = nn.Sequential(
            nn.Mish(),
            nn.Linear(cond_dim, cond_channels),
            nn.Unflatten(-1, (-1, 1))
        )

        # make sure dimensions compatible
        self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
            if in_channels != out_channels else nn.Identity()

    def forward(self, x, cond):
        '''
            x : [ batch_size x in_channels x horizon ]
            cond : [ batch_size x cond_dim]

            returns:
            out : [ batch_size x out_channels x horizon ]
        '''
        out = self.blocks[0](x)
        embed = self.cond_encoder(cond)

        embed = embed.reshape(
            embed.shape[0], 2, self.out_channels, 1)
        scale = embed[:,0,...]
        bias = embed[:,1,...]
        out = scale * out + bias

        out = self.blocks[1](out)
        out = out + self.residual_conv(x)
        return out


class ConditionalUnet1D(nn.Module):
    def __init__(self,
        input_dim,
        global_cond_dim,
        diffusion_step_embed_dim=256,
        down_dims=[256,512,1024],
        kernel_size=5,
        n_groups=8
        ):
        """
        input_dim: Dim of actions.
        global_cond_dim: Dim of global conditioning applied with FiLM
          in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
        diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
        down_dims: Channel size for each UNet level.
          The length of this array determines numebr of levels.
        kernel_size: Conv kernel size
        n_groups: Number of groups for GroupNorm
        """

        super().__init__()
        all_dims = [input_dim] + list(down_dims)
        start_dim = down_dims[0]

        dsed = diffusion_step_embed_dim
        diffusion_step_encoder = nn.Sequential(
            SinusoidalPosEmb(dsed),
            nn.Linear(dsed, dsed * 4),
            nn.Mish(),
            nn.Linear(dsed * 4, dsed),
        )
        cond_dim = dsed + global_cond_dim

        in_out = list(zip(all_dims[:-1], all_dims[1:]))
        mid_dim = all_dims[-1]
        self.mid_modules = nn.ModuleList([
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
        ])

        down_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (len(in_out) - 1)
            down_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_in, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_out, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Downsample1d(dim_out) if not is_last else nn.Identity()
            ]))

        up_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (len(in_out) - 1)
            up_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_out*2, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_in, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Upsample1d(dim_in) if not is_last else nn.Identity()
            ]))

        final_conv = nn.Sequential(
            Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
            nn.Conv1d(start_dim, input_dim, 1),
        )

        self.diffusion_step_encoder = diffusion_step_encoder
        self.up_modules = up_modules
        self.down_modules = down_modules
        self.final_conv = final_conv

        print("number of parameters: {:e}".format(
            sum(p.numel() for p in self.parameters()))
        )

    def forward(self,
            sample: torch.Tensor,
            timestep: Union[torch.Tensor, float, int],
            global_cond=None):
        """
        x: (B,T,input_dim)
        timestep: (B,) or int, diffusion step
        global_cond: (B,global_cond_dim)
        output: (B,T,input_dim)
        """
        # (B,T,C)
        sample = sample.moveaxis(-1,-2)
        # (B,C,T)

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])

        global_feature = self.diffusion_step_encoder(timesteps)

        if global_cond is not None:
            global_feature = torch.cat([
                global_feature, global_cond
            ], axis=-1)

        x = sample
        h = []
        for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            h.append(x)
            x = downsample(x)

        for mid_module in self.mid_modules:
            x = mid_module(x, global_feature)

        for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            x = upsample(x)

        x = self.final_conv(x)

        # (B,C,T)
        x = x.moveaxis(-1,-2)
        # (B,T,C)
        return x

### Setup

In [16]:
# download demonstration data from Google Drive
env_id = 'PickCube-v1'
#env_id = 'StackCube-v1'
#env_id = 'PegInsertionSide-v1'
#env_id = 'PlugCharger-v1'
#env_id = 'PushCube-v1'
obs_mode = 'state_dict'
control_mode = 'pd_joint_delta_pos'
#control_mode = 'pd_ee_delta_pos'

pred_horizon = 16
obs_horizon = 2
action_horizon = 8

#==============================================================================

task_id = {
    'PickCube-v1': 0.0,
    'StackCube-v1': 0.1,
    'PegInsertionSide-v1': 0.2,
    'PlugCharger-v1': 0.3,
    'PushCube-v1': 0.4
}

exclude_features = [25, 26, 27, 28, 29, 30, 31, 39] # goal pose x, y, z, qw, qx, qy, qz and task_id
min_vals = None
max_vals = None

train_dataset_path = '/content/03_pd_joint_pos_training_set.h5'
val_dataset_path = '/content/03_pd_joint_pos_validation_set.h5'

model_path = f'drive/MyDrive/Data/Checkpoints/{env_id}_{control_mode}_model.pt'


# create dataset from file
train_dataset = StateDataset(
    dataset_file=train_dataset_path,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon,
    task_id=task_id[env_id],
    device=None
)

val_dataset = StateDataset(
    dataset_file=val_dataset_path,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon,
    task_id=task_id[env_id],
    device=None
)

# create dataloader
train_dataloader = DataLoader(
    train_dataset,
    batch_size=128,
    num_workers=1,
    # don't kill worker process afte each epoch
    persistent_workers=True,
    shuffle=True
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=128,
    num_workers=1,
    # don't kill worker process afte each epoch
    persistent_workers=True,
    shuffle=True
)

for batch in train_dataloader:
    batch = batch['obs']
    batch_reshaped = batch.view(-1, batch.shape[-1])
    mask = torch.ones(batch_reshaped.shape[1], dtype=torch.bool)
    mask[exclude_features] = False

    batch_min = batch_reshaped[:, mask].min(dim=0)[0]
    batch_max = batch_reshaped[:, mask].max(dim=0)[0]

    if min_vals is None and max_vals is None:
        min_vals = batch_min
        max_vals = batch_max
    else:
        min_vals = torch.min(min_vals, batch_min)
        max_vals = torch.max(max_vals, batch_max)


# visualize data in batch
batch = next(iter(train_dataloader))
print(batch.keys())
print("observations:", batch['obs'].shape, batch['obs'].dtype)
print("actions:", batch['actions'].shape, batch['actions'].dtype)


# observation and action dimensions corrsponding to the dataset
obs_dim = batch['obs'].shape[-1]
action_dim = batch['actions'].shape[-1]
print("obs_dim:", obs_dim)
print("action_dim:", action_dim)

# create network object
noise_pred_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon
)

# example inputs
noised_action = torch.randn((1, pred_horizon, action_dim))
obs = torch.zeros((1, obs_horizon, obs_dim))
diffusion_iter = torch.zeros((1,))

# the noise prediction network
# takes noisy action, diffusion iteration and observation as input
# predicts the noise added to action
noise = noise_pred_net(
    sample=noised_action,
    timestep=diffusion_iter,
    global_cond=obs.flatten(start_dim=1))

# illustration of removing noise
# the actual noise removal is performed by NoiseScheduler
# and is dependent on the diffusion noise schedule
denoised_action = noised_action - noise

# for this demo, we use DDPMScheduler with 100 diffusion iterations
num_diffusion_iters = 100
noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_diffusion_iters,
    # the choise of beta schedule has big impact on performance
    # we found squared cosine works the best
    beta_schedule='squaredcos_cap_v2',
    # clip output to [-1,1] to improve stability
    clip_sample=True,
    # our network predicts noise (instead of denoised action)
    prediction_type='epsilon'
)

# device transfer
device = torch.device("cuda")
_ = noise_pred_net.to(device)

Loading Episodes: 100%|[32m██████████[0m| 750/750 [00:04<00:00, 164.37it/s]
Loading Episodes: 100%|[32m██████████[0m| 250/250 [00:01<00:00, 215.08it/s]


dict_keys(['obs', 'actions'])
observations: torch.Size([128, 2, 40]) torch.float32
actions: torch.Size([128, 16, 8]) torch.float32
obs_dim: 40
action_dim: 8
number of parameters: 6.636750e+07


### Training

In [17]:
num_epochs = 5
import torch
import numpy as np
from tqdm import tqdm
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import functional as F

num_epochs = 5

# Exponential Moving Average
ema = EMAModel(
    parameters=noise_pred_net.parameters(),
    power=0.75
)

# Standard ADAM optimizer
optimizer = AdamW(
    params=noise_pred_net.parameters(),
    lr=1e-4, weight_decay=1e-6
)

# Cosine LR schedule with linear warmup
lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(train_dataloader) * num_epochs
)

with tqdm(range(num_epochs), desc='Epoch') as tglobal:
    for epoch_idx in tglobal:
        train_loss = list()

        # Training loop
        noise_pred_net.train()
        with tqdm(train_dataloader, desc='Train Batch', leave=False) as tepoch:
            for batch in tepoch:
                # Normalize data
                nbatch = normalize_batch(batch, min_vals, max_vals, exclude_features)

                # Device transfer
                nobs = nbatch.to(device)
                naction = batch['actions'].to(device)
                B = nobs.shape[0]

                # Observation as FiLM conditioning
                obs_cond = nobs[:, :obs_horizon, :].flatten(start_dim=1)

                # Sample noise
                noise = torch.randn(naction.shape, device=device)

                # Sample diffusion iteration
                timesteps = torch.randint(
                    0, noise_scheduler.config.num_train_timesteps,
                    (B,), device=device
                ).long()

                # Add noise to actions
                noisy_actions = noise_scheduler.add_noise(
                    naction, noise, timesteps)

                # Predict noise residual
                noise_pred = noise_pred_net(
                    noisy_actions, timesteps, global_cond=obs_cond)

                # L2 loss
                loss = F.mse_loss(noise_pred, noise)

                # Optimize
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                lr_scheduler.step()

                # Update EMA of model weights
                ema.step(noise_pred_net.parameters())

                # Logging
                loss_cpu = loss.item()
                train_loss.append(loss_cpu)
                tepoch.set_postfix(loss=loss_cpu)

        # Validation loop
        val_loss = list()
        noise_pred_net.eval()
        with torch.no_grad():
            with tqdm(val_dataloader, desc='Val Batch', leave=False) as vepoch:
                for batch in vepoch:
                    # Normalize data
                    nbatch = normalize_batch(batch, min_vals, max_vals, exclude_features)

                    # Device transfer
                    nobs = nbatch.to(device)
                    naction = batch['actions'].to(device)
                    B = nobs.shape[0]

                    # Observation as FiLM conditioning
                    obs_cond = nobs[:, :obs_horizon, :].flatten(start_dim=1)

                    # Sample noise
                    noise = torch.randn(naction.shape, device=device)

                    # Sample diffusion iteration
                    timesteps = torch.randint(
                        0, noise_scheduler.config.num_train_timesteps,
                        (B,), device=device
                    ).long()

                    # Add noise to actions
                    noisy_actions = noise_scheduler.add_noise(
                        naction, noise, timesteps)

                    # Predict noise residual
                    noise_pred = noise_pred_net(
                        noisy_actions, timesteps, global_cond=obs_cond)

                    # L2 loss
                    loss = F.mse_loss(noise_pred, noise)

                    # Logging
                    loss_cpu = loss.item()
                    val_loss.append(loss_cpu)
                    vepoch.set_postfix(loss=loss_cpu)

        tglobal.set_postfix(
            train_loss=np.mean(train_loss),
            val_loss=np.mean(val_loss)
        )

# Use weights of the EMA model for inference
ema_noise_pred_net = noise_pred_net
ema.copy_to(ema_noise_pred_net.parameters())


Epoch:   0%|          | 0/5 [00:00<?, ?it/s]
Train Batch:   0%|          | 0/1071 [00:00<?, ?it/s][A
Train Batch:   0%|          | 0/1071 [00:00<?, ?it/s, loss=1.18][A
Train Batch:   0%|          | 1/1071 [00:00<05:57,  2.99it/s, loss=1.18][A
Train Batch:   0%|          | 1/1071 [00:00<05:57,  2.99it/s, loss=1.17][A
Train Batch:   0%|          | 2/1071 [00:00<04:50,  3.69it/s, loss=1.17][A
Train Batch:   0%|          | 2/1071 [00:00<04:50,  3.69it/s, loss=1.19][A
Train Batch:   0%|          | 3/1071 [00:00<04:30,  3.95it/s, loss=1.19][A
Train Batch:   0%|          | 3/1071 [00:01<04:30,  3.95it/s, loss=1.16][A
Train Batch:   0%|          | 4/1071 [00:01<04:20,  4.09it/s, loss=1.16][A
Train Batch:   0%|          | 4/1071 [00:01<04:20,  4.09it/s, loss=1.16][A
Train Batch:   0%|          | 5/1071 [00:01<04:14,  4.18it/s, loss=1.16][A
Train Batch:   0%|          | 5/1071 [00:01<04:14,  4.18it/s, loss=1.15][A
Train Batch:   1%|          | 6/1071 [00:01<04:12,  4.22it/s, loss=1.1

KeyboardInterrupt: 

#### Saving model

In [None]:
torch.save({
    'model_state_dict': ema_noise_pred_net.state_dict(),
    'ema_model_state_dict': ema.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'lr_scheduler_state_dict': lr_scheduler.state_dict(),
    'epoch': epoch_idx,
    'loss': loss, # Save the current epoch
}, model_path)

#### Loading Model

In [None]:
state_dict = torch.load(model_path, map_location='cuda')
ema_noise_pred_net = noise_pred_net
ema_noise_pred_net.load_state_dict(state_dict['model_state_dict'])
print('Pretrained weights loaded.')

### Inference

In [None]:
# limit enviornment interaction to 200 steps before termination
env = gym.make(env_id, obs_mode=obs_mode, control_mode=control_mode, render_mode='rgb_array')

max_steps = 400

# reset
obs, info = env.reset()
obs = get_observations(obs)
obs = convert_observation(obs, task_id[env_id])

# save observations
obs_deque = collections.deque([obs] * obs_horizon, maxlen=obs_horizon)

# save visualization
imgs = []
rewards = []
done = False
step_idx = 0


with tqdm(total=max_steps, desc="Eval") as pbar:
    while not done:
        B = 1
        # stack the last obs_horizon (2) number of observations
        obs_seq = np.stack(obs_deque)

        nobs = normalize_batch({'obs': torch.tensor(obs_seq, dtype=torch.float32)}, min_vals, max_vals, exclude_features)

        # device transfer
        #nobs = torch.from_numpy(nobs).to(device, dtype=torch.float32)
        nobs= nobs.to(device)

        # infer action
        with torch.no_grad():
            # reshape observation to (B,obs_horizon*obs_dim)
            obs_cond = nobs.unsqueeze(0).flatten(start_dim=1)

            # initialize action from Guassian noise
            noisy_action = torch.randn(
                (B, pred_horizon, action_dim), device=device)
            naction = noisy_action

            # init scheduler
            noise_scheduler.set_timesteps(num_diffusion_iters)

            for k in noise_scheduler.timesteps:
                # predict noise
                noise_pred = ema_noise_pred_net(
                    sample=naction,
                    timestep=k,
                    global_cond=obs_cond
                )

                # inverse diffusion step (remove noise)
                naction = noise_scheduler.step(
                    model_output=noise_pred,
                    timestep=k,
                    sample=naction
                ).prev_sample

        # unnormalize action
        naction = naction.detach().to('cpu').numpy()
        # (B, pred_horizon, action_dim)
        action_pred = naction[0] # we dont have to denormalize the action

        # only take action_horizon number of actions
        start = obs_horizon - 1
        end = start + action_horizon
        action = action_pred[start:end,:]

        # execute action_horizon number of steps
        # without replanning
        for i in range(len(action)):
            # stepping env
            obs, reward, done, _, info = env.step(action[i])

            # process observation
            # From the observation dictionary, we concatenate all the observations
            # as done in the training data
            obs = get_observations(obs)
            obs = convert_observation(obs, task_id[env_id])

            # save observations
            obs_deque.append(obs)

            # and reward/vis
            rewards.append(reward)
            imgs.append(env.render())

            # update progress bar
            step_idx += 1
            pbar.update(1)
            pbar.set_postfix(reward=reward)
            if step_idx > max_steps:
                done = True
            if done:
                break
# print out the maximum target coverage
print('Score: ', max(rewards))

### Save gif

In [None]:
from PIL import Image
from IPython.display import display, Image as IPImage
import io

images = [Image.fromarray(img.squeeze(0).cpu().numpy()) for img in imgs]

# Save to a bytes buffer
buffer = io.BytesIO()
images[0].save(buffer, format='GIF', save_all=True, append_images=images[1:], optimize=False, duration=50, loop=0)
buffer.seek(0)

# Save to a file
with open(f'drive/MyDrive/Data/Results/{env_id}_{control_mode}_animation.gif', 'wb') as f:
    f.write(buffer.getvalue())

# Display the GIF (optional)
display(IPImage(data=buffer.getvalue()))