In [1]:
import os
import composuite
from CORL.algorithms.offline.td3_bc import *
from diffusion.utils import *
import numpy as np

def get_weights_norm(model):
    total_norm = 0.0
    for param in model.parameters():
        if param.requires_grad:
            total_norm += param.norm(2).item() ** 2
    return total_norm ** 0.5


base_agent_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_synthetic_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/cluster_results/diffusion'

base_results_folder = '/Users/shubhankar/Developer/compositional-rl-synth-data/local_results/offline_learning'

In [2]:
config = TrainConfig()

data_type = 'agent'

run = f'offline_learning_{data_type}_12'  # agent

if data_type == 'synthetic':
    synthetic_run_id = 'cond_diff_20'
    mode = 'train'
    run = f'offline_learning_{data_type}_6'  # synthetic

checkpoint = 'checkpoint_5000.pt'

In [3]:
checkpoint_path = os.path.join(base_results_folder, run, 'checkpoint_5000.pt')
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'), weights_only=True)

In [None]:
robot = 'Kinova3'
obj = 'Hollowbox'
obst = 'None'
subtask = 'Trashcan'

if data_type == 'agent':
    env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=True, ignore_done=False)
    dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
                                             dataset_type='expert', 
                                             robot=robot, obj=obj, 
                                             obst=obst, task=subtask)
    dataset, _ = remove_indicator_vectors(env.modality_dims, transitions_dataset(dataset))

if data_type == 'synthetic':
    dataset = load_single_synthetic_dataset(base_path=os.path.join(base_synthetic_data_path, synthetic_run_id, mode), 
                                            robot=robot, obj=obj, 
                                            obst=obst, task=subtask)

In [None]:
env = composuite.make(
    robot=robot,
    obj=obj,
    obstacle=obst,
    task=subtask,
    has_renderer=True,
    ignore_done=True,
)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
env = wrap_env(env, state_mean=state_mean, state_std=state_std)

max_action = float(env.action_space.high[0])

In [None]:
print(state_mean.mean(), state_std.mean())

In [None]:
actor = Actor(state_dim, action_dim, max_action, hidden_dim=config.network_width, n_hidden=config.network_depth).to(config.device)
print('Before:', get_weights_norm(actor))
actor.load_state_dict(checkpoint['actor'])
print('After:', get_weights_norm(actor))

In [10]:
state = env.reset()
env.viewer.set_camera(camera_id=3)

low, high = env.action_spec

# do visualization
for _ in range(1000):
    action = actor.act(state)
    state, _, _, _ = env.step(action)
    env.render()