# Neural Cellular Automata

## NCA

# Glaberish

## A general framework for continuous CA that distinguishes birth and survival


### Overview

![](../assets/neural_neuroscutium_valvatus_summary.png)

$$
\frac{\partial A}{\partial t} = U(A_t) = (1 - K_i \circledast A_t) G(K_n \circledast A_t) + (K_i \circledast A_t) P(K_n \circledast A_t)
$$

$$
A_{t+ \Delta t} = A_t + \Delta t * U(A_t)
$$

$$
U(A_t) = G(K_n \circledast A_t)
$$

### Neighborhoods

$$
K_n(r)  = a {e} ^{- \frac{(r - \mu_K)^2}{2 \sigma_K^2} }
$$

$ \frac{1}{a} = \sum_{i,j} e^{-\frac{(\sqrt{i^2+j^2} - b)^2}{c^2}} $

### Updates

$$
n = K_n \circledast A_t
$$


$$
G(n) = W_{hy} @ \sigma(W_{hh} @ \sigma(W_{xh} @ n))
$$

$$
G(n) = K_{hy} \circledast \sigma(K_{hh} \circledast \sigma(K_{xh}\circledast n))
$$


<!-- Image link to be used after making repo public -->
<!-- <img src="https://raw.github.com/riveSunder/DiscoGliders/assets/nca_4.gif"> -->

<!-- local image -->
![](../assets/nca_neurosingle_glider000.gif)


In [None]:
import os
import time

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams["animation.embed_limit"] = 128

import skimage
import skimage.io as sio
import skimage.transform

import yuca
from yuca.ca.neural import NCA
from yuca.ca.continuous import CCA
from yuca.cppn import CPPN

from yuca.zoo.librarian import Librarian
from yuca.kernels import get_kernel

torch.set_default_dtype(torch.float32)

import IPython

from importlib import reload
reload(yuca)
reload(yuca.ca)



In [None]:
def plot_ca(ca, my_suptitle = ""):
    
    my_cmap = lambda x: np.array(plt.get_cmap("magma")(x))
    colorify = lambda x: 1.0 - my_cmap(x)[...,:3]
    
    fig, ax = plt.subplots(1,3, figsize=(9,3), gridspec_kw={\
            "width_ratios": [0.1, 0.1, 0.7],"height_ratios": [1]})
    
    K_i = ca.id_kernel.detach().cpu().squeeze().numpy()
    K_n = ca.neighborhood_kernels.detach().cpu().squeeze().numpy()
    
    #dim = max(K_n.shape)
    
    # nearest neighbor resize
    #K_i = skimage.transform.resize(K_i,(dim,dim), order=0)
    K_i /= K_i.max()
    K_n /= K_n.max()
    
    display_kernels = (colorify(K_i), colorify(K_n))
    
    x = torch.arange(0, 1.0, 0.001).reshape(1,1,1,-1)
    
    P_x = ca.update_universe(torch.ones_like(x), x).numpy().squeeze() #ca.persistence_fns[0](x)
    G_x = ca.update_universe(torch.zeros_like(x), x).numpy().squeeze()  #ca.genesis_fns[0](x)
    xx = x.squeeze()
    
    ax[0].imshow(display_kernels[0])
    ax[1].imshow(display_kernels[1])
    
    ax[0].set_title("$K_i$")    
    ax[1].set_title("$K_n$")
    
    my_color = (1.0 - my_cmap(192)[:3], 1.0-my_cmap(128)[:3])
    
    if type(ca) == yuca.ca.neural.NCA:
        pass
    
    if (P_x - G_x).sum() == 0.0:
        ax[2].plot(xx, P_x, lw=3, label= "Growth $G$", alpha=0.85, color=my_color[0])
    else:
        ax[2].plot(xx, P_x, "--", lw=3, \
                label="Persistence $P$", alpha=0.85, color=my_color[0])
        ax[2].plot(xx, G_x, "-", lw=3, \
                label="Genesis $G$", alpha=0.85, color=my_color[1])
    
    ax[2].set_title("Update function(s)")
    
    fig.suptitle(my_suptitle, fontsize=22)
    plt.legend()
    plt.tight_layout()
    
    return fig, ax

