In [1]:
import logging

from rlduels.src.primitives.trajectory_pair import Transition, Trajectory, TrajectoryPair, NDArray

from rlduels.src.database.database_manager import MongoDBManager

from rlduels.src.create_video import create_video_from_pair

from rlduels.src.env_wrapper import EnvWrapper, GymWrapper

logging.basicConfig(level=logging.DEBUG,
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')

In [2]:
import gymnasium as gym

def run_cartpole():
    # Create the CartPole environment
    env = GymWrapper.create_env(name='CartPole-v1')
    print(f"Type of env: {type(env)}")

    # Define the number of episodes you want to run
    num_episodes = 5

    trajectories = []

    for episode in range(num_episodes):
        # Reset the environment for a new episode
        seed = 42
        observation, info = env.reset(seed=42)
        print("Observation:", observation)
        total_reward = 0
        done = False

        transitions = []

        while not done:
            env.render()

            action = env.sample_action()

            next_observation, reward, terminated, truncated, info = env.step(action)

            # Record the transition
            transitions.append(Transition(
                state=NDArray(array=observation),
                action=NDArray(array=action),
                reward=reward,
                terminated=terminated,
                truncated=truncated,
                next_state=NDArray(array=next_observation)
            ))

            observation = next_observation
            total_reward += reward
            done = terminated or truncated
        
        trajectories.append(Trajectory(
            env_name = env.name,
            information={'seed': [42]},
            transitions=transitions
        ))

        if done:
            print(f"Episode {episode + 1}: Total reward = {total_reward}")

    # Close the environment
    env.close()

    # Return trajectories for further analysis if necessary
    return trajectories

In [3]:
trajs = run_cartpole()

x = TrajectoryPair(trajectory1=trajs[0], trajectory2=trajs[1])

x.env_name

Type of env: <class 'rlduels.src.env_wrapper.GymWrapper'>
Observation: [ 0.0273956  -0.00611216  0.03585979  0.0197368 ]
Episode 1: Total reward = 9.0
Observation: [ 0.0273956  -0.00611216  0.03585979  0.0197368 ]
Episode 2: Total reward = 30.0
Observation: [ 0.0273956  -0.00611216  0.03585979  0.0197368 ]
Episode 3: Total reward = 10.0
Observation: [ 0.0273956  -0.00611216  0.03585979  0.0197368 ]
Episode 4: Total reward = 15.0
Observation: [ 0.0273956  -0.00611216  0.03585979  0.0197368 ]
Episode 5: Total reward = 28.0


  gym.logger.warn(


'CartPole-v1'

In [4]:
create_video_from_pair(x)

Creating env!
{'CartPole-v1': <rlduels.src.env_wrapper.GymWrapper object at 0x786e5d7cb910>}
Env created
Action:  ('action', NDArray(array=array(0))) 
type:  <class 'tuple'>


AssertionError: ('action', NDArray(array=array(0))) (<class 'tuple'>) invalid

In [None]:
db = MongoDBManager()

In [None]:
print(db.add_entry(x))
print(db.find_entry(x))
print(db.delete_entry(x))
print(db.find_entry(x))

2024-05-06 01:08:00 - DEBUG - {"message": "Server selection started", "selector": "<function writable_server_selector at 0x798a381f7f70>", "operation": "insert", "topologyDescription": "<TopologyDescription id: 663811506766734c6ad25ea9, topology_type: Single, servers: [<ServerDescription ('localhost', 27017) server_type: Standalone, rtt: 0.0012598000175785273>]>", "clientId": {"$oid": "663811506766734c6ad25ea9"}}
2024-05-06 01:08:00 - DEBUG - {"message": "Server selection succeeded", "selector": "<function writable_server_selector at 0x798a381f7f70>", "operation": "insert", "topologyDescription": "<TopologyDescription id: 663811506766734c6ad25ea9, topology_type: Single, servers: [<ServerDescription ('localhost', 27017) server_type: Standalone, rtt: 0.0012598000175785273>]>", "clientId": {"$oid": "663811506766734c6ad25ea9"}, "serverHost": "localhost", "serverPort": 27017}
2024-05-06 01:08:00 - DEBUG - {"clientId": {"$oid": "663811506766734c6ad25ea9"}, "message": "Command started", "comm

('Added 7120a896-3438-4ab2-93a2-ac61a7e16716 to the database', None)
id=UUID('5d6ecda4-c860-432e-ba9c-91bcc601a7b9') trajectory1=Trajectory(env='CartPole-v1', transitions=[Transition(state=NDArray(array=array([-0.0204758 ,  0.03131844, -0.00041203, -0.01910891])), action=NDArray(array=array(0)), reward=1.0, terminated=False, truncated=False, next_state=NDArray(array=array([-0.01984943, -0.1637976 , -0.00079421,  0.273444  ]))), Transition(state=NDArray(array=array([-0.01984943, -0.1637976 , -0.00079421,  0.273444  ])), action=NDArray(array=array(0)), reward=1.0, terminated=False, truncated=False, next_state=NDArray(array=array([-0.02312538, -0.35890821,  0.00467467,  0.56587631]))), Transition(state=NDArray(array=array([-0.02312538, -0.35890821,  0.00467467,  0.56587631])), action=NDArray(array=array(1)), reward=1.0, terminated=False, truncated=False, next_state=NDArray(array=array([-0.03030355, -0.16385216,  0.0159922 ,  0.27466977]))), Transition(state=NDArray(array=array([-0.0303035

In [None]:
db.close_db()

2024-05-06 01:08:00 - DEBUG - Deleting every entry from the database.
2024-05-06 01:08:00 - INFO - Database closed successfully.


In [None]:
import json

serialized_data = x.json()

data_to_store = json.loads(serialized_data)

data_to_store['_id'] = data_to_store.pop('id')

print(data_to_store)

{'trajectory1': {'env': 'CartPole-v1', 'transitions': [{'state': {'array': [-0.020475801080465317, 0.031318437308073044, -0.0004120297380723059, -0.019108913838863373]}, 'action': {'array': 0}, 'reward': 1.0, 'terminated': False, 'truncated': False, 'next_state': {'array': [-0.01984943263232708, -0.16379760205745697, -0.0007942080264911056, 0.2734439969062805]}}, {'state': {'array': [-0.01984943263232708, -0.16379760205745697, -0.0007942080264911056, 0.2734439969062805]}, 'action': {'array': 0}, 'reward': 1.0, 'terminated': False, 'truncated': False, 'next_state': {'array': [-0.023125384002923965, -0.35890820622444153, 0.004674671683460474, 0.565876305103302]}}, {'state': {'array': [-0.023125384002923965, -0.35890820622444153, 0.004674671683460474, 0.565876305103302]}, 'action': {'array': 1}, 'reward': 1.0, 'terminated': False, 'truncated': False, 'next_state': {'array': [-0.030303549021482468, -0.16385215520858765, 0.015992198139429092, 0.2746697664260864]}}, {'state': {'array': [-0.0

In [None]:
import gymnasium

print(gymnasium.envs.registry.keys())

dict_keys(['CartPole-v0', 'CartPole-v1', 'MountainCar-v0', 'MountainCarContinuous-v0', 'Pendulum-v1', 'Acrobot-v1', 'phys2d/CartPole-v0', 'phys2d/CartPole-v1', 'phys2d/Pendulum-v0', 'LunarLander-v2', 'LunarLanderContinuous-v2', 'BipedalWalker-v3', 'BipedalWalkerHardcore-v3', 'CarRacing-v2', 'Blackjack-v1', 'FrozenLake-v1', 'FrozenLake8x8-v1', 'CliffWalking-v0', 'Taxi-v3', 'tabular/Blackjack-v0', 'tabular/CliffWalking-v0', 'Reacher-v2', 'Reacher-v4', 'Pusher-v2', 'Pusher-v4', 'InvertedPendulum-v2', 'InvertedPendulum-v4', 'InvertedDoublePendulum-v2', 'InvertedDoublePendulum-v4', 'HalfCheetah-v2', 'HalfCheetah-v3', 'HalfCheetah-v4', 'Hopper-v2', 'Hopper-v3', 'Hopper-v4', 'Swimmer-v2', 'Swimmer-v3', 'Swimmer-v4', 'Walker2d-v2', 'Walker2d-v3', 'Walker2d-v4', 'Ant-v2', 'Ant-v3', 'Ant-v4', 'Humanoid-v2', 'Humanoid-v3', 'Humanoid-v4', 'HumanoidStandup-v2', 'HumanoidStandup-v4', 'GymV26Environment-v0', 'GymV21Environment-v0', 'Adventure-v0', 'AdventureDeterministic-v0', 'AdventureNoFrameskip-v0