## Loading

In [None]:
# install the package
%pip install --upgrade mani_skill
# install a version of torch that is compatible with your system
%pip install torch torchvision torchaudio numpy diffusers


# etc imports
from typing import Tuple, Sequence, Dict, Union, Optional
from collections import OrderedDict
import collections
import math
import h5py
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import display, Image as IPImage
import io
import os
import csv

# mani_skill imports
from mani_skill.utils import common
from mani_skill.utils.io_utils import load_json
from mani_skill.utils.common import flatten_state_dict
import mani_skill.envs

#torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, Dataset
from torch.utils.data import DataLoader

# diffuser imports
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler

# gym imports
import gymnasium as gym
from gymnasium import spaces

# google colab imports
from google.colab import drive
drive.mount('/content/drive')

## Network

In [None]:

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class Downsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)

class Upsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Conv1dBlock(nn.Module):
    '''
        Conv1d --> GroupNorm --> Mish
    '''

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.GroupNorm(n_groups, out_channels),
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)


class ConditionalResidualBlock1D(nn.Module):
    def __init__(self,
            in_channels,
            out_channels,
            cond_dim,
            kernel_size=3,
            n_groups=8):
        super().__init__()

        self.blocks = nn.ModuleList([
            Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
            Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
        ])

        # FiLM modulation https://arxiv.org/abs/1709.07871
        # predicts per-channel scale and bias
        cond_channels = out_channels * 2
        self.out_channels = out_channels
        self.cond_encoder = nn.Sequential(
            nn.Mish(),
            nn.Linear(cond_dim, cond_channels),
            nn.Unflatten(-1, (-1, 1))
        )

        # make sure dimensions compatible
        self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
            if in_channels != out_channels else nn.Identity()

    def forward(self, x, cond):
        '''
            x : [ batch_size x in_channels x horizon ]
            cond : [ batch_size x cond_dim]

            returns:
            out : [ batch_size x out_channels x horizon ]
        '''
        out = self.blocks[0](x)
        embed = self.cond_encoder(cond)

        embed = embed.reshape(
            embed.shape[0], 2, self.out_channels, 1)
        scale = embed[:,0,...]
        bias = embed[:,1,...]
        out = scale * out + bias

        out = self.blocks[1](out)
        out = out + self.residual_conv(x)
        return out


class ConditionalUnet1D(nn.Module):
    def __init__(self,
        input_dim,
        global_cond_dim,
        diffusion_step_embed_dim=256,
        down_dims=[256,512,1024],
        kernel_size=5,
        n_groups=8
        ):
        """
        input_dim: Dim of actions.
        global_cond_dim: Dim of global conditioning applied with FiLM
          in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
        diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
        down_dims: Channel size for each UNet level.
          The length of this array determines numebr of levels.
        kernel_size: Conv kernel size
        n_groups: Number of groups for GroupNorm
        """

        super().__init__()
        all_dims = [input_dim] + list(down_dims)
        start_dim = down_dims[0]

        dsed = diffusion_step_embed_dim
        diffusion_step_encoder = nn.Sequential(
            SinusoidalPosEmb(dsed),
            nn.Linear(dsed, dsed * 4),
            nn.Mish(),
            nn.Linear(dsed * 4, dsed),
        )
        cond_dim = dsed + global_cond_dim

        in_out = list(zip(all_dims[:-1], all_dims[1:]))
        mid_dim = all_dims[-1]
        self.mid_modules = nn.ModuleList([
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
        ])

        down_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (len(in_out) - 1)
            down_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_in, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_out, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Downsample1d(dim_out) if not is_last else nn.Identity()
            ]))

        up_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (len(in_out) - 1)
            up_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_out*2, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_in, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Upsample1d(dim_in) if not is_last else nn.Identity()
            ]))

        final_conv = nn.Sequential(
            Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
            nn.Conv1d(start_dim, input_dim, 1),
        )

        self.diffusion_step_encoder = diffusion_step_encoder
        self.up_modules = up_modules
        self.down_modules = down_modules
        self.final_conv = final_conv

        print("number of parameters: {:e}".format(
            sum(p.numel() for p in self.parameters()))
        )

    def forward(self,
            sample: torch.Tensor,
            timestep: Union[torch.Tensor, float, int],
            global_cond=None):
        """
        x: (B,T,input_dim)
        timestep: (B,) or int, diffusion step
        global_cond: (B,global_cond_dim)
        output: (B,T,input_dim)
        """
        # (B,T,C)
        sample = sample.moveaxis(-1,-2)
        # (B,C,T)

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])

        global_feature = self.diffusion_step_encoder(timesteps)

        if global_cond is not None:
            global_feature = torch.cat([
                global_feature, global_cond
            ], axis=-1)

        x = sample
        h = []
        for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            h.append(x)
            x = downsample(x)

        for mid_module in self.mid_modules:
            x = mid_module(x, global_feature)

        for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            x = upsample(x)

        x = self.final_conv(x)

        # (B,C,T)
        x = x.moveaxis(-1,-2)
        # (B,T,C)
        return x

## Evaluation

