In [3]:
# Change the working directory to the repo root.
!mkdir logs
!git clone https://github.com/philipp01wagner/gym-pybullet-drones.git
%cd gym-pybullet-drones
!python setup.py install
!pip install pybullet
!pip install stable-baselines3[extra]
# Add the repo root to the Python path.
import sys, os
sys.path.append(os.getcwd())
%cd ..
import time
import gym
import numpy as np
import argparse
from stable_baselines3 import A2C, PPO, DDPG, SAC
from stable_baselines3.common.env_checker import check_env
from gym_pybullet_drones.envs.single_agent_rl.StraightFlightAviary import StraightFlightAviary

from gym_pybullet_drones.envs.BaseAviary import DroneModel, Physics
from gym_pybullet_drones.utils.Logger import Logger
from gym_pybullet_drones.utils.utils import sync
from gym.envs.registration import register
from gym_pybullet_drones.utils.utils import sync, str2bool

In [6]:
parser = argparse.ArgumentParser(description='Helix flight script using CtrlAviary or VisionAviary and DSLPIDControl')
parser.add_argument('--drone',              default="ha",       type=DroneModel,    help='Drone model (default: CF2X)', metavar='', choices=DroneModel)
parser.add_argument('--physics',            default="pyb",      type=Physics,       help='Physics updates (default: PYB)', metavar='', choices=Physics)
parser.add_argument('--gui',                default=False,       type=str2bool,      help='Whether to use PyBullet GUI (default: True)', metavar='')
parser.add_argument('--aggregate',          default=True,       type=str2bool,      help='Whether to aggregate physics steps (default: True)', metavar='')
parser.add_argument('--simulation_freq_hz', default=240,        type=int,           help='Simulation frequency in Hz (default: 240)', metavar='')
parser.add_argument('--control_freq_hz',    default=48,         type=int,           help='Control frequency in Hz (default: 48)', metavar='')
parser.add_argument('--duration_sec',       default=12,         type=int,           help='Duration of the simulation in seconds (default: 5)', metavar='')
parser.add_argument('--trajectory',         default=1,          type=int,           help='Trajectory type (default: 1)', metavar='')
parser.add_argument('--wind',               default=False,      type=str2bool,      help='Whether to enable wind (default: False)', metavar='')
parser.add_argument('--record_video',       default=False,      type=str2bool,      help='Whether to record a video (default: False)', metavar='')

ARGS = parser.parse_args(args=[])

In [7]:
H = 1.0
R = .3
INIT_XYZS = np.array([[0, 0, H]])
INIT_RPYS = np.array([[0, 0,  0]])
AGGR_PHY_STEPS = int(ARGS.simulation_freq_hz/ARGS.control_freq_hz) if ARGS.aggregate else 1

register(
    id='straight-flight-aviary-v0',
    entry_point='gym_pybullet_drones.envs.single_agent_rl:StraightFlightAviary',
    reward_threshold=1.0,
    nondeterministic = False,
    )

#### Check the environment's spaces ########################
env = StraightFlightAviary(drone_model=ARGS.drone,
                         initial_xyzs=INIT_XYZS,
                         initial_rpys=INIT_RPYS,
                         physics=ARGS.physics,
                         freq=ARGS.simulation_freq_hz,
                         aggregate_phy_steps=AGGR_PHY_STEPS,
                         gui=ARGS.gui,
                         record=ARGS.record_video
                         )

[INFO] BaseAviary.__init__() loaded parameters from the drone's .urdf:
[INFO] m 0.500000, L 0.175000,
[INFO] ixx 0.002300, iyy 0.002300, izz 0.004000,
[INFO] kf 0.000000, km 0.000000,
[INFO] t2w 2.000000, max_speed_kmh 50.000000,
[INFO] gnd_eff_coeff 0.000000, prop_radius 0.000000,
[INFO] drag_xy_coeff 0.000000, drag_z_coeff 0.000000,
[INFO] dw_coeff_1 0.000000, dw_coeff_2 0.000000, dw_coeff_3 1.000000




In [11]:
print("[INFO] Action space:", env.action_space)
print("[INFO] Observation space:", env.observation_space)
check_env(env,
            warn=True,
            skip_render_check=True
            )

[INFO] Action space: Box([-1. -1.], [1. 1.], (2,), float32)
[INFO] Observation space: Box([-1. -1.  0. -1. -1. -1. -1. -1. -1. -1. -1. -1.], [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], (12,), float32)


In [None]:
%reload_ext tensorboard

