# Boids - Flocking model to study impact of communication constraints



The setup code below is taken from Google's example Boids model implemented using JAX.

https://colab.research.google.com/github/google/jax-md/blob/master/notebooks/flocking.ipynb

Copyright 2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [1]:
#@title Imports & Utils

# Imports

# !pip install -q git+https://www.github.com/google/jax-md

import numpy as onp

# from jax.config import config ; config.update('jax_enable_x64', True)
import jax.numpy as np
from jax import random
from jax import jit
from jax import vmap
from jax import lax
vectorize = np.vectorize

from functools import partial

from collections import namedtuple
import base64

import IPython
from google.colab import output

import os

from jax_md import space, smap, energy, minimize, quantity, simulate, partition, util
from jax_md.util import f32

# Plotting

import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style(style='white')

dark_color = [56 / 256] * 3
light_color = [213 / 256] * 3
axis_color = 'white'

def format_plot(x='', y='', grid=True):  
  ax = plt.gca()
  
  ax.spines['bottom'].set_color(axis_color)
  ax.spines['top'].set_color(axis_color) 
  ax.spines['right'].set_color(axis_color)
  ax.spines['left'].set_color(axis_color)
  
  ax.tick_params(axis='x', colors=axis_color)
  ax.tick_params(axis='y', colors=axis_color)
  ax.yaxis.label.set_color(axis_color)
  ax.xaxis.label.set_color(axis_color)
  ax.set_facecolor(dark_color)
  
  plt.grid(grid)
  plt.xlabel(x, fontsize=20)
  plt.ylabel(y, fontsize=20)
  
def finalize_plot(shape=(1, 1)):
  plt.gcf().patch.set_facecolor(dark_color)
  plt.gcf().set_size_inches(
    shape[0] * 1.5 * plt.gcf().get_size_inches()[1], 
    shape[1] * 1.5 * plt.gcf().get_size_inches()[1])
  plt.tight_layout()

# Progress Bars

from IPython.display import HTML, display
import time

def ProgressIter(iter_fun, iter_len=0):
  if not iter_len:
    iter_len = len(iter_fun)
  out = display(progress(0, iter_len), display_id=True)
  for i, it in enumerate(iter_fun):
    yield it
    out.update(progress(i + 1, iter_len))

def progress(value, max):
    return HTML("""
        <progress
            value='{value}'
            max='{max}',
            style='width: 45%'
        >
            {value}
        </progress>
    """.format(value=value, max=max))

normalize = lambda v: v / np.linalg.norm(v, axis=1, keepdims=True)

# Rendering

