In [172]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, pmap
import seaborn as sns
from tqdm import tqdm
from evojax.algo import SimpleGA, PGPE
from evosax.problems import BBOBFitness
import seaborn as sns
from evosax import problems as p
from functools import partial
from typing import Callable, Tuple, List, Dict, Any, Optional

from src.plots import plot_multiples

In [130]:
rng = jax.random.PRNGKey(0)
N_DRONES = int(128 ** 0.5) ** 2
N_METERS = 1024

In [131]:
# TYPES
Array = jnp.ndarray
Swarm = Tuple[Array, Array, Array, Array] # (position, velocity, mask, battery)
Terrain = Tuple[Array]
Observation = Tuple[Swarm, Terrain]
Population = List[Swarm]
State = Tuple[Swarm, Terrain]

In [132]:
def random_position(n_drones, rng: Array) -> Array:
    return jax.random.uniform(rng, (n_drones, 3), minval=0, maxval=N_METERS)

def grid_positions(n_drones, dispersion=100) -> Array:
    x_range = jnp.linspace((N_METERS // 2)-dispersion, (N_METERS // 2)+dispersion, int(n_drones ** 0.5))
    x, y = jnp.meshgrid(x_range, x_range)
    position = jnp.stack([x.ravel(), y.ravel(), jnp.zeros_like(x.ravel())], axis=-1)
    return position

def init_drones(position_fn: Callable = grid_positions, n_drones=128, **kwargs) -> Swarm:
    position = position_fn(n_drones, **kwargs)
    velocity = jnp.zeros((position.shape[0], 3))
    battery = jnp.ones((position.shape[0]))
    mask = jnp.ones((position.shape[0])).astype(bool)
    return position, velocity, mask, battery

drones = init_drones(grid_positions)
print(drones[0].shape)  # see the shape of the drones positions (N_DRONES, 3)

(121, 3)


In [133]:
def drone_dists(drones: Swarm) -> Array:
    position, _, mask, _ = drones
    delta = position[:, jnp.newaxis, :] - position[jnp.newaxis, :, :]
    dists = jnp.sqrt(jnp.sum(delta ** 2, axis=-1))
    return dists

def get_neighbours(drones: Swarm, dists: Array, n_neighbours: int = 8) -> Array:
    idxs = jnp.argsort(dists, axis=-1)[:, 1:n_neighbours+1]
    dists = jnp.take_along_axis(dists, idxs, axis=-1)
    neigh = drones[0][idxs] - drones[0][:, jnp.newaxis, :]
    phi = jnp.arctan2(neigh[..., 1], neigh[..., 0])
    theta = jnp.arctan2(neigh[..., 2], neigh[..., 0])
    neigh = jnp.stack([dists, phi, theta], axis=-1)
    return neigh

def observe(drones: Swarm, terrain: Terrain) -> Observation:
    # coords = drones[0].astype(jnp.int32)[:,:2]
    # height = terrain[coords[:, 0], coords[:, 1]]
    dists = drone_dists(drones)
    neighs = get_neighbours(drones, dists)
    obs = neighs # .reshape((drones[0].shape[0], -1))
    return obs, dists

obs, dists = observe(drones, None)
obs.shape, dists.shape

((121, 8, 3), (121, 121))

In [134]:
def action_fn(observation: Array, mask: Array, rng: Array) -> Array:
    action = model(observation, rng)
    action = action * mask[:, jnp.newaxis]
    return action

def model(observation: Array, chromo: Array) -> Array:
    return jnp.matmul(observation, chromo)

In [135]:
def collision_test(drones: Swarm, dists, threshold: float = 1.0) -> Array:
    mask = jnp.triu(dists < threshold, k=1)
    mask = jnp.any(mask, axis=0) | jnp.any(mask, axis=1)
    return ~mask

def step_fn(drones: Swarm, rng: Array) -> Swarm:
    position, velocity, mask, battery = drones
    dists = drone_dists(drones)
    action = action_fn(observe(position), mask, rng)
    velocity = velocity * mask[:, jnp.newaxis]
    position = position + (velocity * action)
    velocity = velocity + action
    battery = battery - jnp.sqrt(jnp.sum(action ** 2, axis=-1))
    return position, velocity, battery

In [136]:
def quantize(drone: Swarm, n_bins: int = 1000) -> Array:
    # position is between 0 and 1000
    position, _, _, _ = drone
    env = jnp.zeros((n_bins, n_bins)).astype(jnp.int32)
    position = position / 1000 * n_bins
    return env.at[jnp.clip(position[:, 0], 0, n_bins - 1).astype(jnp.int32),
                  jnp.clip(position[:, 1], 0, n_bins - 1).astype(jnp.int32)].add(1)


In [174]:
def reward_fn(drones: Swarm, terrain: Terrain) -> Array:
    pass

def generation(drones: Swarm, rng: Array, n_steps=100) -> Swarm:
    position, velocity, mask, battery = drones
    for step in range(n_steps):
        rng, key = jax.random.split(rng)
        mask &= collision_test(position, velocity, mask)
        position, velocity, battery = step_fn(position, velocity, mask, battery, key)
        img = quantize(position)

def evolve(rng: Array, n_generations: int = 100, n_swarms=32, population_size=128) -> Swarm:
    # strategy = CMA_ES(popsize=20, num_dims=2, elite_ratio=0.5)
    population = [init_drones() for _ in range(population_size)]
    strategy = SimpleGA(population, fitness_fn)
    for gen in tqdm(range(n_generations)):
        population = vmap(generation, in_axes=(0, None))(population, rng)
        population = strategy.ask()
    return population

def fitness_fn(params):
    params = jnp.array(params).reshape((8, 3))
    drones = init_drones(position_fn=partial(grid_positions, dispersion=100))
    fitness = generation(params, drones)
    fitness = reward_fn(fitness)
    return fitness

drones = init_drones()

In [173]:

param_size = 8 * 3 + 1
pgpe = PGPE(32, param_size)
num_generations = 50
rng, key = jax.random.split(rng)
for generation in range(num_generations):
    solutions = pgpe.ask()
    fitness_values = jnp.array([fitness_fn(sol) for sol in solutions])
    pgpe.tell(fitness_values)
    best_params = pgpe.best_params

Generation 0, Best Params: [0.00065204 0.00040044]
Generation 1, Best Params: [ 0.00030255 -0.00034209]
Generation 2, Best Params: [-0.00018763 -0.00014254]
Generation 3, Best Params: [ 1.9429062e-04 -2.7376693e-05]
Generation 4, Best Params: [-4.0013625e-04  6.9630820e-05]
Generation 5, Best Params: [-0.00029281 -0.00054618]
Generation 6, Best Params: [-1.26137151e-04  1.20906625e-05]
Generation 7, Best Params: [0.00046332 0.00039844]
Generation 8, Best Params: [-0.00016824  0.00025088]
Generation 9, Best Params: [-0.00023508 -0.0002191 ]
Generation 10, Best Params: [-5.8207705e-05  7.2770554e-04]
Generation 11, Best Params: [-5.1661953e-04 -9.4003626e-06]
Generation 12, Best Params: [-1.2868433e-05 -1.2605096e-04]
Generation 13, Best Params: [-2.846879e-05  3.197082e-04]
Generation 14, Best Params: [ 3.5489429e-04 -5.1322306e-05]
Generation 15, Best Params: [-7.6900818e-05 -3.2153766e-05]
Generation 16, Best Params: [0.00020799 0.00046257]
Generation 17, Best Params: [-3.0939438e-05 