In [1]:
import torch
import stable_baselines3
import sys
import numpy as np

print("python version:", sys.version)
print("stable_baselines3 version:", stable_baselines3.__version__)
print("torch version:", torch.__version__)
print("cuda available:", torch.cuda.is_available())
print("cuda version:", torch.version.cuda)
print("cudnn version:", torch.backends.cudnn.version())

# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

# set torch default device
torch.set_default_device(device)

python version: 3.11.6 | packaged by conda-forge | (main, Oct  3 2023, 10:40:35) [GCC 12.3.0]
stable_baselines3 version: 2.2.1
torch version: 2.1.0
cuda available: True
cuda version: 12.1
cudnn version: 8902
device: cuda


# Equation of Motion 3D Quadcopter

In [12]:
# reload sympy
from sympy import *

# Equations of motion 3D quadcopter from https://arxiv.org/pdf/2304.13460.pdf

# w1,w2,w3,w4 are the motor speeds normalized to [-1,1]
# u1,u2,u3,u4 are the motor commands normalized to [-1,1]

t = symbols('t')
state = symbols('x y z v_x v_y v_z phi theta psi p q r w1 w2 w3 w4')
x,y,z,vx,vy,vz,phi,theta,psi,p,q,r,w1,w2,w3,w4 = state
control = symbols('U_1 U_2 U_3 U_4')    # normalized motor commands between [-1,1]
u1,u2,u3,u4 = control


g = 9.81

params = symbols('k_x, k_y, k_w, k_p, k_q1, k_q2, k_r1, k_r2, tau, k, w_min, w_max, tR, tV')
k_x, k_y, k_w, k_p, k_q1, k_q2, k_r1, k_r2, tau, k, w_min, w_max, tR, tV = params

# Rotation matrix 
Rx = Matrix([[1, 0, 0], [0, cos(phi), -sin(phi)], [0, sin(phi), cos(phi)]])
Ry = Matrix([[cos(theta), 0, sin(theta)], [0, 1, 0], [-sin(theta), 0, cos(theta)]])
Rz = Matrix([[cos(psi), -sin(psi), 0], [sin(psi), cos(psi), 0], [0, 0, 1]])
R = Rz*Ry*Rx

# Body velocity
vbx, vby, vbz = R.T@Matrix([vx,vy,vz])

# normalized motor speeds to rad/s
W1 = (w1+1)/2*(w_max-w_min) + w_min
W2 = (w2+1)/2*(w_max-w_min) + w_min
W3 = (w3+1)/2*(w_max-w_min) + w_min
W4 = (w4+1)/2*(w_max-w_min) + w_min

# motor commands scaled to [0,1]
U1 = (u1+1)/2
U2 = (u2+1)/2
U3 = (u3+1)/2
U4 = (u4+1)/2

# first order delay:
# the steadystate rpm motor response to the motor command U is described by:
# Wc = (w_max-w_min)*sqrt(k U**2 + (1-k)*U) + w_min
Wc1 = (w_max-w_min)*sqrt(k*U1**2 + (1-k)*U1) + w_min
Wc2 = (w_max-w_min)*sqrt(k*U2**2 + (1-k)*U2) + w_min
Wc3 = (w_max-w_min)*sqrt(k*U3**2 + (1-k)*U3) + w_min
Wc4 = (w_max-w_min)*sqrt(k*U4**2 + (1-k)*U4) + w_min

# rad/s
d_W1 = (Wc1 - W1)/tau
d_W2 = (Wc2 - W2)/tau
d_W3 = (Wc3 - W3)/tau
d_W4 = (Wc4 - W4)/tau

# normalized motor speeds d/dt[W - w_min)/(w_max-w_min)*2 - 1]
d_w1 = d_W1/(w_max-w_min)*2
d_w2 = d_W2/(w_max-w_min)*2
d_w3 = d_W3/(w_max-w_min)*2
d_w4 = d_W4/(w_max-w_min)*2

# Thrust and Drag
T = -k_w*(W1**2 + W2**2 + W3**2 + W4**2)
Dx = -k_x*vbx*(W1+W2+W3+W4)
Dy = -k_y*vby*(W1+W2+W3+W4)

# Moments
Mx = k_p*(-W1**2-W2**2+W3**2+W4**2)
My = k_q1*(W1**2+W3**2) + k_q2*(W2**2+W4**2)
Mz = k_r1*(-W1+W2+W3-W4) + k_r2*(-d_W1+d_W2-d_W3+d_W4)

# Dynamics
d_x = vx
d_y = vy
d_z = vz

d_vx, d_vy, d_vz = Matrix([0,0,g]) + R@Matrix([Dx, Dy,T])

d_phi   = p + q*sin(phi)*tan(theta) + r*cos(phi)*tan(theta)
d_theta = q*cos(phi) - r*sin(phi)
d_psi   = q*sin(phi)/cos(theta) + r*cos(phi)/cos(theta)

d_p     = Mx
d_q     = My
d_r     = Mz

# State space model
f = [d_x, d_y, d_z, d_vx, d_vy, d_vz, d_phi, d_theta, d_psi, d_p, d_q, d_r, d_w1, d_w2, d_w3, d_w4]

# lambdify
f_func_ = lambdify((Array(state), Array(control), Array(params)), Array(f), 'numpy')
f_func = lambda x,u,p: f_func_(x.T,u.T,p.T).T

# trajectory reference
tx = tR*cos(t*tV/tR)
ty = tR*sin(t*tV/tR)
tz = -1+0.0001*cos(t*tV/tR)
tvx = diff(tx,t)
tvy = diff(ty,t)
tvz = diff(tz,t)
tax = diff(tvx,t)
tay = diff(tvy,t)
taz = diff(tvz,t)
f_t = [tx, ty, tz, tvx, tvy, tvz, tax, tay, taz]
f_traj_ = lambdify((Array([t]), Array(params)), Array(f_t), 'numpy')
f_traj = lambda t,p: f_traj_(t.T,p.T).T

# extra functions
f_acc_ = lambdify((Array(state), Array(params)), Array([d_vx, d_vy, d_vz]), 'numpy')
f_acc = lambda x,p: f_acc_(x.T,p.T).T
get_world_acceleration = f_acc

