# sortsol

In [None]:
import syrkis

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random
from jax.lax import scatter_add, ScatterDimensionNumbers
from evosax import OpenES

from functools import partial
from typing import Callable, Tuple, List, Dict, Any, Optional

In [None]:
POPSIZE  = 100
N_DRONES = 16
kernel   = jnp.ones((3, 3))[:, :, None, None]

In [None]:
# TYPES
Array = jnp.ndarray

In [None]:
def grid_positions(n_drones) -> Array:
    # n_drones is power of 2
    lo, hi = - int(n_drones ** 0.5), int(n_drones ** 0.5)
    span   = jnp.linspace(lo, hi, int(n_drones ** 0.5))
    x, y   = jnp.meshgrid(span, span)
    coords = jnp.stack([x.ravel(), y.ravel()], axis=1)
    return coords

def init_swarm(n_drones: int) -> Array:
    coords   = grid_positions(n_drones)
    velocity = jnp.zeros_like(coords)
    swarm    = jnp.concatenate([coords, velocity], axis=1)
    return swarm

def conv2d(x, w):
    return jax.lax.conv_general_dilated(x, w, window_strides=(1, 1),padding='SAME',dimension_numbers=("NHWC", "HWIO", "NHWC")).squeeze()

def positions(swarm: Array) -> Array:  # TODO: sort by distance (probably)
    delta        = swarm[:, :N_DRONES][:, jnp.newaxis] - swarm[:, :N_DRONES][jnp.newaxis, :]                      
    other_dists  = jnp.sqrt(jnp.sum(delta ** 2, axis=-1))                          # others
    other_theta  = jnp.arctan2(delta[..., 1], delta[..., 0])
    idx          = jnp.argsort(other_dists, axis=-1)[:, 1:]                        # sort others and remove self
    other_dists  = other_dists[jnp.arange(other_dists.shape[0])[:, None], idx]     # origin
    other_theta  = other_theta[jnp.arange(other_theta.shape[0])[:, None], idx]
    origin_dist  = jnp.sqrt(jnp.sum(swarm[:, :N_DRONES] ** 2, axis=-1))
    origin_theta = jnp.arctan2(swarm[:, :N_DRONES][..., 1], swarm[:, :N_DRONES][..., 0])
    return origin_theta, origin_dist, other_dists, other_theta

def act(rng, params, swarm: Array, momentum=0.6) -> Array:
    # swarm : n_drones x 4 (x, y, dx, dy)
    coords, velocity = swarm.split(2, axis=1)
    new_coords       = coords + velocity
    obs              = observe(swarm)
    theta, speed     = model(rng, params, obs)  # n_drones x 2 (theta, speed)
    delta            = jnp.stack((speed * jnp.cos(theta), speed * jnp.sin(theta)), axis=1).squeeze()
    new_velocity     = momentum * speed + (1 - momentum) * delta
    new_velocity     = jax.nn.tanh(new_velocity) * 2
    return jnp.concatenate([new_coords, new_velocity], axis=1)

def model(rng, params, obs) -> Array:
    # given an observation, predict the next action
    noise  = random.normal(rng, (obs.shape[0], 2)) * 0.1
    action = jnp.tanh(jnp.dot(obs, params) + noise)
    return jnp.split(action, 2, axis=1)

def loss_fn(image):
    # minimise number of drones in the same pixel
    return 1 - jnp.sum(image) / (N_DRONES * 9)

def fitness_fn(genome, rng, n_steps=100) -> Array:
    fitness = 0
    swarm = init_swarm(N_DRONES)
    for _ in range(n_steps):
        rng, key = random.split(rng)
        swarm    = act(key, genome, swarm)
        image    = quantize(swarm)
        fitness += loss_fn(image)
    return fitness / n_steps

def observe(swarm: Array) -> Array:
    origin_theta, origin_dist, other_dists, other_theta = positions(swarm)
    thetas = jnp.concatenate([origin_theta[:, None], other_theta], axis=1)  # n_drones x 1 + n_drones
    dists  = jnp.concatenate([origin_dist[:, None], other_dists], axis=1)   # n_drones x 1 + n_drones
    obs    = jnp.concatenate([thetas, dists], axis=1)                       # n_drones x 2 + 2 * n_drones
    return obs

def resize(source, target):
    # make source have same dimensions as target (for plotting)
    return jax.image.resize(source, (target.shape[0], target.shape[1]), method='nearest')

def quantize(swarm: Array, resolution: int = 256) -> Array:
    coords = (jnp.split(swarm, 2, axis=1)[0] * int(resolution ** 0.5)).astype(int).T
    image = jnp.zeros((resolution, resolution))
    matrix = image.at[coords].set(1)
    return matrix

params    = jnp.zeros((N_DRONES * 2, 2))
rng, key  = random.split(random.PRNGKey(0))
strategy  = OpenES(popsize=POPSIZE, pholder_params=params)   # for each drone, theta and speed of others, to theta and speed of self
es_params = strategy.default_params
state     = strategy.initialize(key, es_params)

for _ in range(100):
    rng, key       = random.split(rng)
    genome, state  = strategy.ask(key, state, es_params)
    fitnesses      = vmap(fitness_fn, in_axes=(0, None))(genome, rng)