In [None]:
#=====================================CHANGE=========================================
env_id = 'PickCube-v1'
#env_id = 'StackCube-v1'
#env_id = 'PegInsertionSide-v2'
#env_id = 'PlugCharger-v2'
#env_id = 'PushCube-v1'
obs_mode = 'state_dict'
control_mode = 'pd_joint_delta_pos'

pred_horizon = 16
obs_horizon = 2
action_horizon = 8

#======================================CHANGE========================================

task_id = {
    'PickCube-v1': 0.0,
    'StackCube-v1': 0.1,
    'PegInsertionSide-v2': 0.2,
    'PlugCharger-v2': 0.3,
    'PushCube-v1': 0.4
}

base_path = '/content/drive/MyDrive/Data'
generated_path = f'{base_path}/Generated/{env_id}/motionplanning'
checkpoints_path = f'{base_path}/Checkpoints/{env_id}'
results_path = f'{base_path}/Results/{env_id}'

train_dataset_path = f'{generated_path}/training.{obs_mode}.{control_mode}.h5'
val_dataset_path = f'{generated_path}/validation.{obs_mode}.{control_mode}.h5'
model_path = f'{checkpoints_path}/model.pt'
loss_path = f'{results_path}/loss.npz'
plot_path = f'{results_path}/plot.png'
animation_path = f'{results_path}/animation.gif'

obs_dim = 40
action_dim = 8
print("obs_dim:", obs_dim)
print("action_dim:", action_dim)

# create network object
noise_pred_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon
)

state_dict = torch.load(model_path, map_location='cuda')
noise_pred_net.load_state_dict(state_dict['model_state_dict'])
stats = state_dict['stats']
print('Pretrained weights loaded.')


In [None]:
env = gym.make(env_id, obs_mode=obs_mode, control_mode=control_mode, render_mode='rgb_array')

max_steps = 400

num_episodes = 50
mean_success = 0 
mean_reward = 0
rewards = []
csv_file = f"{results_path}/results.csv"
with open(csv_file, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['Episode', 'Max Reward', 'Success'])
    print(f"Opened file {csv_file} for writing.")

    with tqdm(range(num_episodes), desc='Epoch') as episodes:

        for episode in episodes:
            
            # reset 
            obs, info = env.reset()
            obs = get_observations(obs)
            obs = convert_observation(obs, task_id[env_id])

            # save observations
            obs_deque = collections.deque([obs] * obs_horizon, maxlen=obs_horizon)

            # save visualization
            imgs = []
            rewards = []
            done = False
            step_idx = 0
            unsuccessful = False


            with tqdm(total=max_steps, desc="Eval", leave=False) as pbar:
                while not done:
                    B = 1
                    # stack the last obs_horizon (2) number of observations
                    obs_seq = np.stack(obs_deque)
                
                    nobs = normalize_batch({'obs': torch.tensor(obs_seq, dtype=torch.float32)}, min_vals, max_vals, exclude_features)
                
                    # device transfer
                    #nobs = torch.from_numpy(nobs).to(device, dtype=torch.float32)
                    nobs= nobs.to(device)

                    # infer action
                    with torch.no_grad():
                        # reshape observation to (B,obs_horizon*obs_dim)
                        obs_cond = nobs.unsqueeze(0).flatten(start_dim=1)

                        # initialize action from Guassian noise
                        noisy_action = torch.randn(
                            (B, pred_horizon, action_dim), device=device)
                        naction = noisy_action

                        # init scheduler
                        noise_scheduler.set_timesteps(num_diffusion_iters)

                        for k in noise_scheduler.timesteps:
                            # predict noise
                            noise_pred = ema_noise_pred_net(
                                sample=naction,
                                timestep=k,
                                global_cond=obs_cond
                            )

                            # inverse diffusion step (remove noise)
                            naction = noise_scheduler.step(
                                model_output=noise_pred,
                                timestep=k,
                                sample=naction
                            ).prev_sample

                    # unnormalize action
                    naction = naction.detach().to('cpu').numpy()
                    # (B, pred_horizon, action_dim)
                    action_pred = naction[0] # we dont have to denormalize the action

                    # only take action_horizon number of actions
                    start = obs_horizon - 1
                    end = start + action_horizon
                    action = action_pred[start:end,:]

                    # execute action_horizon number of steps
                    # without replanning
                    for i in range(len(action)):
                        # stepping env
                        obs, reward, done, _, info = env.step(action[i])

                        # process observation
                        # From the observation dictionary, we concatenate all the observations
                        # as done in the training data
                        obs = get_observations(obs)
                        obs = convert_observation(obs, task_id[env_id])

                        # save observations
                        obs_deque.append(obs)

                        # and reward/vis
                        rewards.append(reward)
                        imgs.append(env.render())

                        # update progress bar
                        step_idx += 1
                        pbar.update(1)
                        pbar.set_postfix(reward=reward)
                        if step_idx > max_steps:
                          
                            done = True
                            unsuccessful = True
                        if done:
                            break
            
            if not unsuccessful:
                mean_success += 1
            mean_reward += max(rewards)
            writer.writerow([episode + 1, max(rewards), int(not unsuccessful)])
            episodes.set_postfix(
                reward=mean_reward / (episode + 1),
                success=mean_success / (episode + 1)
            )

            
        

    print("Reward: ", mean_reward / num_episodes)
    print("Success: ", mean_success/num_episodes)