In [39]:
# parameters based on fit from flight_data/aggressive_cmds2.csv
params_values = {
    'k_x': 6.05e-05,
    'k_y': 6.67e-05,
    'k_w': 2.26e-06,
    'k_p': 6.52e-05,
    'k_q1': -5.16e-05,
    'k_q2': 5.88e-05,
    'k_r1': 5.40e-06,
    'k_r2': 1.97e-03,
    'tau': 0.038,
    'k': 0.95,
    'w_min': 238.48,
    'w_max': 3057.02+238.48,
    'tR': 3.0,      # circle trajectory radius
    'tV': 0.0       # circle trajectory velocity
}

# param ranges (for domain randomization)
params_ranges = {
    'k_x': (5.0e-05, 7.0e-05),
    'k_y': (5.0e-05, 7.0e-05),
    'k_w': (2.0e-06, 3.0e-06),
    'k_p': (5.0e-05, 7.0e-05),
    'k_q1': (-6.0e-05, -4.0e-05),
    'k_q2': (5.0e-05, 7.0e-05),
    'k_r1': (4.0e-06, 6.0e-06),
    'k_r2': (1.0e-03, 3.0e-03),
    'tau': (0.02, 0.05),
    'k': (0., 1.0),
    'w_min': (0.0, 500.0),
    'w_max': (3000.0, 3500.0),
    'tR': (3.0, 3.0),
    'tV': (1.0, 3.0)
}

In [40]:
# test f_dyn with vectorized inputs
num = 2
state_values = np.random.rand(num,16)
control_values = np.random.rand(num,4)
params_values_ = np.array([list(params_values.values())]*num)
print(f_func(state_values, control_values, params_values_))

# test f_traj with vectorized inputs
num = 2
t_values = np.random.rand(num).reshape(num,1)
params_values_ = np.array([list(params_values.values())]*num)
print(f_traj(t_values, params_values_))

[[ 3.50273945e-01  5.96519944e-01  3.13467592e-01 -2.54594911e+01
   1.01236764e+01 -1.34019070e+01  1.44905267e+00  2.05039674e-01
   1.42201547e+00  1.48759768e+02  6.14538031e+01 -4.30536842e+01
   1.65321904e+01  3.68235379e+00  7.30032948e+00  5.85223747e+00]
 [ 7.78385189e-01  3.32688584e-01  3.67625639e-02 -4.17462729e+01
  -1.78029147e+00 -2.52884341e+01  1.11256356e+00 -2.30835226e-01
   1.04955143e+00  4.76645384e+02  2.72425171e+02 -2.45517487e+01
  -6.75718218e+00  1.51739042e+01  1.27966696e+01 -1.72853864e+01]]
[[ 3.      0.     -0.9999 -0.      0.     -0.     -0.     -0.     -0.    ]
 [ 3.      0.     -0.9999 -0.      0.     -0.     -0.     -0.     -0.    ]]


# Define the 3D Quad Trajectory Tracking Environment

In [54]:
# Efficient vectorized version of the environment
from gymnasium import spaces
from stable_baselines3.common.vec_env import VecEnv

