In [None]:
import os
import time

from jax import numpy as np
import numpy.random as npr

import matplotlib
import matplotlib.animation 
import matplotlib.pyplot as plt
matplotlib.rcParams["animation.embed_limit"] = 1024

import skimage
import skimage.io as sio
import skimage.transform
import fracatal

from fracatal.functional_jax.convolve import ft_convolve
from fracatal.functional_jax.pad import pad_2d
from fracatal.functional_jax.metrics import compute_entropy, compute_frequency_ratio, compute_frequency_entropy


# imports being deprecated
from fracatal.functional_jax.compose import make_gaussian, \
        make_mixed_gaussian, \
        make_kernel_field, \
        make_update_function, \
        make_update_step, \
        make_make_kernel_function, \
        sigmoid_1, \
        get_smooth_steps_fn, \
        make_make_smoothlife_kernel_function, \
        make_smooth_interval, \
        make_smoothlife_update_function, \
        make_smoothlife_update_step


from fracatal.scripts.v_stability_sweep import v_stability_sweep
from fracatal.scripts.stability_sweep import stability_sweep     
from fracatal.scripts.mpi_sweep import mpi_stability_sweep

import IPython

In [None]:
"""
animation functions
"""

def get_fig(grid):
    
    global subplot_0
    
    fig, ax = plt.subplots(1,1)
    
    subplot_0 = ax.imshow(grid.squeeze(), cmap="magma")
    
    return fig, ax

def update_frame(ii):
    
    global grid
    
    subplot_0.set_array(grid.squeeze())
    
    grid = update_step(grid)

# _H. natans_

In [None]:
# common setup for H. natans

# the neighborhood kernel
amplitudes = [0.5, 1.0, 0.6667]
means = [0.0938, 0.2814, 0.4690]
standard_deviations = [0.0330, 0.0330, 0.0330]
kernel_radius = 31

make_kernel = make_make_kernel_function(amplitudes, means, standard_deviations)
kernel = make_kernel(kernel_radius)

# the growth function
mean_g = 0.26
standard_deviation_g = 0.036

clipping_fn = lambda x: np.clip(x,0,1.0)
my_update = make_update_function(mean_g, standard_deviation_g)


# The Platonic Pattern: _Hydrogeminium natans_ pickle

In [None]:
pattern_filepath = os.path.join("..", "patterns", "hydrogeminium_natans_pickle.npy")

pattern = np.load(pattern_filepath)[None,None,:,:]
plt.figure()
plt.imshow(pattern.squeeze(), cmap="magma")


In [None]:
number_samples = 1
dim = 256
grid = np.zeros((number_samples,1,dim,dim))
grid = grid.at[:,:,:pattern.shape[-2], :pattern.shape[-1]].set(pattern)

dts = 0.1

clipping_fn = lambda x: np.clip(x,0,1.0)
update_step = make_update_step(my_update, kernel, dts, clipping_fn)

num_frames = 100

fig, ax = get_fig(grid[0])
plt.show()

IPython.display.HTML(matplotlib.animation.FuncAnimation(fig, update_frame, frames=num_frames, interval=10).to_jshtml())

# The Non-Platonic _H. natans_ wobbler

In [None]:
#
pattern_name = "hydrogeminium_natans_wobbler"
pattern_filepath = os.path.join("..", "patterns", f"{pattern_name}.npy")

pattern = np.load(pattern_filepath)[None,None,:,:]
plt.figure()
plt.imshow(pattern.squeeze(), cmap="magma")
print(pattern.shape)

In [None]:
# common setup for H. natans

# the neighborhood kernel
amplitudes = [0.5, 1.0, 0.6667]
means = [0.0938, 0.2814, 0.4690]
standard_deviations = [0.0330, 0.0330, 0.0330]
kernel_radius = 25
k0 = 25

make_kernel = make_make_kernel_function(amplitudes, means, standard_deviations, dim=122)
kernel = make_kernel(kernel_radius)

# the growth function
mean_g = 0.26
standard_deviation_g = 0.036

clipping_fn = lambda x: np.clip(x,0,1.0)
my_update = make_update_function(mean_g, standard_deviation_g)


#
pattern_name = "hydrogeminium_natans_crispus"
pattern_filepath = os.path.join("..", "patterns", f"{pattern_name}.npy")

scale_factor = kernel_radius / k0
pattern = np.load(pattern_filepath)#[None,None,:,:]
scaled_pattern = np.array(skimage.transform.rescale(pattern, (1,1, scale_factor, scale_factor), order=5, anti_aliasing=True), \
                          dtype=np.float32)