renderer_code = IPython.display.HTML('''
<canvas id="canvas"></canvas>
<script>
  Rg = null;
  Ng = null;

  var current_scene = {
      R: null,
      N: null,
      is_loaded: false,
      frame: 0,
      frame_count: 0,
      boid_vertex_count: 0,
      boid_buffer: [],
      predator_vertex_count: 0,
      predator_buffer: [],
      disk_vertex_count: 0,
      disk_buffer: null,
      box_size: 0
  };

  google.colab.output.setIframeHeight(0, true, {maxHeight: 5000});

  async function load_simulation() {
    buffer_size = 400;
    max_frame = 800;

    result = await google.colab.kernel.invokeFunction(
        'notebook.GetObstacles', [], {});
    data = result.data['application/json'];

    if(data.hasOwnProperty('Disk')) {
      current_scene = put_obstacle_disk(current_scene, data.Disk);
    }

    for (var i = 0 ; i < max_frame ; i += buffer_size) {
      console.log(i);
      result = await google.colab.kernel.invokeFunction(
          'notebook.GetBoidStates', [i, i + buffer_size], {}); 
      
      data = result.data['application/json'];
      current_scene = put_boids(current_scene, data);
    }
    current_scene.is_loaded = true;

    result = await google.colab.kernel.invokeFunction(
        'notebook.GetPredators', [], {}); 
    data = result.data['application/json'];
    if (data.hasOwnProperty('R'))
      current_scene = put_predators(current_scene, data);

    result = await google.colab.kernel.invokeFunction(
          'notebook.GetSimulationInfo', [], {});
    current_scene.box_size = result.data['application/json'].box_size;
  }

  function initialize_gl() {
    const canvas = document.getElementById("canvas");
    canvas.width = 640;
    canvas.height = 640;

    const gl = canvas.getContext("webgl2");

    if (!gl) {
        alert('Unable to initialize WebGL.');
        return;
    }

    gl.viewport(0, 0, gl.drawingBufferWidth, gl.drawingBufferHeight);
    gl.clearColor(0.2, 0.2, 0.2, 1.0);
    gl.enable(gl.DEPTH_TEST);

    const shader_program = initialize_shader(
        gl, VERTEX_SHADER_SOURCE_2D, FRAGMENT_SHADER_SOURCE_2D);
    const shader = {
      program: shader_program,
      attribute: {
          vertex_position: gl.getAttribLocation(shader_program, 'vertex_position'),
      },
      uniform: {
          screen_position: gl.getUniformLocation(shader_program, 'screen_position'),
          screen_size: gl.getUniformLocation(shader_program, 'screen_size'),
          color: gl.getUniformLocation(shader_program, 'color'),
      },
    };
    gl.useProgram(shader_program);

    const half_width = 200.0;

    gl.uniform2f(shader.uniform.screen_position, half_width, half_width);
    gl.uniform2f(shader.uniform.screen_size, half_width, half_width);
    gl.uniform4f(shader.uniform.color, 0.9, 0.9, 1.0, 1.0);

    return {gl: gl, shader: shader};
  }

  var loops = 0;

  function update_frame() {
    gl.clear(gl.COLOR_BUFFER_BIT | gl.DEPTH_BUFFER_BIT);

    if (!current_scene.is_loaded) {
      window.requestAnimationFrame(update_frame);
      return;
    }

    var half_width = current_scene.box_size / 2.;
    gl.uniform2f(shader.uniform.screen_position, half_width, half_width);
    gl.uniform2f(shader.uniform.screen_size, half_width, half_width);

    if (current_scene.frame >= current_scene.frame_count) {
      if (!current_scene.is_loaded) {
        window.requestAnimationFrame(update_frame);
        return;
      }
      loops++;
      current_scene.frame = 0;
    }

    gl.enableVertexAttribArray(shader.attribute.vertex_position);

    gl.bindBuffer(gl.ARRAY_BUFFER, current_scene.boid_buffer[current_scene.frame]);
    gl.uniform4f(shader.uniform.color, 0.0, 0.35, 1.0, 1.0);
    gl.vertexAttribPointer(
      shader.attribute.vertex_position,
      2,
      gl.FLOAT,
      false,
      0,
      0
    );
    gl.drawArrays(gl.TRIANGLES, 0, current_scene.boid_vertex_count);

    if(current_scene.predator_buffer.length > 0)  {
      gl.bindBuffer(gl.ARRAY_BUFFER, current_scene.predator_buffer[current_scene.frame]);
      gl.uniform4f(shader.uniform.color, 1.0, 0.35, 0.35, 1.0);
      gl.vertexAttribPointer(
        shader.attribute.vertex_position,
        2,
        gl.FLOAT,
        false,
        0,
        0
      );
      gl.drawArrays(gl.TRIANGLES, 0, current_scene.predator_vertex_count);
    }
    
    if(current_scene.disk_buffer) {
      gl.bindBuffer(gl.ARRAY_BUFFER, current_scene.disk_buffer);
      gl.uniform4f(shader.uniform.color, 0.9, 0.9, 1.0, 1.0);
      gl.vertexAttribPointer(
        shader.attribute.vertex_position,
        2,
        gl.FLOAT,
        false,
        0,
        0
      );
      gl.drawArrays(gl.TRIANGLES, 0, current_scene.disk_vertex_count);
    }

    current_scene.frame++;
    if ((current_scene.frame_count > 1 && loops < 5) || 
        (current_scene.frame_count == 1 && loops < 240))
      window.requestAnimationFrame(update_frame);
    
    if (current_scene.frame_count > 1 && loops == 5 && current_scene.frame < current_scene.frame_count - 1)
      window.requestAnimationFrame(update_frame);
  }

  function put_boids(scene, boids) {
    const R = decode(boids['R']);
    const R_shape = boids['R_shape'];
    const theta = decode(boids['theta']);
    const theta_shape = boids['theta_shape'];

    function index(i, b, xy) {
      return i * R_shape[1] * R_shape[2] + b * R_shape[2] + xy; 
    }

    var steps = R_shape[0];
    var boids = R_shape[1];
    var dimensions = R_shape[2];

    if(dimensions != 2) {
      alert('Can only deal with two-dimensional data.')
    }

    // First flatten the data.
    var buffer_data = new Float32Array(boids * 6);
    var size = 8.0;
    for (var i = 0 ; i < steps ; i++) {
      var buffer = gl.createBuffer();
      for (var b = 0 ; b < boids ; b++) {
        var xi = index(i, b, 0);
        var yi = index(i, b, 1);
        var ti = i * boids + b;
        var Nx = size * Math.cos(theta[ti]); //N[xi];
        var Ny = size * Math.sin(theta[ti]); //N[yi];
        buffer_data.set([
          R[xi] + Nx, R[yi] + Ny,
          R[xi] - Nx - 0.5 * Ny, R[yi] - Ny + 0.5 * Nx,
          R[xi] - Nx + 0.5 * Ny, R[yi] - Ny - 0.5 * Nx,             
        ], b * 6);
      }
      gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
      gl.bufferData(gl.ARRAY_BUFFER, buffer_data, gl.STATIC_DRAW);

      scene.boid_buffer.push(buffer);
    }
    scene.boid_vertex_count = boids * 3;
    scene.frame_count += steps;
    return scene;
  }

  function put_predators(scene, boids) {
    // TODO: Unify this with the put_boids function.
    const R = decode(boids['R']);
    const R_shape = boids['R_shape'];
    const theta = decode(boids['theta']);
    const theta_shape = boids['theta_shape'];

    function index(i, b, xy) {
      return i * R_shape[1] * R_shape[2] + b * R_shape[2] + xy; 
    }

    var steps = R_shape[0];
    var boids = R_shape[1];
    var dimensions = R_shape[2];

    if(dimensions != 2) {
      alert('Can only deal with two-dimensional data.')
    }

    // First flatten the data.
    var buffer_data = new Float32Array(boids * 6);
    var size = 18.0;
    for (var i = 0 ; i < steps ; i++) {
      var buffer = gl.createBuffer();
      for (var b = 0 ; b < boids ; b++) {
        var xi = index(i, b, 0);
        var yi = index(i, b, 1);
        var ti = theta_shape[1] * i + b;
        var Nx = size * Math.cos(theta[ti]);
        var Ny = size * Math.sin(theta[ti]);
        buffer_data.set([
          R[xi] + Nx, R[yi] + Ny,
          R[xi] - Nx - 0.5 * Ny, R[yi] - Ny + 0.5 * Nx,
          R[xi] - Nx + 0.5 * Ny, R[yi] - Ny - 0.5 * Nx,             
        ], b * 6);
      }
      gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
      gl.bufferData(gl.ARRAY_BUFFER, buffer_data, gl.STATIC_DRAW);

      scene.predator_buffer.push(buffer);
    }
    scene.predator_vertex_count = boids * 3;
    return scene;
  }

  function put_obstacle_disk(scene, disk) {
    const R = decode(disk.R);
    const R_shape = disk.R_shape;
    const radius = decode(disk.D);
    const radius_shape = disk.D_shape;

    const disk_count = R_shape[0];
    const dimensions = R_shape[1];
    if (dimensions != 2) {
        alert('Can only handle two-dimensional data.');
    }
    if (radius_shape[0] != disk_count) {
        alert('Inconsistent disk radius count found.');
    }
    const segments = 32;

    function index(o, xy) {
        return o * R_shape[1] + xy;
    }

    // TODO(schsam): Use index buffers here.
    var buffer_data = new Float32Array(disk_count * segments * 6);
    for (var i = 0 ; i < disk_count ; i++) {
      var xi = index(i, 0);
      var yi = index(i, 1);
      for (var s = 0 ; s < segments ; s++) {
        const th = 2 * s / segments * Math.PI;
        const th_p = 2 * (s + 1) / segments * Math.PI;
        const rad = radius[i] * 0.8;
        buffer_data.set([
          R[xi], R[yi],
          R[xi] + rad * Math.cos(th), R[yi] + rad * Math.sin(th),
          R[xi] + rad * Math.cos(th_p), R[yi] + rad * Math.sin(th_p),
        ], i * segments * 6 + s * 6);
      }
    }
    var buffer = gl.createBuffer();
    gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
    gl.bufferData(gl.ARRAY_BUFFER, buffer_data, gl.STATIC_DRAW);
    scene.disk_vertex_count = disk_count * segments * 3;
    scene.disk_buffer = buffer;
    return scene;
  }

  // SHADER CODE

  const VERTEX_SHADER_SOURCE_2D = `
    // Vertex Shader Program.
    attribute vec2 vertex_position;
    
    uniform vec2 screen_position;
    uniform vec2 screen_size;

    void main() {
      vec2 v = (vertex_position - screen_position) / screen_size;
      gl_Position = vec4(v, 0.0, 1.0);
    }
  `;

  const FRAGMENT_SHADER_SOURCE_2D = `
    precision mediump float;

    uniform vec4 color;

    void main() {
      gl_FragColor = color;
    }
  `;

  function initialize_shader(
    gl, vertex_shader_source, fragment_shader_source) {

    const vertex_shader = compile_shader(
      gl, gl.VERTEX_SHADER, vertex_shader_source);
    const fragment_shader = compile_shader(
      gl, gl.FRAGMENT_SHADER, fragment_shader_source);

    const shader_program = gl.createProgram();
    gl.attachShader(shader_program, vertex_shader);
    gl.attachShader(shader_program, fragment_shader);
    gl.linkProgram(shader_program);

    if (!gl.getProgramParameter(shader_program, gl.LINK_STATUS)) {
      alert(
        'Unable to initialize shader program: ' + 
        gl.getProgramInfoLog(shader_program)
        );
        return null;
    }
    return shader_program;
  }

  function compile_shader(gl, type, source) {
    const shader = gl.createShader(type);
    gl.shaderSource(shader, source);
    gl.compileShader(shader);

    if (!gl.getShaderParameter(shader, gl.COMPILE_STATUS)) {
      alert('An error occured compiling shader: ' + gl.getShaderInfoLog(shader));
      gl.deleteShader(shader);
      return null;
    }

    return shader;
  }

  // SERIALIZATION UTILITIES
  function decode(sBase64, nBlocksSize) {
    var chrs = atob(atob(sBase64));
    var array = new Uint8Array(new ArrayBuffer(chrs.length));

    for(var i = 0 ; i < chrs.length ; i++) {
      array[i] = chrs.charCodeAt(i);
    }

    return new Float32Array(array.buffer);
  }

  // RUN CELL

  load_simulation();
  gl_and_shader = initialize_gl();
  var gl = gl_and_shader.gl;
  var shader = gl_and_shader.shader;
  update_frame();
</script>
''')