class Quadcopter3DTT(VecEnv):
    def __init__(self,
                 num_envs,
                 pause_if_collision=False,
                 params=params_values,
                 domain_randomization=0.0,
                 ):
        
        # Pause if collision
        self.pause_if_collision = pause_if_collision
        
        # Domain randomization
        self.domain_randomization = domain_randomization
        self.params = np.array([list(params.values())]*num_envs, dtype=np.float32)
        self.params_nominal = self.params.copy()

        # action space: [cmd1, cmd2, cmd3, cmd4]
        action_space = spaces.Box(low=-1, high=1, shape=(4,))

        # observation space: pos[W], vel[W], att[eulerB->W], rates[B], rpms
        # [B] = body frame
        # [W] = world frame
        self.state_len = 16
        self.input_hist = 1
        observation_space = spaces.Box(
            low  = np.array([-np.inf]*self.state_len*self.input_hist),
            high = np.array([ np.inf]*self.state_len*self.input_hist)
        )

        # Initialize the VecEnv
        VecEnv.__init__(self, num_envs, observation_space, action_space)

        # world state: pos[W], vel[W], att[eulerB->W], rates[B], rpms, traj_pos
        self.world_states = np.zeros((num_envs,16), dtype=np.float32)
        # trajectory ref: pos[W], vel[W], acc[W]
        self.traj_ref = np.zeros((num_envs, 9), dtype=np.float32)
        # observation state
        self.states = np.zeros((num_envs,self.state_len*self.input_hist), dtype=np.float32)
        # state history tracking
        num_hist = 10
        self.state_hist = np.zeros((num_envs,num_hist,self.state_len), dtype=np.float32)

        # Define any other environment-specific parameters
        self.max_steps = 1200      # Maximum number of steps in an episode
        self.dt = np.float32(0.01) # Time step duration

        self.step_counts = np.zeros(num_envs, dtype=int)
        self.actions = np.zeros((num_envs,4), dtype=np.float32)
        
        self.pause = False

    def update_states(self):
        new_states = np.zeros((self.num_envs,self.state_len), dtype=np.float32)
        new_states[:,0:16] = self.world_states
        
        # update history
        self.state_hist = np.roll(self.state_hist, 1, axis=1)
        self.state_hist[:,0] = new_states
        
        self.states = self.state_hist[:,0:self.input_hist].reshape((self.num_envs,-1))

    def reset_(self, dones):
        num_reset = dones.sum()
        
        # update params (domain randomization)
        self.params[dones] = self.params_nominal[dones] * np.random.uniform(1-self.domain_randomization, 1+self.domain_randomization, size=(num_reset, self.params.shape[1]))
        
        self.traj_ref[dones] = f_traj(np.zeros((num_reset,1)), self.params[dones])
        
        x0 = np.random.uniform(-5.,5., size=(num_reset,)) #+ self.traj_ref[dones,0]
        y0 = np.random.uniform(-5.,5., size=(num_reset,)) #+ self.traj_ref[dones,1]
        z0 = np.random.uniform(-5.,5., size=(num_reset,)) #+ self.traj_ref[dones,2]
                
        vx0 = np.random.uniform(-0.5,0.5, size=(num_reset,))
        vy0 = np.random.uniform(-0.5,0.5, size=(num_reset,))
        vz0 = np.random.uniform(-0.5,0.5, size=(num_reset,))
        
        phi0   = np.random.uniform(-np.pi/9,np.pi/9, size=(num_reset,))
        theta0 = np.random.uniform(-np.pi/9,np.pi/9, size=(num_reset,))
        psi0   = np.random.uniform(-np.pi,np.pi, size=(num_reset,))
        
        p0 = np.random.uniform(-0.1,0.1, size=(num_reset,))
        q0 = np.random.uniform(-0.1,0.1, size=(num_reset,))
        r0 = np.random.uniform(-0.1,0.1, size=(num_reset,))
        
        w10 = np.random.uniform(-1,1, size=(num_reset,))
        w20 = np.random.uniform(-1,1, size=(num_reset,))
        w30 = np.random.uniform(-1,1, size=(num_reset,))
        w40 = np.random.uniform(-1,1, size=(num_reset,))

        self.world_states[dones] = np.stack([x0, y0, z0, vx0, vy0, vz0, phi0, theta0, psi0, p0, q0, r0, w10, w20, w30, w40], axis=1)

        self.step_counts[dones] = np.zeros(num_reset)
        
        # update states
        self.update_states()
        return self.states
    
    def reset(self):
        return self.reset_(np.ones(self.num_envs, dtype=bool))

    def step_async(self, actions):
        self.actions = actions
    
    def step_wait(self):
        # update states
        new_states = self.world_states + self.dt*f_func(self.world_states, self.actions, self.params)
        
        # update step counts        
        self.step_counts += 1
        
        # update trajectory reference
        self.traj_ref = f_traj(self.step_counts.reshape(-1,1)*self.dt, self.params)
        
        # Rewards
        # pos_error = np.linalg.norm(new_states[:,0:3] - self.traj_ref[:,0:3], axis=1)
        # vel_error = np.linalg.norm(new_states[:,3:6] - self.traj_ref[:,3:6], axis=1)
        
        # acc = get_world_acceleration(new_states, self.params)
        # acc_error = np.linalg.norm(acc - self.traj_ref[:,6:9], axis=1)
        
        # rate_error = np.linalg.norm(new_states[:,9:12], axis=1)

        # rewards = 1-0.1*pos_error-0.1*vel_error-0.01*rate_error
                
        pos_error = np.linalg.norm(new_states[:,0:3], axis=1)
        vel_error = np.linalg.norm(new_states[:,3:6], axis=1)
        ang_error = np.linalg.norm(new_states[:,6:9], axis=1)
        rat_error = np.linalg.norm(new_states[:,9:12], axis=1)
        
        pos_error_old = np.linalg.norm(self.world_states[:,0:3], axis=1)
        vel_error_old = np.linalg.norm(self.world_states[:,3:6], axis=1)
        ang_error_old = np.linalg.norm(self.world_states[:,6:9], axis=1)
        rat_error_old = np.linalg.norm(self.world_states[:,9:12], axis=1)
        
        cost_old = pos_error_old + ang_error_old + rat_error_old
        cost_new = pos_error + ang_error + rat_error
        rewards = cost_old - cost_new
        
        # rewards = -0.002*pos_error-0.002*vel_error-0.0001*ang_error-0.0001*rat_error
        # rewards = np.exp(-pos_error-ang_error-0.02*vel_error-0.02*rat_error)
        # rewards = pos_reward + vel_reward + ang_reward + rat_reward
        
        hover_state = (pos_error < 0.3) & (vel_error < 0.3) & (ang_error < 10*np.pi/180) & (rat_error < 10*np.pi/180)
        rewards[hover_state] += 10
        
        # Check out of bounds
        out_of_bounds  = np.any(np.abs(new_states[:,0:3]) > 10, axis=1)     # outside grid abs(x,y)>5
        out_of_bounds |= np.any(np.abs(self.states[:,6:8]) > np.pi, axis=1) # angle limits
        # out_of_bounds |= new_states[:,2] > 0                                # ground collision
        # out_of_bounds |= new_states[:,2] < -5                               # ceiling collision
        out_of_bounds |= np.any(np.abs(new_states[:,9:12]) > 1000, axis=1)  # prevent numerical issues: abs(p,q,r) < 1000
        # out_of_bounds |= pos_error > 1                                      # large position error
        rewards[out_of_bounds] = -10 
        
        # Check number of steps
        max_steps_reached = self.step_counts >= self.max_steps

        
        # Check if the episode is done
        dones = max_steps_reached | out_of_bounds
        self.dones = dones
        
        # Pause if collision
        if self.pause:
            dones = dones & ~dones
            self.dones = dones
        elif self.pause_if_collision:
            # dones = max_steps_reached | final_gate_passed | out_of_bounds
            update = ~dones #~(gate_collision | ground_collision)
            # Update world states
            self.world_states[update] = new_states[update]
            self.update_states()
            # Reset env if done (and update states)
            # self.reset_(dones)
        else:
            # Update world states
            self.world_states = new_states
            # reset env if done (and update states)
            self.reset_(dones)


        # Write info dicts
        infos = [{}] * self.num_envs
        for i in range(self.num_envs):
            if dones[i]:
                infos[i]["terminal_observation"] = self.states[i]
            if max_steps_reached[i]:
                infos[i]["TimeLimit.truncated"] = True
        return self.states, rewards, dones, infos
    
    def close(self):
        pass

    def seed(self, seed=None):
        pass

    def get_attr(self, attr_name, indices=None):
        raise AttributeError()

    def set_attr(self, attr_name, value, indices=None):
        pass

    def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
        pass

    def env_is_wrapped(self, wrapper_class, indices=None):
        return [False]*self.num_envs

    def render(self, mode='human'):
        # Outputs a dict containing all information for rendering
        state_dict = dict(zip(['x','y','z','vx','vy','vz','phi','theta','psi','p','q','r','w1','w2','w3','w4'], self.world_states.T))
        # state_dict['z'] += -2  # offset z for rendering
        # Rescale actions to [0,1] for rendering
        action_dict = dict(zip(['u1','u2','u3','u4'], (np.array(self.actions.T)+1)/2))
        # trajectory
        traj_dict = dict(zip(['traj_x','traj_y','traj_z'], self.traj_ref[:,0:3].T))
        traj_dict = {}
        return {**state_dict, **action_dict, **traj_dict}

# Define Race Track

In [55]:
import importlib
from quadcopter_animation import animation
importlib.reload(animation)


num = 10
env = Quadcopter3DTT(num_envs=num, pause_if_collision=False)

# Run a random agent
env.reset()

done = False
def run():
    global done
    action = np.random.uniform(-1,1, size=(num,4))
    state, reward, done, _ = env.step(action)
    return env.render()

