In [None]:
import d3rlpy
import numpy as np
import matplotlib.pyplot as plt

from d3rlpy.dataset import Transition

In [10]:
def _to_transitions(dataset,):
    rets = []
    num_data = dataset['observations'].shape[0]
    observation_shape = dataset['observations'][0].shape
    action_size = dataset['actions'][0].shape[0]
    for i in range(num_data):
        transition = Transition(
            observation_shape=observation_shape,
            action_size=action_size,
            observation=dataset['observations'][i],
            action=dataset['actions'][i],
            reward=dataset['rewards'][i],
            next_observation=dataset['next_observations'][i],
            terminal=dataset['terminals'][i],
        )

        # set pointer to the next transition
        rets.append(transition)
    return rets

In [None]:
def merge_dictionary(list_of_dict):
    merged_data = {}

    for d in list_of_dict:
        for k, v in d.items():
            if k not in merged_data.keys():
                merged_data[k] = [v]
            else:
                merged_data[k].append(v)

    for k, v in merged_data.items():
        merged_data[k] = np.concatenate(merged_data[k])

    return merged_data
dataset_name = 'halfcheetah-medium-replay-v0'
data = np.load(f'./gta/{dataset_name}.npz', allow_pickle=True)
config_dict = data['config'].item()
dataset = data['data'].squeeze()

metadata = {}

metadata['diffusion_horizon'] = config_dict['construct_diffusion_model']['denoising_network']['horizon']
metadata['diffusion_backbone'] = config_dict['construct_diffusion_model']['denoising_network']['_target_'].split('.')[-1]
metadata['conditioned'] = True if config_dict['construct_diffusion_model']['denoising_network']['cond_dim'] != 0 else False
metadata['guidance_target_multiple'] = config_dict['SimpleDiffusionGenerator']['amplify_returnscale']
metadata['noise_level'] = config_dict['SimpleDiffusionGenerator']['noise_level']
dataset = merge_dictionary([*dataset])

In [11]:
transition_list = _to_transitions(dataset)

NameError: name 'dataset' is not defined

In [None]:
np.savez(
    f'./gta/{dataset_name}-dataset.npz', 
    observations=dataset['observations'],
    actions=dataset['actions'],
    rewards=dataset['rewards'],
    next_observations=dataset['next_observations'],
    terminals=dataset['terminals'],
)

In [13]:
dataset_name = 'halfcheetah-medium-replay-v0'
test = np.load(f'./gta/{dataset_name}-dataset.npz')

In [None]:
original_dataset = d3rlpy.datasets.get_d4rl('halfcheetah-medium-replay-v0')

In [None]:
exps = ['halfcheetah-medium-replay-v0', 'hopper-medium-replay-v0', 'walker2d-medium-replay-v0']
datasets = [d3rlpy.datasets.get_d4rl(exp)[0] for exp in exps]
exp_datasets = dict(zip(exps, datasets))

In [None]:
datasets[0].compute_stats()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 5))
for ax, (exp_name, dataset) in zip(axes,exp_datasets.items()):
    episodes = dataset.episodes
    returns = [episode.compute_return() for episode in episodes]
    ax.plot(returns)
    ax.set_title(exp_name)

In [None]:
gpu = 0
encoder = d3rlpy.models.encoders.VectorEncoderFactory([256, 256, 256])

exp_name, dataset = list(exp_datasets.items())[0]
if "medium-v0" in exp_name:
    conservative_weight = 10.0
else:
    conservative_weight = 5.0
cql = d3rlpy.algos.CQL(actor_learning_rate=1e-4,
                           critic_learning_rate=3e-4,
                           temp_learning_rate=1e-4,
                           actor_encoder_factory=encoder,
                           critic_encoder_factory=encoder,
                           batch_size=256,
                           n_action_samples=10,
                           alpha_learning_rate=0.0,
                           conservative_weight=conservative_weight,
                           use_gpu=gpu)

In [None]:
cql.build_with_dataset(dataset)
cql.load_model('../d3rlpy_logs/CQL_halfcheetah-medium-replay-v0_1_trim00_20240529201518/model_190000.pt')

In [None]:
fig, axes = plt.subplots(21, 5, figsize=(20, 20))
for i, episode in enumerate(dataset.episodes):
    qs = cql.predict_value(episode.observations, episode.actions)
    axes[i // 5][i % 5].plot(qs)