In [1]:
import torch

from flow_policy.pusht.dp_state_notebook.all import (
    PushTEnv,
    ConditionalUnet1D,
    Rollout,
)
from flow_policy.pusht.dataset import PushTStateDatasetWithNextObsAsAction
from flow_policy.pusht.utils import show_gif
from flow_policy.pusht.sfp import StreamingFlowPolicyPositionOnly

## Parameters

In [2]:
obs_horizon = 2
action_horizon = 8

max_rollout_steps = 200

obs_dim = 5
action_dim = 2

In [3]:
# Load trained weights
ckpt_path = "../../../../models/pusht_sfp_obs_200ep.pth"
velocity_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon,
    use_linear_up_down_sampling=True,
)
policy = StreamingFlowPolicyPositionOnly(
    velocity_net=velocity_net,
    action_dim=action_dim,
    device='cuda',
)
state_dict = torch.load(ckpt_path, map_location='cuda')
policy.load_state_dict(state_dict)
policy.cuda()
print('Pretrained weights loaded.')

# Load dataset for stats
dataset = PushTStateDatasetWithNextObsAsAction(
    pred_horizon=policy.pred_horizon.item(),
    obs_horizon=obs_horizon,
    action_horizon=action_horizon
)

number of parameters: 6.371482e+07
Pretrained weights loaded.


  state_dict = torch.load(ckpt_path, map_location='cuda')


In [4]:
env = PushTEnv()

# use a seed >200 to avoid initial states seen in the training dataset
env.seed(1)

score, imgs = Rollout(
    env,
    policy,
    policy_kwargs = dict(
        num_actions = 1 + action_horizon,  # 1 + 8
        integration_steps_per_action = 6,
    ),
    stats = dataset.stats,
    max_steps = max_rollout_steps,
    obs_horizon = obs_horizon,
    action_horizon = action_horizon,
    device = 'cuda',
)

# print out the maximum target coverage
print('Score: ', score)

# visualize
show_gif(imgs)


Eval PushTStateEnv:   0%|          | 0/200 [00:00<?, ?it/s]

Score:  0.8664485929094118