animation.view(run) #, record_steps=1000, show_window=True)

# Train PPO Model

In [56]:
import os
from stable_baselines3 import PPO
from datetime import datetime
from stable_baselines3.common.vec_env import VecMonitor
import importlib
from quadcopter_animation import animation

models_dir = 'models/TT'
log_dir = 'logs/TT'
video_log_dir = 'videos/TT'

if not os.path.exists(models_dir):
    os.makedirs(models_dir)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
if not os.path.exists(video_log_dir):
    os.makedirs(video_log_dir)

# Date and time string for unique folder names
datetime_str = datetime.now().strftime("%Y%m%d-%H%M%S")

# Create the environment
env = Quadcopter3DTT(num_envs=100)
test_env = Quadcopter3DTT(num_envs=10, pause_if_collision=True)

# Wrap the environment in a Monitor wrapper
env = VecMonitor(env)

# MODEL DEFINITION
policy_kwargs = dict(activation_fn=torch.nn.ReLU, net_arch=[dict(pi=[120,120,120], vf=[120,120,120])], log_std_init = 0)
model = PPO(
    "MlpPolicy",
    env,
    policy_kwargs=policy_kwargs,
    verbose=0,
    tensorboard_log=log_dir,
    n_steps=1000,
    batch_size=5000,
    n_epochs=10,
    gamma=0.999
)

print(model.policy)
print(model.num_timesteps)

def animate_policy(model, env, deterministic=False, log_times=False, **kwargs):
    env.reset()
    def run():
        actions, _ = model.predict(env.states, deterministic=deterministic)

        states, rewards, dones, infos = env.step(actions)
        if log_times:
            if rewards[0] == 10:
                print(env.step_counts[0]*env.dt)
        return env.render()
    animation.view(run, **kwargs)

# animate untrained policy (use this to set the recording camera position)
animate_policy(model, test_env)

# training loop saves model every 10 policy rollouts and saves a video animation
def train(model, test_env, log_name, n=10000000000):
    # save every 10 policy rollouts
    TIMESTEPS = model.n_steps*env.num_envs*10
    for i in range(0,n):
        model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name=log_name)
        time_steps = model.num_timesteps
        # save model
        model.save(models_dir + '/' + log_name + '/' + str(time_steps))
        # save policy animation
        animate_policy(
            model,
            test_env,
            record_steps=1200,
            record_file=video_log_dir + '/' + log_name + '/' + str(time_steps) + '.mp4',
            show_window=False
        )

