# Load SOAR Data (for training)
This notebook is a minimal example of how to load the RLDS SOAR data (e.g. for downstream training)

In [None]:
"""
Make sure the Colab has access to the SOAR repo
"""
!git clone https://github.com/rail-berkeley/soar.git

In [None]:

"""
1. Download a minimal SOAR dataset that's small to be used for testing

In this notebook we load a small dummy dataset for speed. If you wish to load the full dataset, 
use the download script in this directory to download the full dataset. Then it can be loaded
in the same way, changing the path to the saved dataset.
"""
SAVE_DIR = "dummy_soar_data"
!cat soar/soar_data/test_dataset_urls.txt | while read url; do wget -P "dummy_soar_data" "$url"; done

In [None]:
"""
2. Import the Dataloader class
"""
import subprocess

# install jaxrl_m if it it not already installed
# the package is located in model_training/jaxrl_m
try:
    import jaxrl_m
except ImportError:
    print("local jaxrl_m package not installed, trying to install now")
    package_path = 'soar/model_training'

    # install jaxrl_m
    result = subprocess.run(['pip', 'install', '-e', package_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
    # Print the standard output and error
    print("Pip install Output:\n", result.stdout)
    if result.stderr:
        print("Pip install Error:\n", result.stderr)

    # add the package path to sys.path (for ipynb)
    import sys
    if package_path not in sys.path:
        sys.path.append(package_path)
    
    # install the requirements for the package
    result = subprocess.run(['pip', 'install', '-r', f"{package_path}/requirements.txt"])
    # Print the standard output and error
    if result.stderr:
        print("Pip install Error:\n", result.stderr)
    
# check that installation was successful
try:
    import jaxrl_m
except ImportError:
    print("Failed to correctly install jaxrl_m package")
    print("Please manually install the package with `pip install -e soar/model_training`")
    raise

# import dataloader class
from jaxrl_m.data.dataset import WidowXDataset

In [None]:
"""
3. Load the dataset
"""
train_data = WidowXDataset(
    [SAVE_DIR],
    seed=0,
    batch_size=16,
    train=True,
    load_language=True,
    goal_relabeling_strategy="uniform",  # specify goal relabeling to load languages
    goal_relabeling_kwargs={"reached_proportion": 0.5, "discount": 0.98},
)

In [None]:
"""
4. Inspect an example batch
"""
!pip install matplotlib
import matplotlib.pyplot as plt

train_data_iter = train_data.iterator()
example_batch = next(train_data_iter)

print(f"Example batch keys: {example_batch.keys()}")
print(f"Actions shape: {example_batch['actions'].shape}, which is (batch_size, action_dim)")
print(f"Observations shape: {example_batch['observations']['image'].shape}, which is (batch_size, observation_dim)")
print(f"Proprio shape: {example_batch['observations']['proprio'].shape}, which is (batch_size, proprio_dim)")

plt.figure(figsize=(10, 10))
for i in range(16):
    plt.subplot(4, 4, i+1)
    plt.imshow(example_batch['observations']['image'][i])
    plt.axis('off')
plt.show()

In [None]:
"""
5. Load only the success/failure split of the SOAR-data
If you wish, you could only load certain splits of the dataset. 
"""
success_data = WidowXDataset(
    [SAVE_DIR],
    data_splits=["success"],
    seed=0,
    batch_size=16,
    train=True,
)

failure_data = WidowXDataset(
    [SAVE_DIR],
    data_splits=["failure"],
    seed=0,
    batch_size=16,
    train=True,
)

## More Advanced Usage
For more advanced usage, check out the arguments of the `BridgeDataset` class at [model_training/jaxrl_m/data/dataset.py](https://github.com/rail-berkeley/soar/blob/main/model_training/jaxrl_m/data/dataset.py).

An example of how this dataset is used is in `model_training/experiments/train.py`, and the configuration and arguments of the datasets are in `model_training/experiments/configs/train_config.py` and `model_training/experiments/configs/data_config.py`.