In [None]:
%tensorboard --logdir=./logs/

In [13]:
model = A2C("MlpPolicy",
            env,
            verbose=1,
            tensorboard_log="./logs/"
            )
model.learn(total_timesteps=100000)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ./a2c_StraightFlight_tensorboard/A2C_6
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 1781      |
|    iterations         | 100       |
|    time_elapsed       | 0         |
|    total_timesteps    | 500       |
| train/                |           |
|    entropy_loss       | -2.84     |
|    explained_variance | -6.2e-06  |
|    learning_rate      | 0.0007    |
|    n_updates          | 99        |
|    policy_loss        | -155      |
|    std                | 0.999     |
|    value_loss         | 4.22e+03  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/             

-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 1922      |
|    iterations         | 1300      |
|    time_elapsed       | 3         |
|    total_timesteps    | 6500      |
| train/                |           |
|    entropy_loss       | -2.84     |
|    explained_variance | -2.38e-07 |
|    learning_rate      | 0.0007    |
|    n_updates          | 1299      |
|    policy_loss        | -155      |
|    std                | 1         |
|    value_loss         | 3.97e+03  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 1929      |
|    iterations         | 1400      |
|    time_elapsed       | 3         |
|    total_t

-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 1970      |
|    iterations         | 2500      |
|    time_elapsed       | 6         |
|    total_timesteps    | 12500     |
| train/                |           |
|    entropy_loss       | -2.89     |
|    explained_variance | -1.19e-07 |
|    learning_rate      | 0.0007    |
|    n_updates          | 2499      |
|    policy_loss        | -206      |
|    std                | 1.03      |
|    value_loss         | 3.75e+03  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 1973      |
|    iterations         | 2600      |
|    time_elapsed       | 6         |
|    total_t

-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 1988      |
|    iterations         | 3700      |
|    time_elapsed       | 9         |
|    total_timesteps    | 18500     |
| train/                |           |
|    entropy_loss       | -3        |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 3699      |
|    policy_loss        | -182      |
|    std                | 1.08      |
|    value_loss         | 3.55e+03  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 1988      |
|    iterations         | 3800      |
|    time_elapsed       | 9         |
|    total_t

-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 1998      |
|    iterations         | 4900      |
|    time_elapsed       | 12        |
|    total_timesteps    | 24500     |
| train/                |           |
|    entropy_loss       | -2.96     |
|    explained_variance | -1.19e-07 |
|    learning_rate      | 0.0007    |
|    n_updates          | 4899      |
|    policy_loss        | -136      |
|    std                | 1.06      |
|    value_loss         | 3.35e+03  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 1998      |
|    iterations         | 5000      |
|    time_elapsed       | 12        |
|    total_t

-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2004      |
|    iterations         | 6100      |
|    time_elapsed       | 15        |
|    total_timesteps    | 30500     |
| train/                |           |
|    entropy_loss       | -2.96     |
|    explained_variance | 1.19e-07  |
|    learning_rate      | 0.0007    |
|    n_updates          | 6099      |
|    policy_loss        | 173       |
|    std                | 1.06      |
|    value_loss         | 2.62e+04  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2004      |
|    iterations         | 6200      |
|    time_elapsed       | 15        |
|    total_t

-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2007      |
|    iterations         | 7300      |
|    time_elapsed       | 18        |
|    total_timesteps    | 36500     |
| train/                |           |
|    entropy_loss       | -2.95     |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 7299      |
|    policy_loss        | -131      |
|    std                | 1.06      |
|    value_loss         | 2.96e+03  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2008      |
|    iterations         | 7400      |
|    time_elapsed       | 18        |
|    total_t

-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2010      |
|    iterations         | 8500      |
|    time_elapsed       | 21        |
|    total_timesteps    | 42500     |
| train/                |           |
|    entropy_loss       | -2.95     |
|    explained_variance | 1.19e-07  |
|    learning_rate      | 0.0007    |
|    n_updates          | 8499      |
|    policy_loss        | -135      |
|    std                | 1.06      |
|    value_loss         | 2.78e+03  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2010      |
|    iterations         | 8600      |
|    time_elapsed       | 21        |
|    total_t

-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2012      |
|    iterations         | 9700      |
|    time_elapsed       | 24        |
|    total_timesteps    | 48500     |
| train/                |           |
|    entropy_loss       | -3.04     |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 9699      |
|    policy_loss        | -184      |
|    std                | 1.11      |
|    value_loss         | 2.6e+03   |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2012      |
|    iterations         | 9800      |
|    time_elapsed       | 24        |
|    total_t