ActorCriticPolicy(
  (features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (pi_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (vf_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (mlp_extractor): MlpExtractor(
    (policy_net): Sequential(
      (0): Linear(in_features=16, out_features=120, bias=True)
      (1): ReLU()
      (2): Linear(in_features=120, out_features=120, bias=True)
      (3): ReLU()
      (4): Linear(in_features=120, out_features=120, bias=True)
      (5): ReLU()
    )
    (value_net): Sequential(
      (0): Linear(in_features=16, out_features=120, bias=True)
      (1): ReLU()
      (2): Linear(in_features=120, out_features=120, bias=True)
      (3): ReLU()
      (4): Linear(in_features=120, out_features=120, bias=True)
      (5): ReLU()
    )
  )
  (action_net): Linear(in_features=120, out_features=4, bias=True)
  (value_net): Linear(in_fea

In [57]:
# run training loop
train(model, test_env, 'test30_low_motor_limit_10x10box_prog_rew6')



recording started
recording ended
recording saved in videos/TT/test30_low_motor_limit_10x10box_prog_rew6/1000000.mp4
recording started
recording ended
recording saved in videos/TT/test30_low_motor_limit_10x10box_prog_rew6/2000000.mp4
recording started
recording ended
recording saved in videos/TT/test30_low_motor_limit_10x10box_prog_rew6/3000000.mp4
recording started
recording ended
recording saved in videos/TT/test30_low_motor_limit_10x10box_prog_rew6/4000000.mp4
recording started
recording ended
recording saved in videos/TT/test30_low_motor_limit_10x10box_prog_rew6/5000000.mp4
recording started
recording ended
recording saved in videos/TT/test30_low_motor_limit_10x10box_prog_rew6/6000000.mp4
recording started
recording ended
recording saved in videos/TT/test30_low_motor_limit_10x10box_prog_rew6/7000000.mp4
recording started
recording ended
recording saved in videos/TT/test30_low_motor_limit_10x10box_prog_rew6/8000000.mp4
recording started
recording ended
recording saved in videos/TT/t

KeyboardInterrupt: 

# Simulate PPO model

In [None]:
test_env.pause_if_collision = False
animate_policy(model, test_env)

In [None]:
def animate_policy2(model, env, deterministic=False, log_times=False, **kwargs):
    env.reset()
    def run():
        actions, _ = model.predict(env.states, deterministic=deterministic)
        
        # compute velocity
        vx,vy,vz = env.world_states[0,3:6]
        v = np.sqrt(vx**2 + vy**2 + vz**2)
        print(v)

        states, rewards, dones, infos = env.step(actions)
        if log_times:
            if rewards[0] == 10:
                print(env.step_counts[0]*env.dt)
        return env.render()
    animation.view(run, gate_pos=env.gate_pos, gate_yaw=env.gate_yaw, **kwargs)

In [None]:
animate_policy2(model, test_env)

# Generate C code

In [None]:
import torch.nn as nn

# E2E_Net
path = 'models/E2E/test1/3000000.zip'

model = PPO.load(path)

# get network
network = list(model.policy.mlp_extractor.policy_net) + [model.policy.action_net]
network = nn.Sequential(*network)
print('NETWORK:')
print(network)

print(model.policy.action_dist)
print(model.policy.log_std)
print(model.policy.log_std.exp())
network_std = model.policy.log_std.exp().cpu().detach().numpy()
print(network_std)


In [None]:
test_env = Quadcopter3DGates(num_envs=1, gates_pos=gate_pos, gate_yaw=gate_yaw, start_pos=start_pos, gates_ahead=1, pause_if_collision=True)

animate_policy(model, test_env, deterministic=False, log_times=True)

In [None]:
import torch
import torch.nn as nn
import os
import subprocess

# remove the c_code folder and all of its contents
subprocess.call('rm -rf c_code', shell=True)
# create a new c_code folder
subprocess.call('mkdir c_code', shell=True)

# Create the "c_code" folder if it doesn't exist
output_folder = "c_code"
os.makedirs(output_folder, exist_ok=True)

# Generate the C file and the header file inside the "c_code" folder
source_file_path = os.path.join(output_folder, "neural_network.c")
header_file_path = os.path.join(output_folder, "neural_network.h")

# np.float32 to str
float_to_str = lambda x: str(float(x))

# Generate the C file
with open(source_file_path, "w") as file:
    file.write('#include "neural_network.h"\n')
    file.write("#include <stdio.h>\n")
    file.write("#include <math.h>\n\n")

    # Define weights and biases as global constant float arrays
    i = 1
    for layer in network:
        if isinstance(layer, nn.Linear):
            weights_layer = layer.weight.data.cpu().numpy()
            biases_layer = layer.bias.data.cpu().numpy()

            file.write(f"const float weights_fc{i}[] = {{\n")
            file.write(",\n".join([", ".join(map(float_to_str, row)) for row in weights_layer]))
            file.write("\n};\n\n")

            file.write(f"const float biases_fc{i}[] = {{\n")
            file.write(", ".join(map(float_to_str, biases_layer)))
            file.write("\n};\n\n")

            i+=1

    # LINEAR LAYER
    file.write("void nn_linear(const float* weights, const float* biases, const float* input, int in_features, int out_features, float* output) {\n")
    file.write("    for (int i = 0; i < out_features; ++i) {\n")
    file.write("        float neuron = biases[i];\n")
    file.write("        for (int j = 0; j < in_features; ++j) {\n")
    file.write("            neuron += input[j] * weights[i * in_features + j];\n")
    file.write("        }\n")
    file.write("        output[i] = neuron;\n")
    file.write("    }\n")
    file.write("}\n\n")

    # RELU LAYER
    file.write("void nn_relu(float* input, int size) {\n")
    file.write("    for (int i = 0; i < size; ++i) {\n")
    file.write("        input[i] = fmaxf(0, input[i]);\n")
    file.write("    }\n")
    file.write("}\n\n")

    # TANH LAYER
    file.write("void nn_tanh(float* input, int size) {\n")
    file.write("    for (int i = 0; i < size; ++i) {\n")
    file.write("        input[i] = tanh(input[i]);\n")
    file.write("    }\n")
    file.write("}\n\n")

    # FORWARD FUNCTION
    file.write("void nn_forward(const float* input, float* output) {\n")
    layer_size = network[0].out_features
    num_layers = sum(isinstance(layer, nn.Linear) for layer in network)
    i=0
    input_array = "input"
    for layer in network:
        if isinstance(layer, nn.Linear):
            i+=1
            if i<num_layers:
                file.write(f"    float fc{i}_output[{layer.out_features}];\n")
                file.write(f"    nn_linear(weights_fc{i}, biases_fc{i}, {input_array}, {layer.in_features}, {layer.out_features}, fc{i}_output);\n")
                input_array = f"fc{i}_output"
            else:
                file.write(f"    nn_linear(weights_fc{i}, biases_fc{i}, {input_array}, {layer.in_features}, {layer.out_features}, output);\n")
                input_array = "output"
            layer_size = layer.out_features
        elif isinstance(layer, nn.ReLU):
            file.write(f"    nn_relu({input_array}, {layer_size});\n")
        elif isinstance(layer, nn.Tanh):
            file.write(f"    nn_tanh({input_array}, {layer_size});\n")
        else:
            raise Exception(f"Unsupported layer: {layer}")
    file.write("}\n")

# Generate the header file
with open(header_file_path, "w") as header_file:
    header_file.write("#ifndef NEURAL_NETWORK_H\n")
    header_file.write("#define NEURAL_NETWORK_H\n\n")
    # Declare the forward function in the header file
    header_file.write("void nn_forward(const float* input, float* output);\n")
    header_file.write("\n#endif // NEURAL_NETWORK_H\n")


# Print the generated files
# Print the generated files
print(f"Generated {source_file_path}")
print(f"Generated {header_file_path}")

In [None]:
name = 'nn_controller'
# Create the "c_code" folder if it doesn't exist
output_folder = "c_code"
os.makedirs(output_folder, exist_ok=True)

# Generate the C file and the header file inside the "c_code" folder
source_file_path = os.path.join(output_folder, f"{name}.c")
header_file_path = os.path.join(output_folder, f"{name}.h")

num_gates = test_env.num_gates
gates_ahead = test_env.gates_ahead
disturbance_input = True
ranges = test_env.disturbance_ranges

# Generate the header file
with open(header_file_path, "w") as file:
    file.write(f"#ifndef {name.upper()}_H\n")
    file.write(f"#define {name.upper()}_H\n")
    file.write("\n")
    file.write("#include <stdint.h>\n")
    file.write("#include <stdbool.h>\n")
    file.write("\n")
    file.write(f'#define GATES_AHEAD {gates_ahead}\n')
    file.write(f'#define NUM_GATES {num_gates}\n')
    file.write("\n")
    # include neural network code
    file.write("// Include the neural network code\n")
    file.write("#include \"neural_network.h\"\n")
    file.write("\n")
    file.write("const float gate_pos[NUM_GATES][3];\n")
    file.write("const float gate_yaw[NUM_GATES];\n")
    file.write("const float start_pos[3];\n")
    # file.write("const float gate_pos_rel[NUM_GATES][3];\n")
    # file.write("const float gate_yaw_rel[NUM_GATES];\n")
    file.write("uint8_t target_gate_index;\n")
    file.write("\n")
    # nn_reset function that resets the target gate index
    file.write("void nn_reset(void);\n")
    # nn_control function that that takes as input a float array of size 16 (world_state) and outputs an array of size 4 (rpms)
    
    #DISTURBANCE INPUT
    if disturbance_input:
        file.write("void nn_control(const float world_state[16], const float disturbances[4], float rpms[4]);\n")
    else:
        file.write("void nn_control(const float world_state[16], float rpms[4]);\n")
    
    file.write("\n")
    file.write("#endif\n")

# Generate the C file
with open(source_file_path, "w") as file:
    file.write(f"#include \"{name}.h\"\n")
    file.write("#include <math.h>\n")
    file.write("#include <stdlib.h>\n")
    file.write("\n")
    # define boolean to set controller to determistic
    file.write("bool deterministic = false;\n")
    file.write("\n")
    file.write("const float output_std[4] = {\n")
    for i in range(4):
        file.write(f"    {network_std[i]},\n")
    file.write("};\n")
    file.write("\n")
    # define the gate positions and headings as const float arrays
    file.write("const float gate_pos[NUM_GATES][3] = {\n")
    for i in range(num_gates):
        file.write(f"    {{{test_env.gate_pos[i][0]}, {test_env.gate_pos[i][1]}, {test_env.gate_pos[i][2]}}},\n")
    file.write("};\n")
    file.write("\n")
    file.write("const float gate_yaw[NUM_GATES] = {\n")
    for i in range(num_gates):
        file.write(f"    {test_env.gate_yaw[i]},\n")
    file.write("};\n")
    file.write("\n")
    # define the start pos as a const float array
    file.write("const float start_pos[3] = {\n")
    file.write(f"    {test_env.start_pos[0]}, {test_env.start_pos[1]}, {test_env.start_pos[2]}\n")
    file.write("};\n")
    file.write("\n")
    # define the relative gate positions and headings as const float arrays
    file.write("const float gate_pos_rel[NUM_GATES][3] = {\n")
    for i in range(num_gates):
        file.write(f"    {{{test_env.gate_pos_rel[i][0]}, {test_env.gate_pos_rel[i][1]}, {test_env.gate_pos_rel[i][2]}}},\n")
    file.write("};\n")
    file.write("\n")
    file.write("const float gate_yaw_rel[NUM_GATES] = {\n")
    for i in range(num_gates):
        file.write(f"    {test_env.gate_yaw_rel[i]},\n")
    file.write("};\n")
    file.write("\n")
    # define the target gate index and set it to 0
    file.write("uint8_t target_gate_index = 0;\n")
    file.write("\n")
    file.write("void nn_reset() {\n")
    file.write("    target_gate_index = 0;\n")
    file.write("}\n")
    file.write("\n")
    if disturbance_input:
        file.write("void nn_control(const float world_state[16], const float disturbances[4], float rpms[4]) {\n")
    else:
        file.write("void nn_control(const float world_state[16], float rpms[4]) {\n")
    file.write("    // Get the current position, velocity and heading\n")
    file.write("    float pos[3] = {world_state[0], world_state[1], world_state[2]};\n")
    file.write("    float vel[3] = {world_state[3], world_state[4], world_state[5]};\n")
    file.write("    float yaw = world_state[8];\n")
    file.write("\n")
    file.write("    // Get the position and heading of the target gate\n")
    file.write("    float target_pos[3] = {gate_pos[target_gate_index][0], gate_pos[target_gate_index][1], gate_pos[target_gate_index][2]};\n")
    file.write("    float target_yaw = gate_yaw[target_gate_index];\n")
    file.write("\n")
    file.write("    // Set the target gate index to the next gate if we passed through the current one\n")
    file.write("    if (cosf(target_yaw) * (pos[0] - target_pos[0]) + sinf(target_yaw) * (pos[1] - target_pos[1]) > 0) {\n")
    file.write("        target_gate_index++;\n")
    file.write("        // loop back to the first gate if we reach the end\n")
    file.write("        target_gate_index = target_gate_index % NUM_GATES;\n")
    file.write("        // reset the target position and heading\n")
    file.write("        target_pos[0] = gate_pos[target_gate_index][0];\n")
    file.write("        target_pos[1] = gate_pos[target_gate_index][1];\n")
    file.write("        target_pos[2] = gate_pos[target_gate_index][2];\n")
    file.write("        target_yaw = gate_yaw[target_gate_index];\n")
    file.write("    }\n")
    file.write("\n")
    file.write("    // Get the position of the drone in gate frame\n")
    file.write("    float pos_rel[3] = {\n")
    file.write("        cosf(target_yaw) * (pos[0] - target_pos[0]) + sinf(target_yaw) * (pos[1] - target_pos[1]),\n")
    file.write("        -sinf(target_yaw) * (pos[0] - target_pos[0]) + cosf(target_yaw) * (pos[1] - target_pos[1]),\n")
    file.write("        pos[2] - target_pos[2]\n")
    file.write("    };\n")
    file.write("\n")
    file.write("    // Get the velocity of the drone in gate frame\n")
    file.write("    float vel_rel[3] = {\n")
    file.write("        cosf(target_yaw) * vel[0] + sinf(target_yaw) * vel[1],\n")
    file.write("        -sinf(target_yaw) * vel[0] + cosf(target_yaw) * vel[1],\n")
    file.write("        vel[2]\n")
    file.write("    };\n")
    file.write("\n")
    file.write("    // Get the heading of the drone in gate frame\n")
    file.write("    float yaw_rel = yaw - target_yaw;\n")
    file.write("    while (yaw_rel > M_PI) {yaw_rel -= 2*M_PI;}\n")
    file.write("    while (yaw_rel < -M_PI) {yaw_rel += 2*M_PI;}\n")
    file.write("\n")
    file.write("    // Get the neural network input\n")
    if disturbance_input:
        file.write("    float nn_input[16+4*GATES_AHEAD+4];\n")
    else:
        file.write("    float nn_input[16+4*GATES_AHEAD];\n")
    file.write("    // position and velocity\n")
    file.write("    for (int i = 0; i < 3; i++) {\n")
    file.write("        nn_input[i] = pos_rel[i];\n")
    file.write("        nn_input[i+3] = vel_rel[i];\n")
    file.write("    }\n")
    file.write("    // attitude\n")
    file.write("    nn_input[6] = world_state[6];\n")
    file.write("    nn_input[7] = world_state[7];\n")
    file.write("    nn_input[8] = yaw_rel;\n")
    file.write("    // body rates\n")
    file.write("    nn_input[9] = world_state[9];\n")
    file.write("    nn_input[10] = world_state[10];\n")
    file.write("    nn_input[11] = world_state[11];\n")
    file.write("    // motor rpms scaled to [-1,1]\n")
    file.write(f"    float w_min = {w_min};\n")
    file.write(f"    float w_max = {w_max};\n")
    file.write("    nn_input[12] = (world_state[12] - w_min) * 2 / (w_max - w_min) - 1;\n")
    file.write("    nn_input[13] = (world_state[13] - w_min) * 2 / (w_max - w_min) - 1;\n")
    file.write("    nn_input[14] = (world_state[14] - w_min) * 2 / (w_max - w_min) - 1;\n")
    file.write("    nn_input[15] = (world_state[15] - w_min) * 2 / (w_max - w_min) - 1;\n")
    file.write("\n")
    file.write("    // relative gate positions and headings\n")
    file.write("    for (int i = 0; i < GATES_AHEAD; i++) {\n")
    file.write("        uint8_t index = target_gate_index + i + 1;\n")
    file.write("        // loop back to the first gate if we reach the end\n")
    file.write("        index = index % NUM_GATES;\n")
    file.write("        nn_input[16+4*i]   = gate_pos_rel[index][0];\n")
    file.write("        nn_input[16+4*i+1] = gate_pos_rel[index][1];\n")
    file.write("        nn_input[16+4*i+2] = gate_pos_rel[index][2];\n")
    file.write("        nn_input[16+4*i+3] = gate_yaw_rel[index];\n")
    file.write("    }\n")
    # DISTURBANCES
    if disturbance_input:
        # file.write("    // disturbance input\n")
        # file.write(f"    float Mx_mean = {test_env.Mx_mean};\n")
        # file.write(f"    float Mx_std = {test_env.Mx_std};\n")
        # file.write(f"    float My_mean = {test_env.My_mean};\n")
        # file.write(f"    float My_std = {test_env.My_std};\n")
        # file.write(f"    float Mz_mean = {test_env.Mz_mean};\n")
        # file.write(f"    float Mz_std = {test_env.Mz_std};\n")
        # file.write(f"    float Fz_mean = {test_env.Fz_mean};\n")
        # file.write(f"    float Fz_std = {test_env.Fz_std};\n")
        # file.write("\n")
        # file.write("    nn_input[16+4*GATES_AHEAD]   = (disturbances[0] - Mx_mean) / Mx_std;\n")
        # file.write("    nn_input[16+4*GATES_AHEAD+1] = (disturbances[1] - My_mean) / My_std;\n")
        # file.write("    nn_input[16+4*GATES_AHEAD+2] = (disturbances[2] - Mz_mean) / Mz_std;\n")
        # file.write("    nn_input[16+4*GATES_AHEAD+3] = (disturbances[3] - Fz_mean) / Fz_std;\n")
        file.write("    // disturbance input\n")
        file.write(f"    float Mx_min = {ranges[0][0]};\n")
        file.write(f"    float Mx_max = {ranges[0][1]};\n")
        file.write(f"    float My_min = {ranges[1][0]};\n")
        file.write(f"    float My_max = {ranges[1][1]};\n")
        file.write(f"    float Mz_min = {ranges[2][0]};\n")
        file.write(f"    float Mz_max = {ranges[2][1]};\n")
        file.write(f"    float Fz_min = {ranges[5][0]};\n")
        file.write(f"    float Fz_max = {ranges[5][1]};\n")
        file.write("    nn_input[16+4*GATES_AHEAD]   = (disturbances[0] - Mx_min) * 2 / (Mx_max - Mx_min) - 1;\n")
        file.write("    nn_input[16+4*GATES_AHEAD+1] = (disturbances[1] - My_min) * 2 / (My_max - My_min) - 1;\n")
        file.write("    nn_input[16+4*GATES_AHEAD+2] = (disturbances[2] - Mz_min) * 2 / (Mz_max - Mz_min) - 1;\n")
        file.write("    nn_input[16+4*GATES_AHEAD+3] = (disturbances[3] - Fz_min) * 2 / (Fz_max - Fz_min) - 1;\n")
        file.write("\n")
    file.write("    // Get the neural network output and write to the action array\n")
    file.write("    float nn_output[4];\n")
    file.write("    nn_forward(nn_input, nn_output);\n")
    file.write("\n")
    # if determinstic is false, add gaussian noise to the output
    file.write("    // add gaussian noise to the output\n")
    file.write("    if (!deterministic) {\n")
    file.write("        for (int i = 0; i < 4; i++) {\n")
    # generate random gaussian variables using the Box–Muller transform
    file.write("            // generate random gaussian variables using the Box–Muller transform\n")
    file.write("            float u1 = (float)rand() / RAND_MAX;\n")
    file.write("            float u2 = (float)rand() / RAND_MAX;\n")
    file.write("            float rand_std = sqrtf(-2 * logf(u1)) * cosf(2 * M_PI * u2);\n")
    file.write("            // add the noise to the output\n")
    file.write("            nn_output[i] += output_std[i] * rand_std;\n")
    file.write("        }\n")
    file.write("    }\n")
    file.write("\n")
    file.write("    for (int i = 0; i < 4; i++) {\n")
    file.write("        // clip the output to the range [-1, 1]\n")
    file.write("        if (nn_output[i] > 1) {nn_output[i] = 1;}\n")
    file.write("        if (nn_output[i] < -1) {nn_output[i] = -1;}\n")
    file.write("         // map the output to the range [w_min, w_max]\n")
    file.write(f"        rpms[i] = (w_max - w_min) * (nn_output[i] + 1) / 2 + w_min;\n")
    file.write("    }\n")
    file.write("}\n")

# Print the generated files
print(f"Generated {source_file_path}")
print(f"Generated {header_file_path}")

# Test C code

In [None]:
import os
import subprocess
import ctypes
import numpy as np
import importlib
importlib.reload(ctypes)

# https://cu7ious.medium.com/how-to-use-dynamic-libraries-in-c-46a0f9b98270
path = os.path.abspath('c_code')
# Create object files
subprocess.call('gcc -fPIC -c *.c', shell=True, cwd=path)
# Create library
subprocess.call('gcc -shared -Wl,-soname,libtools.so -o libtools.so *.o', shell=True, cwd=path)
# Remove object files
subprocess.call('rm *.o', shell=True, cwd=path)

lib_path = os.path.abspath("c_code/libtools.so")
fun = ctypes.CDLL(lib_path)

# define argument types 
fun.nn_forward.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float)]
if disturbance_input:
    fun.nn_control.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float)]
else:
    fun.nn_control.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float)]

