Skip to content

Commit

Permalink
Clean up, moved enum definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
JacopoPan committed Apr 16, 2023
1 parent f7846ed commit 4b68901
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 102 deletions.
1 change: 0 additions & 1 deletion gym_pybullet_drones/envs/BaseAviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from gym_pybullet_drones.utils.enums import DroneModel, Physics, ImageType



class BaseAviary(gym.Env):
"""Base class for "drone aviary" Gym environments."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from gymnasium import spaces

from gym_pybullet_drones.envs.BaseAviary import BaseAviary
from gym_pybullet_drones.utils.enums import DroneModel, Physics
from gym_pybullet_drones.envs.single_agent_rl.BaseSingleAgentAviary import ActionType, ObservationType
from gym_pybullet_drones.utils.utils import nnlsRPM
from gym_pybullet_drones.utils.enums import DroneModel, Physics, ActionType, ObservationType
from gym_pybullet_drones.control.DSLPIDControl import DSLPIDControl

class BaseMultiagentAviary(BaseAviary):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,9 @@
import pybullet as p

from gym_pybullet_drones.envs.BaseAviary import BaseAviary
from gym_pybullet_drones.utils.enums import DroneModel, Physics, ImageType
from gym_pybullet_drones.utils.utils import nnlsRPM
from gym_pybullet_drones.utils.enums import DroneModel, Physics, ImageType, ActionType, ObservationType
from gym_pybullet_drones.control.DSLPIDControl import DSLPIDControl

class ActionType(Enum):
"""Action type enumeration class."""
RPM = "rpm" # RPMS
PID = "pid" # PID control
VEL = "vel" # Velocity input (using PID control)
ONE_D_RPM = "one_d_rpm" # 1D (identical input to all motors) with RPMs
ONE_D_PID = "one_d_pid" # 1D (identical input to all motors) with PID control

################################################################################

class ObservationType(Enum):
"""Observation type enumeration class."""
KIN = "kin" # Kinematic information (pose, linear and angular velocities)
RGB = "rgb" # RGB camera capture in each drone's POV

################################################################################

class BaseSingleAgentAviary(BaseAviary):
"""Base single drone environment class for reinforcement learning."""
Expand Down
17 changes: 16 additions & 1 deletion gym_pybullet_drones/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,19 @@ class ImageType(Enum):
SEG = 2 # Segmentation by object id
BW = 3 # Black and white

################################################################################
################################################################################

class ActionType(Enum):
"""Action type enumeration class."""
RPM = "rpm" # RPMS
PID = "pid" # PID control
VEL = "vel" # Velocity input (using PID control)
ONE_D_RPM = "one_d_rpm" # 1D (identical input to all motors) with RPMs
ONE_D_PID = "one_d_pid" # 1D (identical input to all motors) with PID control

################################################################################

class ObservationType(Enum):
"""Observation type enumeration class."""
KIN = "kin" # Kinematic information (pose, linear and angular velocities)
RGB = "rgb" # RGB camera capture in each drone's POV
79 changes: 0 additions & 79 deletions gym_pybullet_drones/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,82 +52,3 @@ def str2bool(val):
return False
else:
raise argparse.ArgumentTypeError("[ERROR] in str2bool(), a Boolean value is expected")

################################################################################

def nnlsRPM(thrust,
x_torque,
y_torque,
z_torque,
counter,
max_thrust,
max_xy_torque,
max_z_torque,
a,
inv_a,
b_coeff,
gui=False
):
"""Non-negative Least Squares (NNLS) RPMs from desired thrust and torques.
This function uses the NNLS implementation in `scipy.optimize`.
Parameters
----------
thrust : float
Desired thrust along the drone's z-axis.
x_torque : float
Desired drone's x-axis torque.
y_torque : float
Desired drone's y-axis torque.
z_torque : float
Desired drone's z-axis torque.
counter : int
Simulation or control iteration, only used for printouts.
max_thrust : float
Maximum thrust of the quadcopter.
max_xy_torque : float
Maximum torque around the x and y axes of the quadcopter.
max_z_torque : float
Maximum torque around the z axis of the quadcopter.
a : ndarray
(4, 4)-shaped array of floats containing the motors configuration.
inv_a : ndarray
(4, 4)-shaped array of floats, inverse of a.
b_coeff : ndarray
(4,1)-shaped array of floats containing the coefficients to re-scale thrust and torques.
gui : boolean, optional
Whether a GUI is active or not, only used for printouts.
Returns
-------
ndarray
(4,)-shaped array of ints containing the desired RPMs of each propeller.
"""
#### Check the feasibility of thrust and torques ###########
if gui and thrust < 0 or thrust > max_thrust:
print("[WARNING] iter", counter, "in utils.nnlsRPM(), unfeasible thrust {:.2f} outside range [0, {:.2f}]".format(thrust, max_thrust))
if gui and np.abs(x_torque) > max_xy_torque:
print("[WARNING] iter", counter, "in utils.nnlsRPM(), unfeasible roll torque {:.2f} outside range [{:.2f}, {:.2f}]".format(x_torque, -max_xy_torque, max_xy_torque))
if gui and np.abs(y_torque) > max_xy_torque:
print("[WARNING] iter", counter, "in utils.nnlsRPM(), unfeasible pitch torque {:.2f} outside range [{:.2f}, {:.2f}]".format(y_torque, -max_xy_torque, max_xy_torque))
if gui and np.abs(z_torque) > max_z_torque:
print("[WARNING] iter", counter, "in utils.nnlsRPM(), unfeasible yaw torque {:.2f} outside range [{:.2f}, {:.2f}]".format(z_torque, -max_z_torque, max_z_torque))
B = np.multiply(np.array([thrust, x_torque, y_torque, z_torque]), b_coeff)
sq_rpm = np.dot(inv_a, B)
#### NNLS if any of the desired ang vel is negative ########
if np.min(sq_rpm) < 0:
sol, res = nnls(a,
B,
maxiter=3*a.shape[1]
)
if gui:
print("[WARNING] iter", counter, "in utils.nnlsRPM(), unfeasible squared rotor speeds, using NNLS")
print("Negative sq. rotor speeds:\t [{:.2f}, {:.2f}, {:.2f}, {:.2f}]".format(sq_rpm[0], sq_rpm[1], sq_rpm[2], sq_rpm[3]),
"\t\tNormalized: [{:.2f}, {:.2f}, {:.2f}, {:.2f}]".format(sq_rpm[0]/np.linalg.norm(sq_rpm), sq_rpm[1]/np.linalg.norm(sq_rpm), sq_rpm[2]/np.linalg.norm(sq_rpm), sq_rpm[3]/np.linalg.norm(sq_rpm)))
print("NNLS:\t\t\t\t [{:.2f}, {:.2f}, {:.2f}, {:.2f}]".format(sol[0], sol[1], sol[2], sol[3]),
"\t\t\tNormalized: [{:.2f}, {:.2f}, {:.2f}, {:.2f}]".format(sol[0]/np.linalg.norm(sol), sol[1]/np.linalg.norm(sol), sol[2]/np.linalg.norm(sol), sol[3]/np.linalg.norm(sol)),
"\t\tResidual: {:.2f}".format(res))
sq_rpm = sol
return np.sqrt(sq_rpm)

0 comments on commit 4b68901

Please sign in to comment.