In [None]:
import os
import time

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams["animation.embed_limit"] = 128

import skimage
import skimage.io as sio
import skimage.transform

import yuca
from yuca.ca.neural import NCA
from yuca.ca.continuous import CCA
from yuca.cppn import CPPN

from yuca.zoo.librarian import Librarian
from yuca.kernels import get_kernel

torch.set_default_dtype(torch.float32)

import IPython

from importlib import reload
reload(yuca)
reload(yuca.ca)

In [None]:
def plot_grid(grid, my_cmap=plt.get_cmap("magma"), title="CA animation", vmin=0.0, vmax=1):

    global subplot_0
    
    fig, ax = plt.subplots(1,1, figsize=(4.5,4.5), facecolor="white")

    # TODO invert cmap
    
    grid_display = grid[0].permute(1,2,0)
    
    subplot_0 = ax.imshow(grid_display, interpolation="nearest")
    
    fig.suptitle(title, fontsize=8)

    ax.set_yticklabels('')
    ax.set_xticklabels('')
    
    plt.tight_layout()

    return fig, ax

def update_fig(i):

    global subplot_0    
    global grid
    #global ax
    
    grid = aa(grid)
    
    grid_display = grid[0].permute(1,2,0)
    
    subplot_0.set_array(grid_display)
        
    plt.tight_layout()

In [None]:
class AdamAutomaton():
    
    def __init__(self, **kwargs):
        
        self.kernel_diameter = kwargs["diameter"] if "diameter" in kwargs.keys() else 27
        
        alpha = kwargs["alpha"] if "alpha" in kwargs.keys() else 1e-1
        beta_1 = kwargs["beta_1"] if "beta_1" in kwargs.keys() else 1e-3
        beta_2 = kwargs["beta_2"] if "beta_2" in kwargs.keys() else 1e-4
        epsilon = kwargs["epsilon"] if "epsilon" in kwargs.keys() else 1e-8
        
        self.set_alpha(alpha)
        self.set_beta_1(beta_1)
        self.set_beta_2(beta_2)
        self.set_epsilon(epsilon)
        
        self.init_kernel()
        self.init_growth()
    
    def init_kernel(self, mu=0.5, sigma=0.15):
        
        my_range = np.arange(-1,\
                             1 + 2/(self.kernel_diameter-1),\
                            2/(self.kernel_diameter-1))
        xx, yy = np.meshgrid(my_range, my_range)
        
        rr = torch.tensor(np.sqrt(xx**2 + yy**2)[None,None,:,:])
        
        self.kernel = torch.exp(-(rr-mu)**2/(2*sigma**2))
        
        self.kernel /= self.kernel.sum()
        self.kernel = self.kernel.to(torch.get_default_dtype())
        
        self.neighborhood = nn.Conv2d(1, 1, \
                self.kernel_diameter,\
                padding=(self.kernel_diameter-1) // 2, \
                groups=1,\
                padding_mode="circular",\
                bias=False)
        
        for param in self.neighborhood.named_parameters():
            param[1].requires_grad = False
            param[1][:] = self.kernel
    
    def init_growth(self, mu=0.167, sigma=0.013):
    
        def growth(x):
            
            return 2*torch.exp(-(x-mu)**2/(2*sigma**2))-1
        
        self.growth = growth
    
    def __call__(self, grid):
        
        # cell states
        a = grid[:,0:1,:,:]
        
        # neighborhoods
        #n = F.conv2d(a, self.kernel)
        n = self.neighborhood(a)
        
        # first and second moments
        m_0 = grid[:,1:2,:,:]
        v_0 = grid[:,2:3,:,:]
        
        # 'gradient', 
        g = self.growth(n)
        
        m = (self.beta_1) * m_0 + (1-self.beta_1)  * g
        v = (self.beta_2) * v_0 + (1-self.beta_2)  * g**2
        
        # adam update for cell states
        new_a = a + self.alpha * (m / (torch.sqrt(v) + self.epsilon))
        
        new_grid = torch.zeros_like(grid)
        # assign cell states and moments
        new_grid[:,0:1,:,:] = new_a.unsqueeze(0).unsqueeze(0)
        new_grid[:,1:2,:,:] = m.unsqueeze(0).unsqueeze(0)
        new_grid[:,2:3,:,:] = v.unsqueeze(0).unsqueeze(0)
        
        return torch.clamp(new_grid, 0, 1.0)
    
    def set_alpha(self, new_alpha):
        self.alpha = 1.0 * new_alpha
        
    def get_alpha(self):
        return 1.0 * self.alpha
        
    def set_beta_1(self, new_beta_1):
        self.beta_1 = 1.0 * new_beta_1
        
    def get_beta_1(self):
        return 1.0 * self.beta_1

    def set_beta_2(self, new_beta_2):
        self.beta_2 = 1.0 * new_beta_2
        
    def get_beta_2(self):
        return 1.0 * self.beta_2
        
    def set_epsilon(self, new_epsilon):
        self.epsilon = 1.0 * new_epsilon
        
    def get_epsilon(self):
        return 1.0 * self.epsilon