In [None]:
def c_network(x):
    x = np.array(x, dtype=np.float32)
    c_net_input = (ctypes.c_float*len(x))(*x)
    c_net_output = (ctypes.c_float*4)()
    fun.nn_forward(c_net_input, c_net_output)
    out = np.array(c_net_output[:])
    return np.clip(out, -1,1)

def torch_network(x):
    x = torch.tensor(x, dtype=torch.float32)
    out = network(x).cpu().detach().numpy()
    return np.clip(out, -1,1)

def nn_control_d_input(x, d):
    x = np.array(x, dtype=np.float32)
    d = np.array(d, dtype=np.float32)

    # scale to [w_min, w_max]
    x[12:16] = (x[12:16] + 1)/2*(w_max - w_min) + w_min

    c_net_input_1 = (ctypes.c_float*len(x))(*x)
    c_net_input_2 = (ctypes.c_float*len(d))(*d)
    c_net_output = (ctypes.c_float*4)()
    
    fun.nn_control(c_net_input_1, c_net_input_2, c_net_output)
    out = np.array(c_net_output[:])

    # map back to [-1,1]
    out = (out - w_min)/(w_max - w_min)*2 - 1
    return out

def nn_control(x):
    x = np.array(x, dtype=np.float32)

    # scale to [w_min, w_max]
    x[12:16] = (x[12:16] + 1)/2*(w_max - w_min) + w_min

    c_net_input = (ctypes.c_float*len(x))(*x)
    c_net_output = (ctypes.c_float*4)()
    
    
    fun.nn_control(c_net_input, c_net_output)
    out = np.array(c_net_output[:])
    # map back to [-1,1]
    out = (out - w_min)/(w_max - w_min)*2 - 1
    return out