def plot_ca_pattern(pattern, ca, my_suptitle = "", row_letter=""):
    
    my_cmap = lambda x: np.array(plt.get_cmap("magma")(x))
    colorify = lambda x: 1.0 - my_cmap(x)[...,:3]
    
    fig, ax = plt.subplots(1,4, figsize=(9,2.5), gridspec_kw={\
            "width_ratios": [0.1, 0.1, 0.6, 0.1],"height_ratios": [1]})
    
    K_i = ca.id_kernel.detach().cpu().squeeze().numpy()
    K_n = ca.neighborhood_kernels.detach().cpu().squeeze().numpy()
    
    #dim = max(K_n.shape)
    
    # nearest neighbor resize
    #K_i = skimage.transform.resize(K_i,(dim,dim), order=0)
    K_i /= K_i.max()
    K_n /= K_n.max()
    
    display_kernels = (colorify(K_i), colorify(K_n))
    
    x = torch.arange(0, 1.0, 0.001).reshape(1,1,1,-1)
    
    P_x = ca.update_universe(torch.ones_like(x), x).numpy().squeeze() #ca.persistence_fns[0](x)
    G_x = ca.update_universe(torch.zeros_like(x), x).numpy().squeeze()  #ca.genesis_fns[0](x)
    xx = x.squeeze()
    
    ax[0].imshow(display_kernels[0])
    ax[1].imshow(display_kernels[1])
    
    ax[0].set_title("$K_i$")    
    ax[1].set_title("$K_n$")
    
    my_color = (1.0 - my_cmap(192)[:3], 1.0-my_cmap(128)[:3])
    
    if type(ca) == yuca.ca.neural.NCA:
        pass
    
    if (P_x - G_x).sum() == 0.0:
        ax[2].plot(xx, P_x, lw=3, label= "Growth $G$", alpha=0.85, color=my_color[0])
    else:
        ax[2].plot(xx, P_x, "--", lw=3, \
                label="Persistence $P$", alpha=0.85, color=my_color[0])
        ax[2].plot(xx, G_x, "-", lw=3, \
                label="Genesis $G$", alpha=0.85, color=my_color[1])
    
    ax[2].set_title("Update function(s)")
    
    fig.suptitle(my_suptitle, fontsize=22)
    fig.legend(loc=[0.65, 0.450])
    
    pad_by = max(pattern.shape) // 2
    display_pattern = colorify(np.pad(pattern.squeeze(), pad_by))
    ax[3].imshow(display_pattern)
    ax[3].set_title("Glider")
    ax[0].set_ylabel(row_letter, fontsize=36, rotation=0)
    
    plt.tight_layout()
    
    return fig, ax

def plot_grid(grid, my_cmap=plt.get_cmap("magma"), title="Neural CA animation", vmin=0.0, vmax=1):

    global subplot_0
    
    fig, ax = plt.subplots(1,1, figsize=(4.5,4.5), facecolor="white")

    # TODO invert cmap
    
    grid_display = 1.0 - my_cmap(grid[0,0])[:,:,:3]
    
    subplot_0 = ax.imshow(grid_display, interpolation="nearest")
    
    fig.suptitle(title, fontsize=8)

    ax.set_yticklabels('')
    ax.set_xticklabels('')
    
    plt.tight_layout()

    return fig, ax

def update_fig(i):

    global subplot_0    
    global grid
    #global ax
    
    grid = ca(grid)
    
    my_cmap=plt.get_cmap("magma")
    grid_display = 1.0 - my_cmap(grid[0,0])[:,:,:3]
    
    subplot_0.set_array(grid_display)
        
    plt.tight_layout()

In [None]:
lib = Librarian()
lib.index

In [None]:
ca = NCA()

In [None]:
num_frames = 256
grid_dim = 64
save_figs = True

nca_names = ["neurorbium000",\
             "neurosynorbium000",\
             "neurosingle_glider000",\
             "neurowobble_glider000"]

for name in nca_names:
    
    lib.verbose = False
    p, m = lib.load(name)

    ca.restore_config(m["ca_config"])
    print(f"\n NCA {name} :\n", f"config name: {m['ca_config']}\n", ca.weights_layer)
    
    if "neurorbium" in name:
        ca.set_dt(0.08)
        
    ca.eval()
    ca.no_grad()
    ca.set_dt(0.08)
    
    fig0, ax0 = plot_ca_pattern(p, ca,  my_suptitle=f"{name}", row_letter="  ") 

    if save_figs:
        # save neighborhood-update-glider ca summary
        plt.savefig(os.path.join("..","assets",f"nca_{name}.png"))
        
        # save glider animation
        grid = torch.zeros(1, 1, grid_dim, grid_dim)

        grid[:,:,:p.shape[-2], :p.shape[-1]] = torch.tensor(p)

        fig, ax = plot_grid(grid)

        matplotlib.animation.FuncAnimation(fig, update_fig, frames=num_frames, interval=10).save(\
                os.path.join("..", "assets", f"nca_{name}.gif"))
    plt.show()

In [None]:
"""
nca_names = ["neurorbium000",\
             "neurosynorbium000",\
             "neurosingle_glider000",\
             "neurowobble_glider000"]
"""

# try running gliders with different scales and time steps

save_figs = False
glider_idx = 0
my_scale = 1
my_dt = 0.08
num_frames = 512

p, m = lib.load(nca_names[glider_idx])
ca.restore_config(m["ca_config"])

ca.set_dt(my_dt)
ca.set_kernel_radius(ca.kernel_radius*my_scale)

grid_dim = max([64, 3*ca.kernel_radius])

ca.eval()
ca.no_grad()

p = skimage.transform.rescale(p, \
                         scale=(1,1, my_scale, my_scale), \
                         anti_aliasing=True,\
                         mode="constant", cval=0.0)

grid = ca.initialize_grid(dim=grid_dim) #torch.zeros(1, 1, grid_dim, grid_dim)

grid[:,:,:p.shape[-2], -p.shape[-1]:] = torch.tensor(p)
    
fig, ax = plot_grid(grid)

if save_figs:
    matplotlib.animation.FuncAnimation(fig, update_fig, frames=num_frames, interval=10).save(f"{glider_idx}_0.gif")

IPython.display.HTML(\
        matplotlib.animation.FuncAnimation(fig, update_fig, frames=num_frames, interval=10).to_jshtml())