In [75]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, pmap
import seaborn as sns
from tqdm import tqdm
from evosax import CMA_ES
import seaborn as sns
from evosax import problems as p
from functools import partial

from typing import Callable, Tuple, List, Dict, Any, Optional

from src.plots import plot_multiples

In [105]:
rng = jax.random.PRNGKey(0)
N_DRONES = int(128 ** 0.5) ** 2
N_METERS = 1024

In [106]:
# TYPES
Array = jnp.ndarray
Drone = Tuple[Array, Array, Array, Array] # (position, velocity, mask, battery)
Terrain = Tuple[Array]
Observation = Tuple[Drone, Terrain]
State = Tuple[Drone, Terrain]

In [118]:
def random_position(n_drones, rng: Array) -> Array:
    return jax.random.uniform(rng, (n_drones, 3), minval=0, maxval=N_METERS)

def grid_positions(n_drones, dispersion=100) -> Array:
    x_range = jnp.linspace((N_METERS // 2)-dispersion, (N_METERS // 2)+dispersion, int(n_drones ** 0.5))
    x, y = jnp.meshgrid(x_range, x_range)
    position = jnp.stack([x.ravel(), y.ravel(), jnp.zeros_like(x.ravel())], axis=-1)
    return position

def init_drones(position_fn: Callable = grid_positions, n_drones=128, **kwargs) -> Drone:
    position = position_fn(n_drones, **kwargs)
    velocity = jnp.zeros((position.shape[0], 3))
    battery = jnp.ones((position.shape[0]))
    mask = jnp.ones((position.shape[0])).astype(bool)
    return position, velocity, mask, battery

drones = init_drones(grid_positions)
print(drones[0].shape)  # see the shape of the drones positions (N_DRONES, 3)

(121, 3)


In [119]:
def drone_dists(drones: Drone) -> Array:
    position, _, mask, _ = drones
    delta = position[:, jnp.newaxis, :] - position[jnp.newaxis, :, :]
    dists = jnp.sqrt(jnp.sum(delta ** 2, axis=-1))
    return dists

def get_neighbours(drones: Drone, dists: Array, n_neighbours: int = 8) -> Array:
    idxs = jnp.argsort(dists, axis=-1)[:, 1:n_neighbours+1]
    dists = jnp.take_along_axis(dists, idxs, axis=-1)
    neigh = drones[0][idxs] - drones[0][:, jnp.newaxis, :]
    phi = jnp.arctan2(neigh[..., 1], neigh[..., 0])
    theta = jnp.arctan2(neigh[..., 2], neigh[..., 0])
    neigh = jnp.stack([dists, phi, theta], axis=-1)
    return neigh

def observe(drones: Drone, terrain: Terrain) -> Observation:
    # coords = drones[0].astype(jnp.int32)[:,:2]
    # height = terrain[coords[:, 0], coords[:, 1]]
    dists = drone_dists(drones)
    neighs = get_neighbours(drones, dists)
    obs = neighs # .reshape((drones[0].shape[0], -1))
    return obs, dists

obs, dists = observe(drones, None)
obs.shape, dists.shape

((121, 8, 3), (121, 121))

In [120]:
def action_fn(observation: Array, mask: Array, rng: Array) -> Array:
    action = model(observation, rng)
    action = action * mask[:, jnp.newaxis]
    return action

def model(observation: Array, chromo: Array) -> Array:
    return jnp.matmul(observation, chromo)

In [121]:
def collision_test(drones: Drone, dists, threshold: float = 1.0) -> Array:
    mask = jnp.triu(dists < threshold, k=1)
    mask = jnp.any(mask, axis=0) | jnp.any(mask, axis=1)
    return ~mask

def step_fn(drones: Drone, rng: Array) -> Drone:
    position, velocity, mask, battery = drones
    dists = drone_dists(drones)
    action = action_fn(observe(position), mask, rng)
    velocity = velocity * mask[:, jnp.newaxis]
    position = position + (velocity * action)
    velocity = velocity + action
    battery = battery - jnp.sqrt(jnp.sum(action ** 2, axis=-1))
    return position, velocity, battery

In [122]:
def quantize(drone: Drone, n_bins: int = 1000) -> Array:
    # position is between 0 and 1000
    position, _, _, _ = drone
    env = jnp.zeros((n_bins, n_bins)).astype(jnp.int32)
    position = position / 1000 * n_bins
    return env.at[jnp.clip(position[:, 0], 0, n_bins - 1).astype(jnp.int32),
                  jnp.clip(position[:, 1], 0, n_bins - 1).astype(jnp.int32)].add(1)


In [126]:
def fitness_fn(drones: Drone, terrain: Terrain) -> Array:
    position, _, mask, battery = drones
    coords = position.astype(jnp.int32)[:,:2]
    height = terrain[coords[:, 0], coords[:, 1]]
    reward = jnp.where(mask, height, 0)
    return reward

def generation(drones: Drone, rng: Array, n_steps=100) -> Drone:
    position, velocity, mask, battery = drones
    for step in range(n_steps):
        rng, key = jax.random.split(rng)
        mask &= collision_test(position, velocity, mask)
        position, velocity, battery = step_fn(position, velocity, mask, battery, key)
        img = quantize(position)

def evolve(rng: Array, n_generations: int = 100, n_swarms=32, population_size=128) -> Drone:
    # strategy = CMA_ES(popsize=20, num_dims=2, elite_ratio=0.5)
    population = [init_drones() for _ in range(population_size)]
    for gen in tqdm(range(n_generations)):
        # vmap over the population
        population = vmap(generation, in_axes=(0, None))(population, rng)
    return population

drones = init_drones()
evolve(rng)

(121, 3) (121, 3) (121,) (121,)
None


[(Array([[412., 412.,   0.],
         [432., 412.,   0.],
         [452., 412.,   0.],
         [472., 412.,   0.],
         [492., 412.,   0.],
         [512., 412.,   0.],
         [532., 412.,   0.],
         [552., 412.,   0.],
         [572., 412.,   0.],
         [592., 412.,   0.],
         [612., 412.,   0.],
         [412., 432.,   0.],
         [432., 432.,   0.],
         [452., 432.,   0.],
         [472., 432.,   0.],
         [492., 432.,   0.],
         [512., 432.,   0.],
         [532., 432.,   0.],
         [552., 432.,   0.],
         [572., 432.,   0.],
         [592., 432.,   0.],
         [612., 432.,   0.],
         [412., 452.,   0.],
         [432., 452.,   0.],
         [452., 452.,   0.],
         [472., 452.,   0.],
         [492., 452.,   0.],
         [512., 452.,   0.],
         [532., 452.,   0.],
         [552., 452.,   0.],
         [572., 452.,   0.],
         [592., 452.,   0.],
         [612., 452.,   0.],
         [412., 472.,   0.],
         [432.