print(pattern.shape, scaled_pattern.shape, kernel.shape)
pattern = scaled_pattern
plt.figure()
plt.imshow(pattern.squeeze(), cmap="magma")


In [None]:
number_samples = 1
dim = 128
grid = np.zeros((number_samples,1,dim,dim))
grid = grid.at[:,:,32:32+pattern.shape[-2], 48:48+pattern.shape[-1]].set(pattern)

dts = 0.591

clipping_fn = lambda x: np.clip(x,0,1.0)
update_step = make_update_step(my_update, kernel, dts, clipping_fn)

num_frames = 1000

fig, ax = get_fig(grid[0])
plt.show()

IPython.display.HTML(matplotlib.animation.FuncAnimation(fig, update_frame, frames=num_frames, interval=10).to_jshtml())

In [None]:
grid_0 = grid * 1.0

In [None]:
grid *= 0
for ii in range(0,grid.shape[-2]-temp.shape[-2], kernel_radius+ temp.shape[-2]):
    for jj in range(0,grid.shape[-1]-temp.shape[-1], kernel_radius+temp.shape[-1]):

        grid = grid.at[:,:,ii:ii+temp.shape[-2], jj:jj+temp.shape[-1]].set(temp)

        for step in range(32):
            grid = update_step(grid)
plt.imshow(grid.squeeze())

In [None]:
fig, ax = get_fig(grid)
num_frames = 1024
IPython.display.HTML(matplotlib.animation.FuncAnimation(fig, update_frame, frames=num_frames, interval=10).to_jshtml())

In [None]:
temp.shape

In [None]:
temp = grid[:,:,16:37, 33:54]

plt.imshow(temp.squeeze(), cmap="magma")
plt.show()

In [None]:
import skimage
import skimage.io as sio

my_cmap= plt.get_cmap("magma")
sio.imsave("../patterns/hydrogeminium_natans_crispus.png", my_cmap(temp.squeeze()))
np.save("../patterns/hydrogeminium_natans_crispus.npy", temp)

In [None]:
parameter_steps = 16
stride = min([16, parameter_steps])
max_t = 18
max_steps = 16000
max_growth = 2.
min_growth = 0.5
k0 = 31
grid_dim = 192

kernel_dim = 122
make_kernel = make_make_kernel_function(amplitudes, means, standard_deviations, dim=kernel_dim)

results = []
t0 = time.time()

max_runtime = 60*20.0

time_elapsed = time.time()-t0
time_stamp = int(t0*1000)
#params = [5, 53, 0.01, 1.05]

freq_zoom_strides = 5
freq_zoom_fraction = 2
idx = 0

exp_name = f"{pattern_name}_{time_stamp}"
save_dir = os.path.join("results", exp_name)

if os.path.exists(save_dir):
    pass
else:
    os.mkdir(save_dir)

metadata_path = os.path.join(save_dir, f"metadata_{time_stamp}.txt")
metadata = "index, pattern_name, min_dt, max_dt, min_kr, max_kr, parameter_steps, max_t, max_steps, max_runtime, "
metadata += "time_stamp, sim_time_elapsed,  total_time_elapsed, "
metadata += "img_savepath, accumulated_t_savepath, total_steps_savepath, explode_savepath, vanish_savepath, grid_T_savepath\n"

with open(metadata_path,"w") as f:
    f.write(metadata)