def encode(R):
  return base64.b64encode(onp.array(R, onp.float32).tobytes())

def render(box_size, states, obstacles=None, predators=None):
  if isinstance(states, Boids):
    R = np.reshape(states.R, (1,) + states.R.shape)
    theta = np.reshape(states.theta, (1,) + states.theta.shape)
  elif isinstance(states, list):
    if all([isinstance(x, Boids) for x in states]):
      R, theta = zip(*states)
      R = onp.stack(R)
      theta = onp.stack(theta)    
  
  if isinstance(predators, list):
    R_predators, theta_predators, *_ = zip(*predators)
    R_predators = onp.stack(R_predators)
    theta_predators = onp.stack(theta_predators)

  def get_boid_states(start, end):
    R_, theta_ = R[start:end], theta[start:end]
    return IPython.display.JSON(data={
        "R_shape": R_.shape,
        "R": encode(R_), 
        "theta_shape": theta_.shape,
        "theta": encode(theta_)
        })
  output.register_callback('notebook.GetBoidStates', get_boid_states)

  def get_obstacles():
    if obstacles is None:
      return IPython.display.JSON(data={})
    else:
      return IPython.display.JSON(data={
          'Disk': {
              'R': encode(obstacles.R),
              'R_shape': obstacles.R.shape,
              'D': encode(obstacles.D),
              'D_shape': obstacles.D.shape
          }
      })
  output.register_callback('notebook.GetObstacles', get_obstacles)

  def get_predators():
    if predators is None:
      return IPython.display.JSON(data={})
    else:
      return IPython.display.JSON(data={
          'R': encode(R_predators),
          'R_shape': R_predators.shape,
          'theta': encode(theta_predators),
          'theta_shape': theta_predators.shape
      })
  output.register_callback('notebook.GetPredators', get_predators)

  def get_simulation_info():
    return IPython.display.JSON(data={
        'frames': R.shape[0],
        'box_size': box_size
        })
  output.register_callback('notebook.GetSimulationInfo', get_simulation_info)

  return renderer_code

