In [1]:
from flygym.arena import FlatTerrain
from flygym import SingleFlySimulation, Camera, Fly
from flygym.examples.vision import ObstacleOdorArena
import numpy as np

import matplotlib.pyplot as plt

In [None]:
class LoomingBallArena(FlatTerrain):
    """
    Simulates a looming ball scenario where a ball approaches a fly entity from different angles
    with Poisson-distributed spawning events.
    """

    def __init__(
        self,
        timestep,
        fly,
        ball_radius=1.0,
        ball_approach_vel=50,
        ball_approach_start_radius=20,
        ball_overshoot_dist=3,
        looming_lambda=2.0,
        seed=0,
        approach_angles=np.array([np.pi / 4, 3 * np.pi / 4]),
        **kwargs
    ):
        super().__init__(**kwargs)

        self.fly = fly
        self.seed = seed
        self.random_state = np.random.RandomState(seed)
        self.dt = timestep
        self.ball_radius = ball_radius

        self._setup_probabilities(looming_lambda)
        self._setup_trajectory_params(ball_approach_vel, ball_approach_start_radius, ball_overshoot_dist)
        self._setup_ball_heights()

        self.ball_approach_angles = approach_angles
        self.is_looming = False
        self.ball_traj_advancement = 0

        self._setup_velocity_buffer()
        self.add_ball(ball_radius)

    def _setup_probabilities(self, looming_lambda):
        """Initialize Poisson probability settings."""
        self.p_no_looming = np.exp(-looming_lambda * self.dt)

    def _setup_trajectory_params(self, vel, start_radius, overshoot_dist):
        """Precompute trajectory parameters for the looming ball."""
        self.ball_approach_vel = vel
        self.ball_approach_start_radius = start_radius
        self.overshoot_dist = overshoot_dist

        interception_time = start_radius / vel
        self.n_interception_steps = int(interception_time / self.dt)
        self.n_overshoot_steps = int((overshoot_dist / vel) / self.dt)

        self.ball_trajectory = np.zeros((self.n_interception_steps + self.n_overshoot_steps, 2))

    def _setup_ball_heights(self):
        """Calculate ball positions for visible and resting states."""
        self.ball_rest_height = 10.0
        self.ball_act_height = -self.ball_rest_height + self.ball_radius + self._get_max_floor_height()

    def _setup_velocity_buffer(self):
        """Setup a fixed-length FIFO buffer for fly velocity estimation."""
        self.vel_buffer_size = 50
        self.fly_velocities = np.full((self.vel_buffer_size, 2), np.nan)
        self.fly_velocities_idx = 0

    def add_ball(self, ball_radius):
        """Add the ball to the scene with joints and geometry."""
        self.ball_body = self.root_element.worldbody.add(
            'body', name='ball', pos=[0, 0, self.ball_rest_height]
        )
        self.ball_jointx = self.ball_body.add("joint", type="slide", axis=[1, 0, 0], damping=0.01)
        self.ball_jointy = self.ball_body.add("joint", type="slide", axis=[0, 1, 0], damping=0.01)
        self.ball_jointz = self.ball_body.add("joint", type="slide", axis=[0, 0, 1], damping=0.01)
        self.ball_geom = self.ball_body.add("geom", name="ball", type='sphere', size=[ball_radius], rgba=[1, 0, 0, 0])

    def spawn_entity(self, entity, rel_pos, rel_angle):
        """Spawn the fly and setup collision pairs."""
        super().spawn_entity(entity, rel_pos, rel_angle)
        self._add_contacts()

    def _add_contacts(self):
        """Add contact pairs between the ball and key fly body parts."""
        ball_geom_name = self.ball_geom.name
        for animat_geom_name in ["Head", "Thorax", "A1A2", "A3", "A4", "A5", "A6"]:
            self.root_element.contact.add(
                "pair",
                name=f"{ball_geom_name}_{self.fly.name}_{animat_geom_name}",
                geom1=f"{self.fly.name}/{animat_geom_name}",
                geom2=ball_geom_name,
            )

    def set_ball_trajectory(self, start_pts, end_pts):
        """Generate a linear trajectory from start to end."""
        self.ball_trajectory = np.linspace(start_pts, end_pts, self.n_interception_steps + self.n_overshoot_steps)

    def make_ball_visible(self, physics):
        physics.bind(self.ball_geom).rgba[3] = 1

    def make_ball_invisible(self, physics):
        physics.bind(self.ball_geom).rgba[3] = 0

    def move_ball(self, physics, x, y, z):
        """Move ball to the desired location using joint positions."""
        physics.bind(self.ball_jointx).qpos = x
        physics.bind(self.ball_jointy).qpos = y
        physics.bind(self.ball_jointz).qpos = z

    def _get_mean_fly_velocity(self):
        """Compute average fly velocity from buffer."""
        return np.nanmean(self.fly_velocities, axis=0)

    def _should_trigger_ball(self):
        """Check if ball should start looming based on Poisson process."""
        return self.random_state.rand() > self.p_no_looming and not self.is_looming

    def _compute_trajectory_from_fly(self, fly_pos, fly_vel, fly_or_vec):
        """Generate start/end points of the ball trajectory based on fly state."""
        fly_roll = np.arctan2(fly_or_vec[1], fly_or_vec[0])
        approach_side = self.random_state.choice([-1, 1])
        rel_angles = self.ball_approach_angles * approach_side + fly_roll
        start_angle = self.random_state.uniform(low=rel_angles[0], high=rel_angles[1])

        interception_pos = fly_pos + fly_vel * self.n_interception_steps * self.dt
        start_pos = interception_pos + self.ball_approach_start_radius * np.array([
            np.cos(start_angle), np.sin(start_angle)
        ])
        end_pos = interception_pos - self.overshoot_dist * np.array([
            np.cos(start_angle), np.sin(start_angle)
        ])
        return start_pos, end_pos, interception_pos, start_angle

    def step(self, dt, physics):
        """Main loop: updates ball state, triggers looming events, and moves the ball."""
        # Update fly velocity buffer
        fly_vel = physics.bind(self.fly._body_sensors[0]).sensordata[:2].copy()
        self.fly_velocities[self.fly_velocities_idx % self.vel_buffer_size] = fly_vel
        self.fly_velocities_idx += 1

        if self._should_trigger_ball():
            self.is_looming = True
            self.ball_traj_advancement = 0
            self.make_ball_visible(physics)

            fly_pos = physics.bind(self.fly._body_sensors[0]).sensordata[:2].copy()
            fly_or_vec = physics.bind(self.fly._body_sensors[4]).sensordata.copy()
            fly_vel_mean = self._get_mean_fly_velocity()

            start_pts, end_pts, interception_pos, angle = self._compute_trajectory_from_fly(
                fly_pos, fly_vel_mean, fly_or_vec
            )

            self.set_ball_trajectory(start_pts, end_pts)
            self.move_ball(physics, *start_pts, -self.ball_act_height)
            self.ball_traj_advancement += 1

            # Optional: visualize
            #self._plot_trajectory_debug(fly_pos, fly_vel_mean, interception_pos, start_pts, fly_or_vec)

        elif self.is_looming:
            self._advance_ball(physics)
        else:
            self.move_ball(physics, 0, 0, self.ball_rest_height)

    def _advance_ball(self, physics):
        """Advance the ball along its trajectory."""
        pos = self.ball_trajectory[self.ball_traj_advancement]
        self.move_ball(physics, pos[0], pos[1], self.ball_act_height)
        self.ball_traj_advancement += 1

        if self.ball_traj_advancement >= self.n_interception_steps + self.n_overshoot_steps:
            self.is_looming = False
            self.make_ball_invisible(physics)
            self.move_ball(physics, 0, 0, self.ball_rest_height)

    def _plot_trajectory_debug(self, fly_pos, fly_vel, intercept_pos, start_pos, orientation_vec):
        """Visualize trajectory for debugging."""
        plt.scatter(fly_pos[0], fly_pos[1], label='fly pos', s=5)
        plt.scatter(intercept_pos[0], intercept_pos[1], label='fly interception pos', s=5)
        plt.plot(self.ball_trajectory[:, 0], self.ball_trajectory[:, 1], label='ball trajectory')
        plt.scatter(start_pos[0], start_pos[1], label='ball start pos', s=5)
        plt.arrow(fly_pos[0], fly_pos[1], fly_vel[0], fly_vel[1], head_width=0.5, fc='blue')
        plt.arrow(intercept_pos[0], intercept_pos[1], orientation_vec[0], orientation_vec[1], head_width=0.5, fc='green')
        plt.legend()
        plt.show()

    def reset(self, physics, seed=None):
        """Reset the environment and optionally reseed."""
        if seed is not None:
            self.seed = seed
        self.random_state = np.random.RandomState(self.seed)

        self.is_looming = False
        self.ball_traj_advancement = 0
        self.make_ball_invisible(physics)
        self.ball_trajectory = np.zeros((self.n_interception_steps, 2))
        self.move_ball(physics, 0, 0, self.ball_rest_height)


