In [1]:
import wandb
import composuite
from diffusion.utils import *
from diffusion.elucidated_diffusion import Trainer
from diffusion.train_diffuser import SimpleDiffusionGenerator

gin.parse_config_file("/Users/shubhankar/Developer/compositional-rl-synth-data/config/diffusion.gin")

base_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_results_folder = '/Users/shubhankar/Developer/compositional-rl-synth-data/results'

In [None]:
dataset_type = 'expert'

robot = 'IIWA'
obj = 'Box'
obst = 'None'
task = 'PickPlace'

results_folder = os.path.join(base_results_folder, robot + '_' + obj + '_' + obst + '_' + task)

dataset = load_single_composuite_dataset(base_path=base_data_path, 
                                         dataset_type=dataset_type, 
                                         robot=robot, obj=obj, 
                                         obst=obst, task=task)
dataset = transitions_dataset(dataset)
inputs = make_inputs(dataset)
print('Before removing task indicators:', inputs.shape)
env = composuite.make(robot, obj, obst, task, use_task_id_obs=True, ignore_done=False)
inputs, indicators = remove_indicator_vectors(inputs, env)
print('After removing task indicators:', inputs.shape, indicators.shape)

In [None]:
idx = 0

task_vector = indicators[idx, :].reshape(1, -1)

labels = ['Object', 'Robot', 'Obstacle', 'Subtask']
plt.figure(figsize=(10, 2))
plt.imshow(task_vector, cmap="viridis", aspect="auto")
plt.colorbar(label="Value")
plt.xticks(ticks=[2, 6, 10, 14], labels=labels, ha='right')
plt.yticks([])
plt.show()

In [4]:
inputs = torch.from_numpy(inputs).float()
indicators = torch.from_numpy(indicators).float()
dataset = torch.utils.data.TensorDataset(inputs, indicators)

In [None]:
diffusion = construct_diffusion_model(inputs=inputs, cond_dim=16)

In [None]:
wandb_project = 'offline_rl_diffusion'
wandb_entity = ''
wandb_group = 'diffusion_training'

wandb.init(
    project=wandb_project,
    entity=wandb_entity,
    group=wandb_group,
    name=results_folder.split('/')[-1],
)

In [None]:
trainer = Trainer(diffusion, dataset, results_folder=results_folder)
trainer.train()

In [None]:
task_indicator = get_task_indicator(robot, obj, obst, task)
env = composuite.make(robot, obj, obst, task, use_task_id_obs=False, ignore_done=False)
generator = SimpleDiffusionGenerator(env=env, ema_model=trainer.ema.ema_model, sample_batch_size=1000)
observations, actions, rewards, next_observations, terminals = generator.sample(num_samples=1000, cond=task_indicator)

In [41]:
np.savez_compressed(
    os.path.join(results_folder, 'samples.npz'),
    observations=observations,
    actions=actions,
    rewards=rewards,
    next_observations=next_observations,
    terminals=terminals
)