# Flocks, Herds, and Schools: A Distributed Behavioral Model

We will go over the paper, ["Flocks, Herds, and Schools: A Distributed Behavioral Model"]((https://citeseerx.ist.psu.edu/viewdoc/download;jsessionid=E252054B1C02D387E8C20827CB414543?doi=10.1.1.103.7187&rep=rep1&type=pdf)) published by C. W. Reynolds in SIGGRAPH 1987. The paper itself is fantastic and, as far as a description of flocking is concerned, there is little that we can offer. Therefore, rather than go through the paper directly, we will use [JAX](https://www.github.com/google/jax) and [JAX, MD](https://www.github.com/google/jax-md) to interactively build a simulation similar to Reynolds' in colab. To simplify our discussion, we will build a two-dimensional version of Reynolds' simulation.

In nature there are many examples in which large numbers of animals exhibit complex collective motion (schools of fish, flocks of birds, herds of horses, colonies of ants). In his seminal paper, Reynolds introduces a model of such collective behavior (henceforth refered to as "flocking") based on simple rules that can be computed locally for each entity (referred to as a "boid") in the flock based on its environment. This paper is written in the context of computer graphics and so Reynolds is going for biologically inspired simulations that look right rather than accuracy in any statistical sense. Ultimately, Reynolds measures success in terms of "delight" people find in watching the simulations; we will use a similar metric here.

Note, we recommend running this notebook in "Dark" mode.

## Boids

Reynolds is interested in simulating bird-like entities that are described by a position, $R$, and an orientation, $\theta$. This state can optionally augmented with extra information (for example, hunger or fear). We can define a Boids type that stores data for a collection of boids as two arrays. `R` is an `ndarray` of shape `[boid_count, spatial_dimension]` and `theta` is an ndarray of shape `[boid_count]`. An individual boid is an index into these arrays. It will often be useful to refer to the vector orientation of the boid $N = (\cos\theta, \sin\theta)$.

In [2]:
Boids = namedtuple('Boids', ['R', 'theta'])

We can instantiate a collection of boids randomly in a box of side length $L$. We will use [periodic boundary conditions](https://en.wikipedia.org/wiki/Periodic_boundary_conditions) for our simulation which means that boids will be able to wrap around the sides of the box. To do this we will use the `space.periodic` command in [JAX, MD](https://github.com/google/jax-md#spaces-spacepy).

In [3]:
# Simulation Parameters:
box_size = 800.0  # A float specifying the side-length of the box.
boid_count = 200  # An integer specifying the number of boids.
dim = 2  # The spatial dimension in which we are simulating.

# Create RNG state to draw random numbers (see LINK).
rng = random.PRNGKey(0)

# Define periodic boundary conditions.
displacement, shift = space.periodic(box_size)

# Initialize the boids.
rng, R_rng, theta_rng = random.split(rng, 3)

boids = Boids(
    R = box_size * random.uniform(R_rng, (boid_count, dim)),
    theta = random.uniform(theta_rng, (boid_count,), maxval=2. * np.pi)
)

In [4]:
@vmap
def normal(theta):
  return np.array([np.cos(theta), np.sin(theta)])

def dynamics(energy_fn, dt, speed):
  @jit
  def update(_, state):
    R, theta = state['boids']

    dstate = quantity.force(energy_fn)(state)
    dR, dtheta = dstate['boids']
    n = normal(state['boids'].theta)

    state['boids'] = Boids(shift(R, dt * (speed * n + dR)), 
                           theta + dt * dtheta)

    return state

  return update

In [5]:
update = dynamics(energy_fn=lambda state: 0., dt=1e-1, speed=1.)

boids_buffer = []

state = {
    'boids': boids
}

for i in ProgressIter(range(400)):
  state = lax.fori_loop(0, 50, update, state)
  boids_buffer += [state['boids']]

# display(render(box_size, boids_buffer))

In [6]:
# Michael's implementation of the display code, for non-Google colab webgl-accelerated notebooks. HTML output works in VSCode.
from matplotlib import animation, rc, rcParams
rcParams['animation.embed_limit'] = 2**128

fig, (ax1) = plt.subplots(nrows = 1, ncols = 1, figsize=(5,5), dpi=120, facecolor='w', edgecolor='k')
ax1.set_xlim([0, box_size])
ax1.set_ylim([0, box_size])

# Set how data is plotted within animation loop
# Agent plotting
line, = ax1.plot([], [], color="xkcd:cerulean blue", marker=(3, 0, 1), markersize = 6, markeredgecolor="none", linestyle="None", alpha = 0.9)
# shadow plotting
# line1, = ax1.plot([], [], 'bh', markersize = 6, markeredgecolor="black", alpha = 0.2)
# line2, = ax1.plot([], [], 'bh', markersize = 6, markeredgecolor="black", alpha = 0.2)
# line3, = ax1.plot([], [], 'bh', markersize = 6, markeredgecolor="black", alpha = 0.2)

# low_emp, = ax1.plot([], [], 'rh', markersize = 6, markeredgecolor="black", alpha = 0.9)
# high_emp, = ax1.plot([], [], 'gh', markersize = 6, markeredgecolor="green", alpha = 0.9)

fsize = 12

# time_text = ax1.text(-20, 42, '', fontsize = fsize)
# box_text = ax1.text(3, 42, '', color = 'red', fontsize = fsize)
# cov_text = ax1.text(20, 42, '', color = 'blue', fontsize = fsize)

# empow_text = ax1.text(-40, 46, '', color = 'purple', fontsize = fsize)

def init():
    line.set_data([], [])
    return (line,)

timesteps = 400

def animate(i):
    line.set_data(boids_buffer[i].R.T[0],boids_buffer[i].R.T[1])
    return (line,)#time_text, box_text, cov_text)

# Display code
# Note: below is the part which makes it work on Colab
rc('animation', html='jshtml')
anim = animation.FuncAnimation(
    fig, animate, init_func=init,
    frames=timesteps, interval=10, blit=True, repeat = True
)
fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None)
fig.set_facecolor('xkcd:dark gray')
plt.axis('off')
plt.close() # update

In [7]:
display(HTML(anim.to_jshtml()))