In [None]:
from diffusion.utils import *
from corl.algorithms import sac_n
from corl.shared.buffer import *
import numpy as np
import pathlib
import d3rlpy
from d3rlpy.dataset import MDPDataset
from d3rlpy.metrics.scorer import evaluate_on_environment

base_agent_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/data'
base_synthetic_data_path = '/Users/shubhankar/Developer/compositional-rl-synth-data/cluster_results/diffusion/monolithic_seed0_train98_1'
base_results_folder = '/Users/shubhankar/Developer/compositional-rl-synth-data/local_results/offline_learning'

In [None]:
def eval_d3rlpy_actor(env: gym.Env, agent, n_episodes: int, seed: int) -> np.ndarray:
    env.seed(seed)
    episode_rewards = []

    for _ in range(n_episodes):
        state, done = env.reset(), False
        episode_reward = 0.0

        while not done:
            action = agent.predict([state])[0]  # d3rlpy expects a batch input
            state, reward, done, _ = env.step(action)
            episode_reward += reward

        episode_rewards.append(episode_reward)

    return np.array(episode_rewards)

def scorer_wrapper(env, agent):
    def score_fn(algo, dataset):
        return eval_d3rlpy_actor(env, agent, n_episodes=10, seed=42).mean()
    return score_fn

In [None]:
data_type = 'agent'

config = sac_n.TrainConfig()
config.seed = 0
config.n_episodes = 10
config.batch_size = 1024

synthetic_run_id = ''
mode = ''  # train/test

robot = 'Panda'
obj = 'Box'
obst = 'None'
subtask = 'Push'

env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=True, ignore_done=False)
agent_dataset = load_single_composuite_dataset(base_path=base_agent_data_path, 
                                               dataset_type='expert', 
                                               robot=robot, obj=obj, 
                                               obst=obst, task=subtask)
agent_dataset, _ = remove_indicator_vectors(env.modality_dims, transitions_dataset(agent_dataset))

integer_dims, constant_dims = identify_special_dimensions(agent_dataset['observations'])
print('Integer dimensions:', integer_dims)
print('Constant dimensions:', constant_dims)

if data_type == 'synthetic':
    synthetic_dataset = load_single_synthetic_dataset(base_path=os.path.join(base_synthetic_data_path, synthetic_run_id, mode), 
                                                      robot=robot, obj=obj, 
                                                      obst=obst, task=subtask)
    synthetic_dataset = process_special_dimensions(synthetic_dataset, integer_dims, constant_dims)

In [None]:
agent_dataset['observations'].shape

In [None]:
# synthetic_dataset['observations'].shape

In [None]:
base_results_path = pathlib.Path(base_results_folder)
idx = 1
while (base_results_path / f"offline_learning_{data_type}_{idx}").exists():
    idx += 1
results_folder = base_results_path / f"offline_learning_{data_type}_{idx}"
results_folder.mkdir(parents=True, exist_ok=True)

config.checkpoints_path = results_folder

In [None]:
if data_type == 'agent':
    dataset = agent_dataset
    num_samples = int(dataset['observations'].shape[0])
elif data_type == 'synthetic':
    dataset = synthetic_dataset
    num_samples = int(dataset['observations'].shape[0])
print("Samples:", num_samples)

In [None]:
env = composuite.make(robot, obj, obst, subtask, use_task_id_obs=False, has_renderer=True, ignore_done=False)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
state_mean, state_std = sac_n.compute_mean_std(dataset["observations"], eps=1e-3)
env = sac_n.wrap_env(env, state_mean=state_mean, state_std=state_std)

In [None]:
from d3rlpy.algos import SAC
from d3rlpy.models.encoders import VectorEncoder, VectorEncoderFactory
from d3rlpy.metrics.scorer import evaluate_on_environment
import torch.nn as nn
import torch

class CustomLayerNormEncoder(VectorEncoder):
    def __init__(self, observation_shape, action_size=None, hidden_units=(256, 256, 256)):
        super().__init__(observation_shape)
        # For critics in SAC, we need to handle actions
        if action_size is not None:
            input_dim = observation_shape[0] + action_size
        else:
            input_dim = observation_shape[0]
            
        layers = []
        last_dim = input_dim
        for hidden_dim in hidden_units:
            layers.append(nn.Linear(last_dim, hidden_dim))
            layers.append(nn.LayerNorm(hidden_dim))
            layers.append(nn.ReLU())
            last_dim = hidden_dim
        self.encoder = nn.Sequential(*layers)
        self._feature_size = last_dim
        self._action_size = action_size
        
        # also set without underscore for backward compatibility
        self.action_size = action_size

    def forward(self, x, action=None):
        if self._action_size is not None and action is not None:
            x = torch.cat([x, action], dim=1)
        return self.encoder(x)

    def get_feature_size(self):
        return self._feature_size

class CustomEncoderFactory(VectorEncoderFactory):
    def __init__(self, hidden_units=(256, 256, 256)):
        self.hidden_units = hidden_units

    def create(self, observation_shape):
        return CustomLayerNormEncoder(observation_shape, hidden_units=self.hidden_units)
        
    def create_with_action(self, observation_shape, action_size):
        return CustomLayerNormEncoder(
            observation_shape,
            action_size=action_size,
            hidden_units=self.hidden_units
        )

    def get_params(self, deep=False):
        return {'hidden_units': self.hidden_units}
    
data_dict = dataset
dataset = MDPDataset(
    observations=data_dict['observations'],
    actions=data_dict['actions'],
    rewards=data_dict['rewards'],
    terminals=data_dict['terminals'],
)

# Create SAC with the custom encoder
encoder_factory = CustomEncoderFactory(hidden_units=(256, 256, 256))
sac = SAC(
    actor_encoder_factory=encoder_factory,
    critic_encoder_factory=encoder_factory
)

# Use the existing dataset directly
sac.fit(
    dataset,
    n_steps=50000,
    eval_episodes=10,
    scorers={
        'env': evaluate_on_environment(env)
    }
)

In [None]:
state = dataset.observations[0]
state_batch = np.expand_dims(state, axis=0)

# actor's action π(s)
pi_action = sac.predict(state_batch)[0]
q_pi = sac.predict_value(state_batch, np.expand_dims(pi_action, 0))

# true dataset action a
actual_action = dataset.actions[0]
q_sa = sac.predict_value(state_batch, np.expand_dims(actual_action, 0))

print("Q(s, π(s)):", q_pi)
print("Q(s, a):", q_sa)

In [None]:
reward = dataset.rewards[0]

print("Reward from dataset (r):", reward)
print("Q(s, π(s)):", q_pi)
print("Q(s, a):", q_sa)