In [3]:
from flygym.examples.locomotion import HybridTurningFly
from flygym import YawOnlyCamera
from tqdm import trange
from dm_control.rl.control import PhysicsError


run_time = 2.0
timestep = 1e-4
contact_sensor_placements = [
    f"{leg}{segment}"
    for leg in ["LF", "LM", "LH", "RF", "RM", "RH"]
    for segment in ["Tibia", "Tarsus1", "Tarsus2", "Tarsus3", "Tarsus4", "Tarsus5"]
]

np.random.seed(0)

fly = HybridTurningFly(
    enable_adhesion=True,
    draw_adhesion=True,
    contact_sensor_placements=contact_sensor_placements,
    seed=0,
    draw_corrections=True,
    timestep=timestep,
)

arena = LoomingBallArena(timestep, fly)

cam = YawOnlyCamera(
    attachment_point=fly.model.worldbody,
    camera_name="camera_top_zoomout",
    targeted_fly_names=[fly.name],
    play_speed=0.1,
)

sim = SingleFlySimulation(
    fly=fly,
    cameras=[cam],
    timestep=timestep,
    arena=arena,
)

obs_list = []
delay = 0.1
delay_steps = int(delay / sim.timestep)
delay_timer = 0

for avoid in [False, True]:
    for seed in range(2):
        obs, info = sim.reset(seed=0)
        arena.reset(sim.physics, seed=seed)
        print(f"Spawning fly at {obs['fly'][0]} mm")

        for i in trange(int(run_time / sim.timestep)):
            curr_time = i * sim.timestep

            action = np.array([1.0, 1.0])
            if arena.is_looming and avoid:
                if delay_timer < delay_steps:
                    delay_timer += 1
                else:
                    action = np.array([-1.0, -1.0])
            try:
                obs, reward, terminated, truncated, info = sim.step(action)
                obs_list.append(obs)
                sim.render()
            except PhysicsError:
                print("Simulation was interrupted because of a physics error")
                break

        x_pos = obs_list[-1]["fly"][0][0]
        print(f"Final x position: {x_pos:.4f} mm")
        print(f"Simulation terminated: {obs_list[-1]['fly'][0] - obs_list[0]['fly'][0]}")

        vid_name = f"./outputs/ball_arena_avoid_seed{seed}.mp4" if avoid else f"./outputs/ball_arena_seed{seed}.mp4"
        cam.save_video(vid_name, 0)

Spawning fly at [0.01634914 0.00730949 1.7814164 ] mm


100%|██████████| 20000/20000 [04:04<00:00, 81.79it/s] 


Final x position: 31.3175 mm
Simulation terminated: [31.301048    5.520554    0.20880663]
Spawning fly at [0.01634914 0.00730949 1.7814164 ] mm


100%|██████████| 20000/20000 [12:53<00:00, 25.86it/s]  


Final x position: 16.4227 mm
Simulation terminated: [16.406263   -0.37365618 -1.8275015 ]
Spawning fly at [0.01634914 0.00730949 1.7814164 ] mm


100%|██████████| 20000/20000 [16:03<00:00, 20.75it/s]   


Final x position: -2.7703 mm
Simulation terminated: [-2.7867138  -0.87256736 -0.6102643 ]
Spawning fly at [0.01634914 0.00730949 1.7814164 ] mm


100%|██████████| 20000/20000 [1:28:28<00:00,  3.77it/s]   


Final x position: 3.1918 mm
Simulation terminated: [ 3.175324    1.066345   -0.61291254]