# test
x = np.random.rand(24)
full = lambda x: [float(xi) for xi in x]
print(full(c_network(x)))
print(full(torch_network(x)))

In [None]:
# Simulate C Network
test_env.reset()
fun.nn_reset()

crashes = []
steps = 0

def run():
    # state = test_env.states[0]
    # action1 = c_network(state.copy())
    # print(state)

    world_state = test_env.world_states[0]
    d_input = test_env.disturbances[0, [0,1,2,5]]
    action = nn_control_d_input(world_state.copy(), d_input.copy())
    
    actions = np.array([action])

    steps = test_env.step_counts[0]+1
    states, rewards, dones, infos = test_env.step(actions)

    if dones[0]:
        crash = steps != test_env.max_steps
        crashes.append(crash)
        if crash:
            print('crash')
        else:
            print('success')
        fun.nn_reset()
    if len(crashes) == 100:
        print(f"Crash rate: {np.mean(crashes)}")
        # raise KeyboardInterrupt
    return test_env.render()

animation.view(run, gate_pos=test_env.gate_pos, gate_yaw=test_env.gate_yaw)

# Add C code for NN Drone Model

In [None]:
import torch
import torch.nn as nn
import os

# Create the "c_code" folder if it doesn't exist
output_folder = "c_code"
os.makedirs(output_folder, exist_ok=True)