In [None]:
my_seed = 13

In [None]:
adam_orbium_1 = torch.load(os.path.join("..","patterns","adam_orbium_1.pt"))

plt.imshow(adam_orbium_1.squeeze().permute(1,2,0))

In [None]:
grid_dim = 72
number_samples = 100
number_steps = 2048
warmup_steps = number_steps // 2

my_dts = np.arange(1/number_samples, 1+1/number_samples, 1/number_samples)

aa = AdamAutomaton(diameter=27, beta_1=0.8, beta_2=0.99)
aa.init_growth(mu=0.167, sigma=0.013)

grid = torch.zeros(1,3,grid_dim, grid_dim)
grid[:,:,:adam_orbium_1.shape[-2], :adam_orbium_1.shape[-1]] = adam_orbium_1

for my_step in range(warmup_steps):
    grid = aa(grid)

grid_0 = 1.0 * grid

for my_dt in my_dts:

    grid = 1.0 * grid_0
    
    aa.set_alpha(my_dt)
    
    for my_step in range(number_steps):
        
        grid = aa(grid)

    fig, ax = plt.subplots(1,2, figsize=(4.5,3.5), facecolor="white")

    grid_display = grid[0].permute(1,2,0)
    grid_0_display = grid_0[0].permute(1,2,0)
    
    
    ax[0].imshow(grid_0_display, interpolation="nearest")
    ax[0].set_title("Starting grid")
    ax[1].imshow(grid_display, interpolation="nearest")

    ax[1].set_title(f"step {my_step+1}\n dt: {aa.get_alpha():.4f} \n kernel diameter: {aa.kernel.shape[-1]}")
    
    ax[0].set_yticklabels('')
    ax[0].set_xticklabels('')
    ax[1].set_yticklabels('')
    ax[1].set_xticklabels('')
    
    fig.suptitle("Adam Orbium discretization stability")
    
    fig_name = f"step{my_step+1}_dt{aa.get_alpha():.4f}_kd{aa.kernel.shape[-1]}"
    fig_name =fig_name.replace(".","x") + ".png"
    dir_name = os.path.join("..","assets", "adamata_stability")
    if os.path.isdir(dir_name):
        pass
    else:
        os.mkdir(dir_name)
        
    fig_name = os.path.join("..","assets", "adamata_stability", fig_name)
    plt.savefig(fig_name)
    
    plt.show()


In [None]:
grid_dim = 180
number_samples = 100
number_steps = 1024
warmup_steps = number_steps // 2

my_dts = np.arange(0.01, 1.02, 0.04) 
my_krs = np.arange(9,32,1)
native_diameter = 27