while time_elapsed <= max_runtime:
    
    min_kr = params[0]
    max_kr = params[1]
    min_dt = params[2]
    max_dt = params[3]

    t1 = time.time()
    
    results.append(v_stability_sweep(pattern, make_kernel, my_update, k0=k0, \
            max_t=max_t, max_steps=max_steps, parameter_steps=parameter_steps, stride=stride,\
            min_dt=min_dt, max_dt=max_dt,\
            min_kr = min_kr, max_kr=max_kr, default_dtype=np.float32))
    t2 = time.time()

    fig, ax = plt.subplots(1,1, figsize=(12,12))
    ax.imshow(results[-1][0])
    dts = np.arange(min_dt, max_dt, (max_dt-min_dt) / parameter_steps)
    krs = np.arange(min_kr, max_kr, (max_kr-min_kr) / parameter_steps)
    
    number_ticklabels = 16
    ticklabel_period = parameter_steps // number_ticklabels
    yticklabels = [f"{elem.item():.6e}" if not(mm % ticklabel_period) else "" for mm, elem in enumerate(dts)]
    xticklabels = [f"{elem.item():.6e}" if not(mm % ticklabel_period) else "" for mm, elem in enumerate(krs)]
    
    _ = ax.set_yticks(np.arange(0,dts.shape[0]))
    _ = ax.set_yticklabels(yticklabels, fontsize=16,  rotation=0)
    _ = ax.set_xticks(np.arange(0,krs.shape[0]))
    _ = ax.set_xticklabels(xticklabels, fontsize=16, rotation=90)
    _ = ax.set_ylabel("step size dt", fontsize=22)
    _ = ax.set_xlabel("kernel radius", fontsize=22)
    
    msg2 = f"total elapsed: {t2-t0:.3f} s, last sweep: {t2-t1:.3f}\n"
    msg = f"    dt from {min_dt:.2e} to {max_dt:.2e}\n"
    msg += f"    kr from {min_kr:2e} to {max_kr:.2e}\n"
    
    ax.set_title("disco persistence \n" +msg, fontsize=24)
    plt.savefig(f"../assets/disco{time_stamp}_{idx}.png")
    plt.show() 
       
    print(msg2 + msg)
    # save results
    # results_img, accumulated_t, total_steps, explode, vanish, done, grid_0, grid
    
    img_savepath = os.path.join(save_dir, f"{exp_name}_img_{idx}.png")
    img_npy_savepath = os.path.join(save_dir, f"{exp_name}_img_{idx}.npy")
    
    accumulated_t_savepath = os.path.join(save_dir, f"{exp_name}_accumulated_t_{idx}.npy")
    total_steps_savepath = os.path.join(save_dir, f"{exp_name}_total_steps_{idx}.npy")
    
    explode_savepath = os.path.join(save_dir, f"{exp_name}_explode_{idx}.npy")
    vanish_savepath = os.path.join(save_dir, f"{exp_name}_vanish_{idx}.npy")
    done_savepath = os.path.join(save_dir, f"{exp_name}_done_{idx}.npy")
    
    grid_0_savepath = os.path.join(save_dir, f"{exp_name}_grid_0_{idx}.npy")
    grid_T_savepath = os.path.join(save_dir, f"{exp_name}_grid_T_{idx}.npy")
                                    
    sio.imsave(img_savepath, results[-1][0])
    np.save(img_npy_savepath, results[-1][0])
    np.save(accumulated_t_savepath, results[-1][1])
    np.save(total_steps_savepath, results[-1][2])
    np.save(explode_savepath, results[-1][3])
    np.save(vanish_savepath, results[-1][4])
    np.save(done_savepath, results[-1][5])
    np.save(grid_0_savepath, results[-1][6])
    np.save(grid_T_savepath, results[-1][7])
    
    # log experiment metadata
    #metadata = "index, min_dt, max_dt, min_kr, max_kr, parameter_steps, time_stamp, "
    #metadata += "img_savepath, accumulated_t_savepath, total_steps_savepath, explode_savepath, vanish_savepath, grid_T_savepath\n"

    metadata = f"{idx}, {pattern_name}, {min_dt}, {max_dt}, {min_kr}, {max_kr}, {parameter_steps}, {max_t}, {max_steps}, {max_runtime}, "
    metadata += f"{time_stamp}, {t2-t1:2f}, {t2-t0:2f}, "
    metadata += f"{img_savepath}, {accumulated_t_savepath}, {total_steps_savepath}, {explode_savepath}, {vanish_savepath}, {grid_T_savepath}\n"
    with open(metadata_path,"a") as f:
        f.write(metadata)
        
    # determine next parameter range
    freq_zoom_dim = (results[-1][0].shape[-2]) // freq_zoom_fraction
    freq_zoom_stride = 4
    freq_zoom_strides = (results[-1][0].shape[-2]-freq_zoom_dim) // freq_zoom_stride +1
    
    fzd = freq_zoom_dim
    fzs = freq_zoom_stride
    
    params_list = []
    entropy = []
    frequency_entropy = []
    frequency_ratio = []
    # Weighted RGB conversion to grayscale

    # image for computing entropy, frequency metrics
    #gray_image = np.array(np.clip(results[-1][1].squeeze(),0,max_t), dtype=np.float32)
    gray_image = (1-results[-1][5]) # + np.clip(results[-1][1] / (2*max_t), 0, 0.5)
    
    #gray_image = gray_image / gray_image.max(), dtype=np.uint8)
    
    #gray_image =results[-1][0][:,:,1]
    #0.29 * results[-1][0][:,:,0] \
    #        + 0.6*results[-1][0][:,:,1] \
    #        + 0.11 * results[-1][0][:,:,2]  
    #gray_image = np.clip(results[-1][1] / max_t, 0, 1.0) \
    #        * (0.6 * (1-results[-1][5]) \
    #        + 0.29 * (results[-1][3] > 0) \
    #        + 0.11 * (results[-1][4]>0) ) 
    gray_image = (1.-results[-1][5])
    gray_image = np.array(gray_image, dtype=np.float32).squeeze()
    for ll in range(freq_zoom_strides**2):
        fzd = freq_zoom_dim
        fzs = freq_zoom_stride
        
        cx = int(np.floor(ll / freq_zoom_strides))
        cy = ll % freq_zoom_strides
        
        params_list.append([krs[cy*fzs].item(), \
                krs[cy*fzs+fzd].item(),\
                dts[cx*fzs].item(), \
                dts[cx*fzs+fzd].item()])


        subimage = gray_image[cx*fzs:cx*fzs+fzd,cy*fzs:cy*fzs+fzd]
        
        frequency_ratio.append(compute_frequency_ratio(subimage))
        entropy.append(compute_entropy(subimage))
        frequency_entropy.append(compute_frequency_entropy(subimage))
        
        #print(ll, cx, cy, fzs*cx, fzs*cy, subimage.shape)
    
    plt.figure()
    plt.subplot(221)
    plt.imshow(gray_image, cmap="magma")
    plt.title("accumulated_t image")
    plt.subplot(222)
    plt.imshow(np.array(frequency_ratio).reshape(freq_zoom_strides, freq_zoom_strides))
    plt.title("freq. ratio")
    plt.subplot(223)
    plt.imshow(np.array(entropy).reshape(freq_zoom_strides, freq_zoom_strides))
    plt.title("entropy")
    plt.subplot(224)
    plt.imshow(np.array(frequency_entropy).reshape(freq_zoom_strides, freq_zoom_strides))
    plt.title("frequency entropy")
    plt.tight_layout()
    plt.savefig(f"../assets/frequency_entropy_{time_stamp}_{idx}.png")
    plt.show()
    
    #params = params_list[np.argmax(np.array(frequency_entropy))]
    params = params_list[np.argmax(np.array(entropy))]

    t3 = time.time()
    idx += 1    
    time_elapsed = t3-t0

