In [None]:
import wandb
import composuite
from diffusion.utils import *
from diffusion.elucidated_diffusion import Trainer
from diffusion.train_diffuser import SimpleDiffusionGenerator

gin.parse_config_file("/Users/shubhankar/Developer/compositional-rl-synth-data/config/diffusion.gin")

base_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_results_folder = '/Users/shubhankar/Developer/compositional-rl-synth-data/local_results'

In [None]:
dataset_type = 'expert'

robot = 'IIWA'
obj = 'Box'
obst = 'None'
task = 'PickPlace'

results_folder = os.path.join(base_results_folder, robot + '_' + obj + '_' + obst + '_' + task)

dataset = load_single_composuite_dataset(base_path=base_data_path, 
                                         dataset_type=dataset_type, 
                                         robot=robot, obj=obj, 
                                         obst=obst, task=task)
dataset = transitions_dataset(dataset)
print('Before removing task indicators:', dataset['observations'].shape)
env = composuite.make(robot, obj, obst, task, use_task_id_obs=True, ignore_done=False)
dataset, indicators = remove_indicator_vectors(env.modality_dims, dataset)
print('After removing task indicators:', dataset['observations'].shape)
inputs = make_inputs(dataset)

In [None]:
idx = 0

task_vector = indicators[idx, :].reshape(1, -1)

labels = ['Object', 'Robot', 'Obstacle', 'Subtask']
plt.figure(figsize=(10, 2))
plt.imshow(task_vector, cmap="viridis", aspect="auto")
plt.colorbar(label="Value")
plt.xticks(ticks=[2, 6, 10, 14], labels=labels, ha='right')
plt.yticks([])
plt.show()

In [None]:
inputs = torch.from_numpy(inputs).float()
indicators = torch.from_numpy(indicators).float()
dataset = torch.utils.data.TensorDataset(inputs, indicators)

In [None]:
diffusion = construct_diffusion_model(inputs=inputs, cond_dim=16)

In [None]:
wandb_project = 'compositional_diffusion_test'
wandb_entity = ''
wandb_group = 'diffusion_training'

wandb.init(
    project=wandb_project,
    entity=wandb_entity,
    group=wandb_group,
    name=results_folder.split('/')[-1],
)

In [None]:
trainer = Trainer(diffusion, dataset, results_folder=results_folder)
trainer.train()

In [None]:
@gin.configurable
class SimpleDiffusionGenerator:
    def __init__(
            self,
            env: gym.Env,
            ema_model,
            num_sample_steps: int = 128,
            sample_batch_size: int = 100000,
    ):
        self.env = env
        self.diffusion = ema_model
        self.diffusion.eval()
        # Clamp samples if normalizer is MinMaxNormalizer
        self.clamp_samples = isinstance(self.diffusion.normalizer, MinMaxNormalizer)
        self.num_sample_steps = num_sample_steps
        self.sample_batch_size = sample_batch_size
        print(f'Sampling using: {self.num_sample_steps} steps, {self.sample_batch_size} batch size.')

    def sample(
            self,
            num_samples: int,
            cond: None,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        
        assert num_samples % self.sample_batch_size == 0, 'num_samples must be a multiple of sample_batch_size'

        if cond is not None:
            cond = torch.from_numpy(cond).float().to(self.diffusion.device)
            cond = cond.unsqueeze(0).expand(self.sample_batch_size, -1)

        num_batches = num_samples // self.sample_batch_size
        observations = []
        actions = []
        rewards = []
        next_observations = []
        terminals = []

        for i in range(num_batches):
            print(f'Generating split {i + 1} of {num_batches}.')
            sampled_outputs = self.diffusion.sample(
                batch_size=self.sample_batch_size,
                num_sample_steps=self.num_sample_steps,
                clamp=self.clamp_samples,
                cond=cond
            )
            sampled_outputs = sampled_outputs.cpu().numpy()

            # Split samples into (s, a, r, s') format
            transitions = split_diffusion_samples(sampled_outputs, self.env)
            if len(transitions) == 4:
                obs, act, rew, next_obs = transitions
                terminal = np.zeros_like(next_obs[:, 0])
            else:
                obs, act, rew, next_obs, terminal = transitions
            observations.append(obs)
            actions.append(act)
            rewards.append(rew)
            next_observations.append(next_obs)
            terminals.append(terminal)
        observations = np.concatenate(observations, axis=0)
        actions = np.concatenate(actions, axis=0)
        rewards = np.concatenate(rewards, axis=0)
        next_observations = np.concatenate(next_observations, axis=0)
        terminals = np.concatenate(terminals, axis=0)

        return observations, actions, rewards, next_observations, terminals

In [None]:
task_indicator = get_task_indicator(robot, obj, obst, task)
env = composuite.make(robot, obj, obst, task, use_task_id_obs=False, ignore_done=False)
generator = SimpleDiffusionGenerator(env=env, ema_model=trainer.ema.ema_model)
observations, actions, rewards, next_observations, terminals = generator.sample(num_samples=100000, cond=task_indicator)

In [None]:
np.savez_compressed(
    os.path.join(results_folder, 'samples.npz'),
    observations=observations,
    actions=actions,
    rewards=rewards,
    next_observations=next_observations,
    terminals=terminals
)