stability_map = None
for my_kr in my_krs:
    my_diameter = 2 * my_kr + 1
    temp_map = None
    
    for my_dt in my_dts:
        
        aa = AdamAutomaton(diameter=my_diameter, beta_1=0.8, beta_2=0.99)
        aa.init_growth(mu=0.167, sigma=0.013)
        
        grid = 1.0 * grid_0

        aa.set_alpha(my_dt)
        
        
        scale = my_diameter / native_diameter
        adam_orbium_1 = np.array(torch.load(os.path.join("..","patterns","adam_orbium_1.pt"))).squeeze()
        
        scale_dims = [np.round(elem * scale) for elem in adam_orbium_1.shape[-2:]]
        adam_orbium_scaled = skimage.transform.resize(adam_orbium_1.transpose(1,2,0), scale_dims)
        adam_orbium_scaled = torch.tensor(adam_orbium_scaled).permute(2,0,1).unsqueeze(0)

        grid = torch.zeros(1,3,grid_dim, grid_dim)
        grid[:,:,:adam_orbium_scaled.shape[-2], :adam_orbium_scaled.shape[-1]] = adam_orbium_scaled

        for my_step in range(warmup_steps):
            grid = aa(grid)

        grid_0 = 1.0 * grid
        
        for my_step in range(number_steps):
            grid = aa(grid)

        fig, ax = plt.subplots(1,2, figsize=(4.5,3.5), facecolor="white")

        grid_display = grid[0].permute(1,2,0)
        grid_0_display = grid_0[0].permute(1,2,0)


        ax[0].imshow(grid_0_display, interpolation="nearest")
        ax[0].set_title("Starting grid")
        ax[1].imshow(grid_display, interpolation="nearest")

        ax[1].set_title(f"step {my_step+1}\n dt: {aa.get_alpha():.4f} \n kernel diameter: {aa.kernel.shape[-1]}")

        ax[0].set_yticklabels('')
        ax[0].set_xticklabels('')
        ax[1].set_yticklabels('')
        ax[1].set_xticklabels('')

        fig.suptitle("Adam Orbium discretization stability")

        fig_name = f"step{my_step+1}_dt{aa.get_alpha():.4f}_kd{aa.kernel.shape[-1]}"
        fig_name =fig_name.replace(".","x") + ".png"
        dir_name = os.path.join("..","assets", "adamata_stability")
        if os.path.isdir(dir_name):
            pass
        else:
            os.mkdir(dir_name)

        fig_name = os.path.join(dir_name, fig_name)
        plt.savefig(fig_name)
        
        if temp_map is None:
            temp_map = grid_display
            
        else:
            temp_map = np.append(temp_map, grid_display, axis=1)

        plt.show()
    
    if stability_map is None:
        stability_map = temp_map
    else:
        stability_map = np.append(stability_map, temp_map, axis=0)
        
    stability_figure_name = os.path.join(dir_name, "stability_map.png")
    stability_numpy_name = os.path.join(dir_name, "stability_map.npy")
    sio.imsave(stability_figure_name, stability_map)
    np.save(stability_numpy_name, stability_map)


In [None]:
dir_name = os.path.join("..","assets", "adamata_stability")
file_name = os.path.join(dir_name, "stability_map.npy")
stability_figure_name = os.path.join(dir_name, "stability_map.png")

temp = np.load(file_name)
temp.shape

for ii in range(0, temp.shape[0], 180):
    for jj in range(0, temp.shape[1], 180):
        
        temp[:,jj:jj+5] += 0.05 
        
    
    temp[ii:ii+5,:] += 0.5
    
temp = np.clip(temp,0,1.0)

plt.figure(figsize=(10,10))
plt.imshow(temp)
plt.show()
sio.imsave(stability_figure_name, temp)
        

In [None]:
#my_seed += 13
num_frames = 100
grid_dim = 512

aa = AdamAutomaton(diameter=27, beta_1=0.8, beta_2=0.99)
aa.init_growth(mu=0.167, sigma=0.013)

plt.figure()
plt.imshow(aa.kernel.squeeze())
plt.title("Adam Automaton neighborhood kernel")
plt.show()

torch.manual_seed(my_seed)

grid = torch.zeros(1,3,grid_dim, grid_dim )
grid[:,:,:,:] = torch.rand(1,3,grid_dim, grid_dim )

gap_size = 96
gap_minus = 10

