In [2]:
import tensorflow as tf

In [160]:
from typing import Callable


class CoralReef:
    def __init__(
        self,
        fitness_fn: Callable,
        dtype: tf.dtypes.DType,
        dim: int,
        domain: tuple[tf.dtypes.DType, tf.dtypes.DType],
        mutation_domain: tuple[tf.dtypes.DType, tf.dtypes.DType],
        settling_trials: int = 10,
        frac_init_alive: float = 0.2,
        n_corals: int = 100,
        fract_broadcast: float = 0.5,
        fract_duplication: float = 0.1,
        prob_die: float = 0.01,
    ):
        self.fitness_fn = fitness_fn
        self.dtype = dtype
        self.grid_values = tf.zeros([n_corals, dim], dtype=self.dtype)
        self.grid_alive = tf.zeros(n_corals, dtype=tf.bool)
        self.domain = domain
        self.mutation_domain = mutation_domain
        self.settling_trials = settling_trials

        # initialize alive corals - there are exactly frac_init_alive alive corals
        alive_ind = tf.random.shuffle(tf.range(n_corals))[
            : int(frac_init_alive * n_corals)
        ]
        alive_ind = tf.reshape(alive_ind, [-1, 1])
        self.grid_values = tf.tensor_scatter_nd_update(
            self.grid_values,
            alive_ind,
            tf.random.uniform(
                [len(alive_ind), dim],
                minval=domain[0],
                maxval=domain[1],
                dtype=self.dtype,
            ),
        )
        self.grid_alive = tf.tensor_scatter_nd_update(
            self.grid_alive, alive_ind, tf.ones(len(alive_ind), dtype=tf.bool)
        )
        self.fract_broadcast = fract_broadcast
        self.fract_duplication = fract_duplication
        self.prob_die = prob_die

    def _broadcast_spawning(self, alive: tf.Tensor, n_broadcasters: int) -> tf.Tensor:
        broadcasters = tf.gather(self.grid_values, alive[:n_broadcasters])

        crossover_points = tf.random.uniform(
            [n_broadcasters // 2],
            minval=0,
            maxval=broadcasters.shape[-1],
            dtype=tf.int32,
        )
        crossover_points = tf.repeat(crossover_points, 2)
        crossover_points = tf.reshape(crossover_points, [-1, 1])
        crossover_points = tf.repeat(crossover_points, broadcasters.shape[1], axis=1)
        broadcasters_swap = tf.reshape(broadcasters, [-1, 2, broadcasters.shape[-1]])
        broadcasters_swap = tf.reverse(broadcasters_swap, [1])
        broadcasters_swap = tf.reshape(broadcasters_swap, [-1, broadcasters.shape[-1]])

        indexing = tf.range(broadcasters.shape[1])
        indexing = tf.reshape(indexing, [1, -1])
        indexing = tf.repeat(indexing, broadcasters.shape[0], axis=0)

        mask = indexing <= crossover_points

        new_corals = tf.where(mask, broadcasters, broadcasters_swap)
        return new_corals

    def _brooding(self, alive: tf.Tensor, n_broadcasters: int) -> tf.Tensor:
        brooders = tf.gather(self.grid_values, alive[n_broadcasters:])
        mutation = tf.random.uniform(
            brooders.shape,
            minval=self.mutation_domain[0],
            maxval=self.mutation_domain[1],
            dtype=self.dtype,
        )
        new_corals = brooders + mutation
        new_corals = tf.clip_by_value(new_corals, self.domain[0], self.domain[1])
        return new_corals
    
    def _larvae_settling(self, new_corals: tf.Tensor) -> tf.Tensor:
        for _ in range(self.settling_trials):
            indices = tf.random.shuffle(tf.range(len(self.grid_values)))[:len(new_corals)]
            old_fitness = tf.map_fn(self.fitness_fn, self.grid_values)
            # try to settle new corals on the grid if they are better than the old ones or if the grid is empty
            


    def step(self):
        # 0. selection for 1. and 2.
        # 1. broadcast spawning (crossover)
        # 2. brooding (mutation)
        # 3. larvae settling
        # 4. asexual reproduction (duplication) + larvae settling
        # 5. depredation ()

        alive = tf.reshape(tf.where(self.grid_alive), [-1])
        tf.random.shuffle(alive)

        n_broadcasters = int(self.fract_broadcast * len(alive)) // 2 * 2
        broadcasted = self._broadcast_spawning(alive, n_broadcasters)
        brooded = self._brooding(alive, n_broadcasters)

        settling_candidates = tf.concat([broadcasted, brooded], axis=0)
        self._larvae_settling(settling_candidates)


reef = CoralReef(
    fitness_fn=lambda x: -x,
    n_corals=100,
    domain=(0, 10),
    mutation_domain=(-1, 1),
    dim=4,
    dtype=tf.int32,
)
reef.step()

tf.Tensor(
[[ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [-8 -1 -6 -5]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [-3 -7 -5  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [-9 -4 -1 -5]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [-8 -1 -9 -5]
 [ 0  0  0  0]
 [-1 -1 -2 -7]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [-3 -1 -3 -7]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0 -8 -8 -1]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [-2 -6 -2 -1]
 [-8 -5 -2  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [-6 -3 -9 -1]
 [ 0 -1 -2 -4]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [-4 -7 -7 -1]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [-7 -3 -6 -6]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [-3 -4 -3 -1]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]
 [ 0  0  0  0]