name = 'nn_moment'

network = moment_model

# Generate the C file and the header file inside the "c_code" folder
source_file_path = os.path.join(output_folder, f'{name}.c')
header_file_path = os.path.join(output_folder, f'{name}.h')

# np.float32 to str
float_to_str = lambda x: str(float(x))

# Generate the C file
with open(source_file_path, "w") as file:
    file.write('#include "nn_thrust.h"\n')
    file.write("#include <stdio.h>\n")
    file.write("#include <math.h>\n\n")

    # Define weights and biases as global constant float arrays
    i = 1
    for layer in network:
        if isinstance(layer, nn.Linear):
            weights_layer = layer.weight.data.cpu().numpy()
            biases_layer = layer.bias.data.cpu().numpy()

            file.write(f"const float {name}_weights_fc{i}[] = {{\n")
            file.write(",\n".join([", ".join(map(float_to_str, row)) for row in weights_layer]))
            file.write("\n};\n\n")

            file.write(f"const float {name}_biases_fc{i}[] = {{\n")
            file.write(", ".join(map(float_to_str, biases_layer)))
            file.write("\n};\n\n")

            i+=1

    # LINEAR LAYER
    file.write("void "+name+"_linear(const float* weights, const float* biases, const float* input, int in_features, int out_features, float* output) {\n")
    file.write("    for (int i = 0; i < out_features; ++i) {\n")
    file.write("        float neuron = biases[i];\n")
    file.write("        for (int j = 0; j < in_features; ++j) {\n")
    file.write("            neuron += input[j] * weights[i * in_features + j];\n")
    file.write("        }\n")
    file.write("        output[i] = neuron;\n")
    file.write("    }\n")
    file.write("}\n\n")

    # RELU LAYER
    file.write("void "+name+"_relu(float* input, int size) {\n")
    file.write("    for (int i = 0; i < size; ++i) {\n")
    file.write("        input[i] = fmaxf(0, input[i]);\n")
    file.write("    }\n")
    file.write("}\n\n")

    # TANH LAYER
    file.write("void "+name+"_tanh(float* input, int size) {\n")
    file.write("    for (int i = 0; i < size; ++i) {\n")
    file.write("        input[i] = tanh(input[i]);\n")
    file.write("    }\n")
    file.write("}\n\n")

    # FORWARD FUNCTION
    file.write("void "+name+"_forward(const float* input, float* output) {\n")
    layer_size = network[0].out_features
    num_layers = sum(isinstance(layer, nn.Linear) for layer in network)
    i=0
    input_array = "input"
    for layer in network:
        if isinstance(layer, nn.Linear):
            i+=1
            if i<num_layers:
                file.write(f"    float fc{i}_output[{layer.out_features}];\n")
                file.write(f"    {name}_linear({name}_weights_fc{i}, {name}_biases_fc{i}, {input_array}, {layer.in_features}, {layer.out_features}, fc{i}_output);\n")
                input_array = f"fc{i}_output"
            else:
                file.write(f"    {name}_linear({name}_weights_fc{i}, {name}_biases_fc{i}, {input_array}, {layer.in_features}, {layer.out_features}, output);\n")
                input_array = "output"
            layer_size = layer.out_features
        elif isinstance(layer, nn.ReLU):
            file.write(f"    {name}_relu({input_array}, {layer_size});\n")
        elif isinstance(layer, nn.Tanh):
            file.write(f"    {name}_tanh({input_array}, {layer_size});\n")
        else:
            raise Exception(f"Unsupported layer: {layer}")
    file.write("}\n")

# Generate the header file
with open(header_file_path, "w") as header_file:
    header_file.write(f'#ifndef {name.upper()}_H\n')
    header_file.write(f'#define {name.upper()}_H\n\n')
    # Declare the forward function in the header file
    header_file.write("void "+name+"_forward(const float* input, float* output);\n")
    header_file.write(f'\n#endif // {name.upper()}_H\n')


# Print the generated files
print(f"Generated {source_file_path}")
print(f"Generated {header_file_path}")