In [None]:
results[-1][1]

In [None]:
results[-1][6].shape
x = 0
y = 15

plt.imshow(np.array(results[-1][6][x,y], dtype=np.float32))

clipping_fn = lambda x: np.clip(x,0,1.0)
kernel = make_kernel(krs[15])
update_step = make_update_step(my_update, kernel, dts[0], clipping_fn)

num_frames = 100
grid = np.array(results[-1][6][x,y], dtype=np.float32)
fig, ax = get_fig(grid)
plt.show()

num_frames = 1000
#IPython.display.HTML(matplotlib.animation.FuncAnimation(fig, update_frame, frames=num_frames, interval=10).to_jshtml())
num_frames = 8192 
matplotlib.animation.FuncAnimation(fig, update_frame, frames=num_frames, interval=10).save("h_natans_new_pseudorganism.mp4")

In [None]:

num_frames = 1000
IPython.display.HTML(matplotlib.animation.FuncAnimation(fig, update_frame, frames=num_frames, interval=10).to_jshtml())

In [None]:
num_frames = 2048 matplotlib.animation.FuncAnimation(fig, update_frame, frames=num_frames, interval=10)

In [None]:
params = [25.7, 25.78448486328125, 0.59, 0.5917350053787231]

In [None]:
#gray_image =results[-1][0][:,:,1]
#0.29 * results[-1][0][:,:,0] \
#        + 0.6*results[-1][0][:,:,1] \
#        + 0.11 * results[-1][0][:,:,2]  
gray_image = np.clip(results[-1][1] / max_t, 0, 1.0) \
            * 0.6 * (1-1.0*results[-1][5]) \
            #* 0.29 * (results[-1][3] > 0)\
            
gray_image += np.clip(results[-1][1] / max_t, 0, 1.0) \
            * 0.29 * (results[-1][3] > 0)\

gray_image += np.clip(results[-1][1] / max_t, 0, 1.0) \
            * 0.11 * (results[-1][4] > 0)

gray_image = np.array(gray_image, dtype=np.float32).squeeze()

plt.figure()
plt.subplot(121)
plt.imshow(gray_image, cmap="gray")
plt.subplot(122)
plt.imshow(results[-1][0])
plt.show()


In [None]:
results[-1][3].reshape(16,16)

In [None]:
results[-1][4].reshape(16,16)

In [None]:

plt.figure()
plt.subplot(121)
plt.imshow(results[-6][0], cmap="gray")
plt.subplot(122)
plt.imshow(np.array((results[-6][1].squeeze() >= max_t), dtype=np.float32))
plt.colorbar()
plt.show()


In [None]:
results[-1][1].shape