In [184]:
import tensorflow as tf

In [246]:
from typing import Callable


class CoralReef:
    def __init__(
        self,
        fitness_fn: Callable,
        dtype: tf.dtypes.DType,
        dim: int,
        domain: tuple[tf.dtypes.DType, tf.dtypes.DType],
        mutation_domain: tuple[tf.dtypes.DType, tf.dtypes.DType],
        settling_trials: int = 10,
        frac_init_alive: float = 0.2,
        n_corals: int = 100,
        fract_broadcast: float = 0.5,
        fract_duplication: float = 0.1,
        prob_die: float = 0.01,
    ):
        self.fitness_fn = fitness_fn
        self.dtype = dtype
        self.grid_values = tf.zeros([n_corals, dim], dtype=self.dtype)
        self.grid_fitness = tf.zeros(n_corals, dtype=tf.float32)
        self.grid_alive = tf.zeros(n_corals, dtype=tf.bool)
        self.domain = domain
        self.mutation_domain = mutation_domain
        self.settling_trials = settling_trials
        self.frac_duplication = fract_duplication

        alive_ind = tf.random.shuffle(tf.range(n_corals))[
            : int(frac_init_alive * n_corals)
        ]
        alive_ind = tf.reshape(alive_ind, [-1, 1])
        self.grid_values = tf.tensor_scatter_nd_update(
            self.grid_values,
            alive_ind,
            tf.random.uniform(
                [len(alive_ind), dim],
                minval=domain[0],
                maxval=domain[1],
                dtype=self.dtype,
            ),
        )
        self.grid_fitness = tf.tensor_scatter_nd_update(
            self.grid_fitness,
            alive_ind,
            tf.vectorized_map(
                fitness_fn,
                tf.gather(self.grid_values, alive_ind),
            ),
        )
        self.grid_alive = tf.tensor_scatter_nd_update(
            self.grid_alive, alive_ind, tf.ones(len(alive_ind), dtype=tf.bool)
        )
        self.fract_broadcast = fract_broadcast
        self.fract_duplication = fract_duplication
        self.prob_die = prob_die

    def _broadcast_spawning(self, alive: tf.Tensor, n_broadcasters: int) -> tf.Tensor:
        broadcasters = tf.gather(self.grid_values, alive[:n_broadcasters])

        crossover_points = tf.random.uniform(
            [n_broadcasters // 2],
            minval=0,
            maxval=broadcasters.shape[-1],
            dtype=tf.int32,
        )
        crossover_points = tf.repeat(crossover_points, 2)
        crossover_points = tf.reshape(crossover_points, [-1, 1])
        crossover_points = tf.repeat(crossover_points, broadcasters.shape[1], axis=1)
        broadcasters_swap = tf.reshape(broadcasters, [-1, 2, broadcasters.shape[-1]])
        broadcasters_swap = tf.reverse(broadcasters_swap, [1])
        broadcasters_swap = tf.reshape(broadcasters_swap, [-1, broadcasters.shape[-1]])

        indexing = tf.range(broadcasters.shape[1])
        indexing = tf.reshape(indexing, [1, -1])
        indexing = tf.repeat(indexing, broadcasters.shape[0], axis=0)

        mask = indexing <= crossover_points

        new_corals = tf.where(mask, broadcasters, broadcasters_swap)
        return new_corals

    def _brooding(self, alive: tf.Tensor, n_broadcasters: int) -> tf.Tensor:
        brooders = tf.gather(self.grid_values, alive[n_broadcasters:])
        mutation = tf.random.uniform(
            brooders.shape,
            minval=self.mutation_domain[0],
            maxval=self.mutation_domain[1],
            dtype=self.dtype,
        )
        new_corals = brooders + mutation
        new_corals = tf.clip_by_value(new_corals, self.domain[0], self.domain[1])
        return new_corals

    def _larvae_settling(self, new_corals: tf.Tensor):
        for i in range(self.settling_trials):
            if len(new_corals) == 0:
                print(f"settled in {i}")
                break

            indices = tf.random.shuffle(tf.range(len(self.grid_values)))[
                : len(new_corals)
            ]

            alive = tf.gather(self.grid_alive, indices)
            old_fitness = tf.gather(self.grid_fitness, indices)
            new_fitness = tf.vectorized_map(self.fitness_fn, new_corals)
            settled_mask = ~alive | (new_fitness > old_fitness)
            settled_indices = tf.reshape(
                tf.boolean_mask(indices, settled_mask), [-1, 1]
            )

            self.grid_values = tf.tensor_scatter_nd_update(
                self.grid_values,
                settled_indices,
                tf.boolean_mask(new_corals, settled_mask),
            )
            self.grid_fitness = tf.tensor_scatter_nd_update(
                self.grid_fitness,
                settled_indices,
                tf.boolean_mask(new_fitness, settled_mask),
            )
            self.grid_alive = tf.tensor_scatter_nd_update(
                self.grid_alive, settled_indices, tf.repeat(True, len(settled_indices))
            )

            new_corals = tf.boolean_mask(new_corals, ~settled_mask)

    def _asexual_reproduction(self):
        n_duplication = int(self.fract_duplication * len(tf.where(self.grid_alive)))
        best_corals = tf.gather(
            self.grid_values, tf.math.top_k(self.grid_fitness, n_duplication).indices
        )
        self._larvae_settling(best_corals)

    def _depredation(self):
        n_depradation = int(self.frac_duplication * len(tf.where(self.grid_alive)))

    def step(self):
        # 0. selection for 1. and 2.
        # 1. broadcast spawning (crossover)
        # 2. brooding (mutation)
        # 3. larvae settling
        # 4. asexual reproduction (duplication) + larvae settling
        # 5. depredation ()

        alive = tf.reshape(tf.where(self.grid_alive), [-1])
        tf.random.shuffle(alive)

        n_broadcasters = int(self.fract_broadcast * len(alive)) // 2 * 2
        broadcasted = self._broadcast_spawning(alive, n_broadcasters)
        brooded = self._brooding(alive, n_broadcasters)

        settling_candidates = tf.concat([broadcasted, brooded], axis=0)
        self._larvae_settling(settling_candidates)

        self._asexual_reproduction()


reef = CoralReef(
    fitness_fn=lambda x: tf.cast(tf.reduce_sum(x), dtype=tf.float32),
    n_corals=100,
    domain=(0, 10),
    mutation_domain=(-1, 1),
    dim=4,
    dtype=tf.int32,
)
reef.step()

settled in 4
tf.Tensor(
[[7 5 9 9]
 [4 5 9 9]
 [8 4 7 6]
 [7 6 8 3]
 [6 5 7 5]
 [7 3 7 5]
 [4 6 8 3]
 [4 7 6 4]
 [4 7 6 4]
 [6 5 7 2]], shape=(10, 4), dtype=int32)
settled in 2
