In [1]:
import matplotlib.pyplot as plt
import numpy as np
from tqdm import trange
from gymnasium.utils.env_checker import check_env
from IPython.display import Video
from enum import Enum, auto

from flygym.mujoco import Parameters
from flygym.mujoco.arena import FlatTerrain
from flygym.mujoco.examples.obstacle_arena import ObstacleOdorArena
from flygym.mujoco.examples.turning_controller import HybridTurningNMF

In [2]:
import flygym.mujoco
from tqdm import trange

# We start by creating a simple arena
flat_terrain_arena = FlatTerrain()

# Then, we add visual and olfactory features on top of it
arena = ObstacleOdorArena(
    terrain=flat_terrain_arena,
    obstacle_positions=np.array([(7.5, 0), (12.5, 5), (17.5, -5)]),
    marker_size=0.5,
    obstacle_colors=[(0.14, 0.14, 0.2, 1), (0.2, 0.8, 0.2, 1), (0.2, 0.2, 0.8, 1)],
    user_camera_settings=((13, -18, 9), (np.deg2rad(65), 0, 0), 45),
)

contact_sensor_placements = [
    f"{leg}{segment}"
    for leg in ["LF", "LM", "LH", "RF", "RM", "RH"]
    for segment in ["Tibia", "Tarsus1", "Tarsus2", "Tarsus3", "Tarsus4", "Tarsus5"]
]

run_time = 1
sim_params = flygym.mujoco.Parameters(
    timestep=1e-4, 
    render_mode="saved", 
    render_playspeed=0.1, 
    draw_contacts=False,
    render_camera="user_cam"
)

nmf = HybridTurningNMF(
    sim_params=sim_params,
    init_pose="stretch",
    ############## 0.5 default of the Neuromechfly (0.2 for the TurningController)
    spawn_pos=(10, -5, 0.5),
    spawn_orientation=(0, 0, np.pi / 2 + np.deg2rad(80)),
    ###############
    contact_sensor_placements=contact_sensor_placements,
    arena=arena
)


In [None]:
from enum import Enum, auto

# random state seed for reproducibility
seed = 1

class State(Enum):
    STATIC = 0
    GO_STRAIGHT = 1
    TURN_LEFT = 2
    TURN_RIGHT = 3
    REVERSE = 4
    REACH_AND_TURN = 5



####STATIC NEED TO BE IMPLEMENTED


# Example of setting a current state
current_state = State.GO_STRAIGHT

decision_interval = 0.2
run_time = 1.2
num_decision_steps = int(run_time / decision_interval)
physics_steps_per_decision_step = int(decision_interval / sim_params.timestep)

low_force_thresh = 2
high_force_thresh = 5
enforce_time = 0
delay = 0.2

obs_hist = []
odor_history = []
obs, _ = nmf.reset(seed)

bias = np.array([0,0])
for i in trange(int(run_time / nmf.sim_params.timestep)):
    curr_time = i * nmf.sim_params.timestep
    left_sense = np.array(obs["contact_forces"][:5, :])
    right_sense = np.array(obs["contact_forces"][18:23, :])

####addition of absolute value, forces can sometimes only be negative!
    left_sense_sum = np.abs(left_sense).sum()
    right_sense_sum = np.abs(right_sense).sum()

    if right_sense_sum > low_force_thresh:
        if current_state == State.GO_STRAIGHT:
            if right_sense_sum > high_force_thresh:
                #for a certain amount of time
                
                current_state = State.REVERSE

                ####then
                current_state = State.TURN_LEFT
                bias = np.array([-1, 0])
                # reverse while turning left function
            else:
                current_state = State.TURN_LEFT
                bias = np.array([-1, 0])

        elif current_state == State.TURN_RIGHT:
            current_state = State.REACH_AND_TURN
            # call reach and turn function


        #delay
        enforce_time =  enforce_time + delay
    
  
    if left_sense_sum > low_force_thresh:
        if current_state == State.GO_STRAIGHT:
            if left_sense_sum > high_force_thresh:
                #for a certain amount of time
                
                current_state = State.REVERSE

                ####then
                current_state = State.TURN_RIGHT
                bias = np.array([0, -1])
                # reverse while turning left function
            else:
                current_state = State.TURN_RIGHT
                bias = np.array([0, -1])

        elif current_state == State.TURN_LEFT:
            current_state = State.REACH_AND_TURN
            # call reach and turn function



        # delay
        enforce_time =  enforce_time + delay
    

    elif curr_time >= enforce_time:
        current_state = State.GO_STRAIGHT
        bias = np.array([0,0])
    

    control_signal = np.array([1, 1]) + bias


    if current_state == State.REVERSE:
        obs, reward, terminated, truncated, info = nmf.step(control_signal, orig = TRUE)
    else:
        obs, reward, terminated, truncated, info = nmf.step(control_signal)
    nmf.render()
    ###########################################
    obs_hist.append(obs)
    ###########################################

nmf.save_video("./outputs/pillars.mp4")
Video("./outputs/pillars.mp4")

In [None]:
np.array(obs["contact_forces"][:1, :3])

In [None]:
print(len(obs_hist))
print(obs_hist[0].keys())
print(obs_hist[1]["contact_forces"][:5,:].shape)
#obs["contact_forces"][:5, :]


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Assuming obs_hist is your list of observations
forces_left = np.array([obs["contact_forces"][:6,:] for obs in obs_hist])
forces_right = np.array([obs["contact_forces"][18:24,:] for obs in obs_hist])

# Create a figure and a grid of subplots
fig, axs = plt.subplots(nrows=6, ncols=6, figsize=(15, 10))

directions = ['x', 'y', 'z']


# For each segment
for i in range(6):
    # For each direction
    for j in range(3):
        # Left leg
        axs[i, j].plot(forces_left[:, i, j])
        axs[i, j].set_title(f'Left leg seg:{i+1}, direction {directions[j]}')
        
        # Right leg
        axs[i, j+3].plot(forces_right[:, i, j])
        axs[i, j+3].set_title(f'Right leg seg:{i+1}, direction {directions[j]}')

plt.tight_layout()
plt.show()

In [None]:
forces_left.shape

In [None]:
import matplotlib.animation as animation

forces_left = np.array([obs["contact_forces"][:5,:] for obs in obs_hist])
forces_right = np.array([obs["contact_forces"][18:23,:] for obs in obs_hist])


fig, axs = plt.subplots(nrows=5, ncols=6, figsize=(15, 10))

directions = ['x', 'y', 'z']

lines = []

for i in range(5):
    # For each direction
    for j in range(3):
        # Left leg
        line, = axs[i, j].plot(forces_left[0, i, j])
        axs[i, j].set_title(f'Left leg seg:{i+1}, direction {directions[j]}')
        lines.append(line)
        
        # Right leg
        line, = axs[i, j+3].plot(forces_right[0, i, j])
        axs[i, j+3].set_title(f'Right leg seg:{i+1}, direction {directions[j]}')
        lines.append(line)

def update(num, forces_left_hist, forces_right_hist, lines):
    for i in range(5):
        for j in range(3):
            lines[2*(i*3+j)].set_ydata(forces_left_hist[:, i, j])
            lines[2*(i*3+j)+1].set_ydata(forces_right_hist[:, i, j])
    
    return lines

ani = animation.FuncAnimation(fig, update, len(forces_left), fargs=[forces_left, forces_right, lines],
                              interval=10, blit=False)

ani.save("./outputs/force_plot.mp4")
Video("./outputs/force_plot.mp4")
plt.show()

In [None]:
Video("./outputs/force_plot.mp4")