In [1]:
%load_ext autoreload
%autoreload 2

In [13]:
import sys, os
work_dir = "/workspace"
os.chdir(work_dir)
sys.path.append("/workspace/src")
import argparse
import gymnasium as gym
import pandas as pd

from utilities.managers import ConfigManager
from datetime import datetime
from src.training.setup import setup_env
from configs.offline_training.cartpole_v1.cartpole_v1_test import CONFIG

In [14]:
EXTENSIONS = {
        "csv": ["csv", ".csv"],
        "pkl": ["pkl", "pickle", ".pkl", ".pickle"]
    }

In [16]:
def save_df(df: pd.DataFrame, file_path: str):
    extension = os.path.splitext(file_path)[-1].lower()
    if extension in EXTENSIONS["csv"]:
        df.to_csv(file_path)
    elif extension in EXTENSIONS["pkl"]:
        df.to_pickle(file_path)

In [18]:
args_dict = {"env": "CartPole-v1"}
args = argparse.Namespace(**args_dict)

In [19]:
config_path = "configs/offline_training/cartpole_v1/cartpole_v1_test.py"
# config = ConfigManager(config_path=config_path).config
config = CONFIG

In [20]:
env = setup_env(config=config, args=args)

In [21]:
def prepare_dataset(num_of_episode, base_save_dir, extension="pkl"):
    date_format = "%Y_%m_%d_%H%M_%S"
    dirname = datetime.now().strftime(date_format)
    save_dir = os.path.join(base_save_dir, dirname)
    os.makedirs(save_dir)

    
    df = pd.DataFrame()
    n_episode = 0
    observation, info = env.reset()
    while n_episode < num_of_episode:
        action = env.action_space.sample()  # agent policy that uses the observation and info
        next_observation, reward, terminated, truncated, info = env.step(action)
        done = terminated | truncated
        transition = {"state": observation, "action": action, "next_state": next_observation, "reward": reward, "done": done}
        if df.empty:
            df = pd.DataFrame(transition)
        else:
            df = pd.concat((df, pd.DataFrame(transition)), axis=0)

        if done:
            observation, info = env.reset()
            save_df(df, os.path.join(save_dir, f"{n_episode}.{extension}"))
            df = pd.DataFrame()
            n_episode += 1
        else:
            observation = next_observation

    env.close()


In [24]:
prepare_dataset(10000, config["dirs"]["processed_dataset"]["csv"], extension="csv")