In [1]:
import torch

from flow_policy.pusht.dp_state_notebook.all import (
    PushTEnv,
    ConditionalUnet1D,
    DiffusionPolicy,
    Rollout,
)
from flow_policy.pusht.dataset import PushTStateDatasetWithNextObsAsAction
from flow_policy.pusht.utils import show_gif

## Parameters

In [2]:
pred_horizon = 16
obs_horizon = 2
action_horizon = 8

num_diffusion_iters = 100
max_rollout_steps = 200

obs_dim = 5
action_dim = 2

In [3]:
# Load dataset for stats
dataset = PushTStateDatasetWithNextObsAsAction(
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon
)

# Load trained weights
ckpt_path = "../../../../models/pusht_dp_obs_100ep.pth"
noise_pred_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon
)
state_dict = torch.load(ckpt_path, map_location='cuda')
noise_pred_net.load_state_dict(state_dict)
noise_pred_net.cuda()
print('Pretrained weights loaded.')

policy = DiffusionPolicy(
    noise_pred_net=noise_pred_net,
    num_diffusion_iters=num_diffusion_iters,
    pred_horizon=pred_horizon,
    action_dim=action_dim,
    device='cuda',
)

number of parameters: 6.535322e+07


  state_dict = torch.load(ckpt_path, map_location='cuda')


Pretrained weights loaded.


In [4]:
env = PushTEnv()

# use a seed >200 to avoid initial states seen in the training dataset
env.seed(2)

score, imgs = Rollout(
    env,
    policy,
    stats = dataset.stats,
    max_steps = max_rollout_steps,
    obs_horizon = obs_horizon,
    action_horizon = action_horizon,
    device = 'cuda',
)

# print out the maximum target coverage
print('Score: ', score)

# visualize
show_gif(imgs)


Eval PushTStateEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Score:  0.9837819958284089