-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2014      |
|    iterations         | 10900     |
|    time_elapsed       | 27        |
|    total_timesteps    | 54500     |
| train/                |           |
|    entropy_loss       | -3.1      |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 10899     |
|    policy_loss        | -140      |
|    std                | 1.14      |
|    value_loss         | 2.43e+03  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2014      |
|    iterations         | 11000     |
|    time_elapsed       | 27        |
|    total_t

-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2015      |
|    iterations         | 12100     |
|    time_elapsed       | 30        |
|    total_timesteps    | 60500     |
| train/                |           |
|    entropy_loss       | -3.21     |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 12099     |
|    policy_loss        | -126      |
|    std                | 1.21      |
|    value_loss         | 2.25e+03  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2015      |
|    iterations         | 12200     |
|    time_elapsed       | 30        |
|    total_t

-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2016      |
|    iterations         | 13300     |
|    time_elapsed       | 32        |
|    total_timesteps    | 66500     |
| train/                |           |
|    entropy_loss       | -3.33     |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 13299     |
|    policy_loss        | -153      |
|    std                | 1.28      |
|    value_loss         | 2.1e+03   |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2016      |
|    iterations         | 13400     |
|    time_elapsed       | 33        |
|    total_t

-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2017      |
|    iterations         | 14500     |
|    time_elapsed       | 35        |
|    total_timesteps    | 72500     |
| train/                |           |
|    entropy_loss       | -3.33     |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 14499     |
|    policy_loss        | -166      |
|    std                | 1.28      |
|    value_loss         | 1.95e+03  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2018      |
|    iterations         | 14600     |
|    time_elapsed       | 36        |
|    total_t

-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2018      |
|    iterations         | 15700     |
|    time_elapsed       | 38        |
|    total_timesteps    | 78500     |
| train/                |           |
|    entropy_loss       | -3.43     |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 15699     |
|    policy_loss        | -133      |
|    std                | 1.34      |
|    value_loss         | 1.8e+03   |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2018      |
|    iterations         | 15800     |
|    time_elapsed       | 39        |
|    total_t

-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2020      |
|    iterations         | 16900     |
|    time_elapsed       | 41        |
|    total_timesteps    | 84500     |
| train/                |           |
|    entropy_loss       | -3.52     |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 16899     |
|    policy_loss        | -137      |
|    std                | 1.41      |
|    value_loss         | 1.66e+03  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2020      |
|    iterations         | 17000     |
|    time_elapsed       | 42        |
|    total_t

-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2022      |
|    iterations         | 18100     |
|    time_elapsed       | 44        |
|    total_timesteps    | 90500     |
| train/                |           |
|    entropy_loss       | -3.55     |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 18099     |
|    policy_loss        | -146      |
|    std                | 1.43      |
|    value_loss         | 1.52e+03  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2022      |
|    iterations         | 18200     |
|    time_elapsed       | 44        |
|    total_t

-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2024      |
|    iterations         | 19300     |
|    time_elapsed       | 47        |
|    total_timesteps    | 96500     |
| train/                |           |
|    entropy_loss       | -3.62     |
|    explained_variance | 0         |
|    learning_rate      | 0.0007    |
|    n_updates          | 19299     |
|    policy_loss        | -135      |
|    std                | 1.48      |
|    value_loss         | 1.39e+03  |
-------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 242       |
|    ep_rew_mean        | -4.84e+03 |
| time/                 |           |
|    fps                | 2024      |
|    iterations         | 19400     |
|    time_elapsed       | 47        |
|    total_t

<stable_baselines3.a2c.a2c.A2C at 0x28f840250>

In [None]:
env = StraightFlightAviary(gui=True,
                        record=False
                        )
logger = Logger(logging_freq_hz=int(env.SIM_FREQ/env.AGGR_PHY_STEPS),
                num_drones=1
                )
obs = env.reset()
start = time.time()
for i in range(20*env.SIM_FREQ):
    action, _states = model.predict(obs,
                                        deterministic=True
                                        )

    print(action)
    obs, reward, done, info = env.step(action)
    logger.log(drone=0,
               timestamp=i/env.SIM_FREQ,
               state=np.hstack([obs[0:3], np.zeros(4), obs[3:15],  np.resize(action, (2))]),
               control=np.zeros(12)
               )
    if i%env.SIM_FREQ == 0:
        env.render()
        print(done)
    sync(i, start, env.TIMESTEP)
    if done:
        obs = env.reset()
env.close()
logger.plot()