In [None]:
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output

from model.drone_model_jax import get_default_params, pack_params, N_STATES, IDX
from utils.swarm_utils import assemble_swarm, swarm_evalf


def visualizeNetwork(x, p, t=None, U=None, xlim=None, ylim=None):    
    # Extract number of drones and positions
    N = len(x) // N_STATES
    
    y_positions = []
    z_positions = []
    for i in range(N):
        base = i * N_STATES
        y_pos = x[base + IDX['y']]
        z_pos = x[base + IDX['z']]
        y_positions.append(y_pos)
        z_positions.append(z_pos)
    
    # Visualization
    plt.gcf().clear()
    
    # Plot targets
    if U is not None:
        for i in range(N):
            plt.plot(U[i][0], U[i][1], 'x', c='red', 
                     markersize=10, alpha=0.7, label=f'Target {i}')
            plt.text(U[i][0] + 0.05, U[i][1] + 0.05, 
                     f'{i}', fontsize=9)
    
    # Plot individual drones
    for i in range(N):
        plt.plot(y_positions[i], z_positions[i], 'o', c='blue', 
                 markersize=10, alpha=0.7, label=f'Drone {i}')
        plt.text(y_positions[i] + 0.05, z_positions[i] + 0.05, 
                 f'{i}', fontsize=9)
    
    title = f'Drone Swarm Visualization ({N} Drones)'
    if t is not None:
        title += f' - Time: {t:.2f} s'
    plt.title(title)
    
    plt.xlabel('Y Position (meters)')
    plt.ylabel('Z Position (meters)')
    plt.legend()
    plt.grid(True)
    plt.axis('equal')
    plt.xlim(xlim)
    plt.ylim(ylim)
    
    # plt.pause(0.01)

In [None]:
%matplotlib inline

# Initialize
p = get_default_params()
p_tuple = pack_params(p)
targets = [(1.0, 2.0), (0.5, 1.5), (0.0, 1.0)]
X0_flat, U = assemble_swarm(targets)
N = len(targets)

# Simulation parameters
dt = 1e-4
T_final = 10.0
frame_skip = 1000

# Initialize state
x = X0_flat.copy()
t = 0.0
step = 0

# Simulation loop
while t < T_final:
    # Euler integration
    dxdt = swarm_evalf(x, p_tuple, U)
    x = x + dxdt * dt
    
    t += dt
    step += 1
    
    # Update visualization periodically
    if step % frame_skip == 0:
        clear_output(wait=True)
        plt.figure()
        visualizeNetwork(x, p_tuple, t, U, xlim=[-2,3], ylim=[-2,4])
        plt.show()