# Cart Pole Balancer

In [103]:
import numpy as np
import plotly.graph_objects as go

In [104]:
GRAVITY = 9.8

In [105]:
class CartPole:
    def __init__(self, graphical=True):        
        # Properties
        self.pole_length = 0.5
        self.pole_mass = 0.5
        self.cart_mass = 0.5
        self.mu_cart = 0.001 # Coefficient of friction between the cart and the track
        self.mu_pole = 0.001 # Coefficient of friction between the pole and the cart
    
        # State
        self.cart_displacement = 0.0
        self.cart_velocity = 0.0
        self.pole_angle = np.pi
        self.pole_angular_velocity = 0.0

        # Simulation
        self.simulation_dt = 0.004 # Time step for each simulation increment
        self.max_force = 20 # Used to limit the force magnitude

        # Rendering
        self.graphical = graphical
        self.cart_width = 1
        self.cart_height = 0.2
        self.rendered_pole_length = 4
        if self.graphical:
            self.render(new_plot=True)


    @property
    def state(self):
        return np.array([self.cart_displacement, self.cart_velocity, 
                         self.pole_angle, self.pole_angular_velocity])

    def set_state(self, state):
        self.cart_displacement = state[0]
        self.cart_velocity = state[1]
        self.pole_angle = self.remap_angle(state[2])
        self.pole_angular_velocity = state[3]

    def reset_state(self):
        self.set_state([0, 0, np.pi, 0])

    def scale_force(self, force):
        return self.max_force * np.tanh(force/self.max_force)

    def remap_angle(self, angle):
        while angle <= -np.pi or angle > np.pi:
            if angle <= -np.pi:
                angle += 2*np.pi
            elif angle > np.pi:
                angle -= 2*np.pi
        return angle

    def update(self, input_force=0, number_of_steps=50):
        force = self.scale_force(input_force)

        for step in range(number_of_steps):
            s = np.sin(self.pole_angle)
            c = np.cos(self.pole_angle)
            m = 4.0 * (self.cart_mass + self.pole_mass) - 3.0 * self.pole_mass * (c**2)

            cart_acceleration = ((2 * (self.pole_length * self.pole_mass * (self.pole_angular_velocity**2) * s 
                                    + 2 * (force - self.mu_cart * self.cart_velocity))
                                - 3 * self.pole_mass * GRAVITY * c * s
                                + 6 * self.mu_pole * self.pole_angular_velocity * c / self.pole_length)
                                / m)

            pole_angular_acceleration = ((-3 * c * (2.0/self.pole_length)
                                        * (self.pole_length / 2.0 * self.pole_mass * (self.pole_angular_velocity**2) * s
                                            + force-self.mu_c*self.cart_velocity)
                                        + 6 * (self.cart_mass + self.pole_mass) / (self.pole_mass * self.pole_length)
                                        * (self.pole_mass * GRAVITY * s - 2 / self.pole_length * self.mu_pole * self.pole_angular_velocity))
                                        / m)
            
            # Doing the velocity updates before the displacement updates means that the result is more stable
            # This is known as the semi-implicit Euler, and is a 'symplectic' method
            self.cart_velocity += self.simulation_dt * cart_acceleration
            self.pole_angular_velocity += self.simulation_dt * pole_angular_acceleration
            self.pole_angle    += self.simulation_dt * self.pole_angular_velocity
            self.cart_location += self.simulation_dt * self.cart_velocity
        
        if self.graphical:
            self.render()

    def render(self, new_plot=False):
        if new_plot:
            self.figure = go.FigureWidget()
            self.figure.update_xaxes(range=[-10, 10], fixedrange=True)
            self.figure.update_yaxes(range=[-5, 5], fixedrange=True)  
            self.figure.add_shape(name="cart",
                                  type="rect",
                                  x0=self.cart_displacement - self.cart_width/2,
                                  y0=-self.cart_height/2,
                                  x1=self.cart_displacement + self.cart_width/2,
                                  y1=self.cart_height/2) 
            self.figure.add_shape(name="pole",
                                  type="line",
                                  x0=self.cart_displacement,
                                  y0=0,
                                  x1=self.rendered_pole_length * np.sin(self.pole_angle),
                                  y1=self.rendered_pole_length * np.cos(self.pole_angle))
            
        else:
            self.figure.update_shapes(patch={"x0": self.cart_displacement - self.cart_width/2,
                                             "x1": self.cart_displacement + self.cart_width/2},
                                      selector={"name": "cart"})
            self.figure.update_shapes(patch={"x0": self.cart_displacement,
                                             "x1": self.rendered_pole_length * np.sin(self.pole_angle),
                                             "y1": self.rendered_pole_length * np.cos(self.pole_angle)},
                                      selector={"name": "pole"})

In [106]:
cp = CartPole()

In [107]:
cp.set_state([1, 0, 0, 0])
cp.render()