for ii in range(0, grid_dim, gap_size):
    grid[:,:,ii:ii+gap_size-gap_minus,:] *= 0.0

for jj in range(0, grid_dim, gap_size):
    grid[:,:,:,jj:jj+gap_size-gap_minus] *= 0.0    

fig, ax = plot_grid(grid)

plt.show()

In [None]:
fig, ax = plot_grid(grid)

#matplotlib.animation.FuncAnimation(fig, update_fig, frames=max_steps, interval=25).save(\
#        f"temp_{pattern_name}.gif")

IPython.display.HTML(\
        matplotlib.animation.FuncAnimation(fig, update_fig, frames=num_frames, interval=10).to_jshtml())

In [None]:
num_frames = 50
fig, ax = plot_grid(grid)

gap_size = 96
gap_minus = 20

for ii in range(0, grid_dim, gap_size):
    grid[:,:,ii:ii+gap_size-gap_minus,:] *= 0.0

for jj in range(0, grid_dim, gap_size):
    grid[:,:,:,jj:jj+gap_size-gap_minus] *= 0.0   

#matplotlib.animation.FuncAnimation(fig, update_fig, frames=max_steps, interval=25).save(\
#        f"temp_{pattern_name}.gif")

IPython.display.HTML(\
        matplotlib.animation.FuncAnimation(fig, update_fig, frames=num_frames, interval=10).to_jshtml())

In [None]:
num_frames = 250

IPython.display.HTML(\
        matplotlib.animation.FuncAnimation(fig, update_fig, frames=num_frames, interval=10).to_jshtml())

In [None]:
grid_0 = grid *1.0

plt.figure(); plt.imshow(grid_0.squeeze().permute(1,2,0)); plt.show()

In [None]:


plt.figure(); plt.imshow(grid.squeeze().permute(1,2,0)); plt.show()

In [None]:
grid[:,0,250:355,:256] *= 0
grid[:,0,:,:50] *= 0

#grid[:,0,:,:128] *= 0
grid[:,0,:,452:] *= 0

plt.figure(); plt.imshow(grid.squeeze().permute(1,2,0)); plt.show()

In [None]:
adam_orbium_0 = grid[:,:,410:450, 70:100]

plt.figure();
plt.imshow(adam_orbium_0.squeeze().permute(1,2,0))

In [None]:
grid_1 = 1.0 * grid

num_frames = 300
fig, ax = plot_grid(grid)

#matplotlib.animation.FuncAnimation(fig, update_fig, frames=max_steps, interval=25).save(\
#        f"temp_{pattern_name}.gif")

IPython.display.HTML(\
        matplotlib.animation.FuncAnimation(fig, update_fig, frames=num_frames, interval=10).to_jshtml())

In [None]:
plt.imshow(grid_1.squeeze().permute(1,2,0))

In [None]:
grid = 1.0 * grid_1[:,:,300:450,0:150]

num_frames = 1000
fig, ax = plot_grid(grid)

pattern_name="adam_orbium"
matplotlib.animation.FuncAnimation(fig, update_fig, frames=num_frames, interval=25).save(\
        f"temp_{pattern_name}.gif")

IPython.display.HTML(\
        matplotlib.animation.FuncAnimation(fig, update_fig, frames=num_frames, interval=10).to_jshtml())

In [None]:
adam_orbium_1 = grid[:,:,315:350, 365:410]

plt.figure();
plt.imshow(adam_orbium_1.squeeze().permute(1,2,0))

In [None]:
torch.save(adam_orbium_0, "adam_orbium_0.pt")
torch.save(adam_orbium_1, "adam_orbium_1.pt")

In [None]:
#print(aa.kernel_diameter, aa.epsilon, aa.alpha, aa.beta_1, aa.beta_2, "mu_g", 0.167, "sigma_g", 0.013, "mu_k", 0.5, "sigma_k", 0.15)
# values
# 27 1e-08 0.1 0.8 0.99 mu_g 0.167 sigma_g 0.013 mu_k 0.5 sigma_k 0.15

In [None]:
grid[:,2,:10,:10]