In [None]:
import os

run_local = True

if run_local: 
    root_dir = ".."
else:
    os.system("git clone -b temp https://github.com/riveSunder/fractal_persistence/")
    os.system("mv fractal_persistence/* ./")
    os.system("! pip install -e .")
    os.system("! mkdir assets")

    root_dir = "."

In [None]:
import os
import time

from jax import numpy as np
import numpy.random as npr

import matplotlib
import matplotlib.animation 
import matplotlib.pyplot as plt
matplotlib.rcParams["animation.embed_limit"] = 1024

import skimage
import skimage.io as sio
import skimage.transform
import fracatal

from fracatal.functional_jax.convolve import ft_convolve
from fracatal.functional_jax.pad import pad_2d
from fracatal.functional_jax.metrics import compute_entropy, compute_frequency_ratio, compute_frequency_entropy


# imports being deprecated
from fracatal.functional_jax.compose import make_gaussian, \
        make_mixed_gaussian, \
        make_kernel_field, \
        make_update_function, \
        make_update_step, \
        make_make_kernel_function, \
        sigmoid_1, \
        get_smooth_steps_fn, \
        make_make_smoothlife_kernel_function, \
        make_smooth_interval, \
        make_smoothlife_update_function, \
        make_smoothlife_update_step


from fracatal.scripts.v_stability_sweep import v_stability_sweep
from fracatal.scripts.stability_sweep import stability_sweep     
from fracatal.scripts.mpi_sweep import mpi_stability_sweep

import IPython

In [None]:
"""
animation functions
"""

def get_fig(grid):
    
    global subplot_0
    
    fig, ax = plt.subplots(1,1)
    
    subplot_0 = ax.imshow(grid.squeeze(), cmap="magma")
    
    return fig, ax

def update_frame(ii):
    
    global grid
    
    subplot_0.set_array(grid.squeeze())
    
    grid = update_step(grid)

In [None]:
# common setup for orbium

# the neighborhood kernel
amplitudes = [1.0]
means = [0.5]
standard_deviations = [0.15]
kernel_radius = 13

make_kernel = make_make_kernel_function(amplitudes, means, standard_deviations)
kernel = make_kernel(kernel_radius)

# the growth function
mean_g = 0.15
standard_deviation_g = 0.017

clipping_fn = lambda x: np.clip(x,0,1.0)
my_update = make_update_function(mean_g, standard_deviation_g)


In [None]:
pattern_name = "orbium_unicaudatus"
pattern_filepath = os.path.join(root_dir, "patterns", f"{pattern_name}.npy")

pattern = np.load(pattern_filepath)
plt.figure()
plt.imshow(pattern.squeeze(), cmap="magma")


In [None]:
if not(os.path.exists(os.path.join(root_dir, "results"))):
    os.mkdir(os.path.join(root_dir, "results"))

In [None]:
parameter_steps = 64
stride = min([16, parameter_steps])
max_t = 10
max_steps = 10000
max_growth = 1.3
min_growth = 0.9
k0 = 13
grid_dim = 256
default_dtype = np.float32

kernel_dim = 122
make_kernel = make_make_kernel_function(amplitudes, means, \
        standard_deviations, dim=kernel_dim, default_dtype=default_dtype)

results = []
t0 = time.time()

max_runtime = 60*360.0
max_zooms = 12
total_zooms = 0

time_elapsed = time.time()-t0
time_stamp = int(t0*1000)
params = [4, 72, 0.01, 1.05]

freq_zoom_strides = 5
freq_zoom_fraction = 2
idx = 0

exp_name = f"{pattern_name}_{time_stamp}"
save_dir = os.path.join(root_dir, "results", exp_name)

if os.path.exists(save_dir):
    pass
else:
    os.mkdir(save_dir)

metadata_path = os.path.join(save_dir, f"metadata_{time_stamp}.txt")
metadata = "index, pattern_name, min_dt, max_dt, min_kr, max_kr, parameter_steps, max_t, max_steps, max_runtime, "
metadata += "time_stamp, sim_time_elapsed,  total_time_elapsed, "
metadata += "img_savepath, accumulated_t_savepath, total_steps_savepath, explode_savepath, vanish_savepath, grid_T_savepath\n"

with open(metadata_path,"w") as f:
    f.write(metadata)

while time_elapsed <= max_runtime and total_zooms <= max_zooms:
    
    min_kr = params[0]
    max_kr = params[1]
    min_dt = params[2]
    max_dt = params[3]

    t1 = time.time()
    
    
    results.append(stability_sweep(pattern, make_kernel, my_update, \
            max_t=max_t, max_steps=max_steps, parameter_steps=parameter_steps, stride=stride,\
            min_dt=min_dt, max_dt=max_dt,\
            min_kr = min_kr, max_kr=max_kr, k0=k0, \
            default_dtype=default_dtype))
   
    t2 = time.time()

    fig, ax = plt.subplots(1,1, figsize=(12,12))
    ax.imshow(results[-1][0])
    dts = np.arange(min_dt, max_dt, (max_dt-min_dt) / parameter_steps)
    krs = np.arange(min_kr, max_kr, (max_kr-min_kr) / parameter_steps)
    
    number_ticklabels = 16
    ticklabel_period = parameter_steps // number_ticklabels
    yticklabels = [f"{elem.item():.6e}" if not(mm % ticklabel_period) else "" for mm, elem in enumerate(dts)]
    xticklabels = [f"{elem.item():.6e}" if not(mm % ticklabel_period) else "" for mm, elem in enumerate(krs)]
    
    _ = ax.set_yticks(np.arange(0,dts.shape[0]))
    _ = ax.set_yticklabels(yticklabels, fontsize=16,  rotation=0)
    _ = ax.set_xticks(np.arange(0,krs.shape[0]))
    _ = ax.set_xticklabels(xticklabels, fontsize=16, rotation=90)
    _ = ax.set_ylabel("step size dt", fontsize=22)
    _ = ax.set_xlabel("kernel radius", fontsize=22)
    
    msg2 = f"total elapsed: {t2-t0:.3f} s, last sweep: {t2-t1:.3f}\n"
    msg = f"    dt from {min_dt:.2e} to {max_dt:.2e}\n"
    msg += f"    kr from {min_kr:2e} to {max_kr:.2e}\n"
    
    ax.set_title("disco persistence \n" +msg, fontsize=24)
    plt.savefig(f"../assets/disco{time_stamp}_{idx}.png")
    plt.show() 
       
    print(msg2 + msg)
    # save results
    # results_img, accumulated_t, total_steps, explode, vanish, done, grid_0, grid
    
    img_savepath = os.path.join(save_dir, f"{exp_name}_img_{idx}.png")
    img_npy_savepath = os.path.join(save_dir, f"{exp_name}_img_{idx}.npy")
    
    accumulated_t_savepath = os.path.join(save_dir, f"{exp_name}_accumulated_t_{idx}.npy")
    total_steps_savepath = os.path.join(save_dir, f"{exp_name}_total_steps_{idx}.npy")
    
    explode_savepath = os.path.join(save_dir, f"{exp_name}_explode_{idx}.npy")
    vanish_savepath = os.path.join(save_dir, f"{exp_name}_vanish_{idx}.npy")
    done_savepath = os.path.join(save_dir, f"{exp_name}_done_{idx}.npy")
    
    grid_0_savepath = os.path.join(save_dir, f"{exp_name}_grid_0_{idx}.npy")
    grid_T_savepath = os.path.join(save_dir, f"{exp_name}_grid_T_{idx}.npy")
                                    
    sio.imsave(img_savepath, results[-1][0])
    np.save(img_npy_savepath, results[-1][0])
    np.save(accumulated_t_savepath, results[-1][1])
    np.save(total_steps_savepath, results[-1][2])
    np.save(explode_savepath, results[-1][3])
    np.save(vanish_savepath, results[-1][4])
    np.save(done_savepath, results[-1][5])
    np.save(grid_0_savepath, results[-1][6])
    np.save(grid_T_savepath, results[-1][7])
    
    # log experiment metadata
    #metadata = "index, min_dt, max_dt, min_kr, max_kr, parameter_steps, time_stamp, "
    #metadata += "img_savepath, accumulated_t_savepath, total_steps_savepath, explode_savepath, vanish_savepath, grid_T_savepath\n"

    metadata = f"{idx}, {pattern_name}, {min_dt}, {max_dt}, {min_kr}, {max_kr}, {parameter_steps}, {max_t}, {max_steps}, {max_runtime}, "
    metadata += f"{time_stamp}, {t2-t1:2f}, {t2-t0:2f}, "
    metadata += f"{img_savepath}, {accumulated_t_savepath}, {total_steps_savepath}, {explode_savepath}, {vanish_savepath}, {grid_T_savepath}\n"
    with open(metadata_path,"a") as f:
        f.write(metadata)
        
    # determine next parameter range
    freq_zoom_dim = (results[-1][0].shape[-2]) // freq_zoom_fraction
    freq_zoom_stride = 4 + int(parameter_steps/16)
    freq_zoom_strides = (results[-1][0].shape[-2]-freq_zoom_dim) // freq_zoom_stride +1
    
    fzd = freq_zoom_dim
    fzs = freq_zoom_stride
    
    params_list = []
    entropy = []
    frequency_entropy = []
    frequency_ratio = []
    # Weighted RGB conversion to grayscale
    gray_image = (1.0 - results[-1][5])
            #0.29 * results[-1][0][:,:,0] \
            #+ 0.6*results[-1][0][:,:,1] \
            #+ 0.11 * results[-1][0][:,:,2]  
    
    for ll in range(freq_zoom_strides**2):
        fzd = freq_zoom_dim
        fzs = freq_zoom_stride
        
        cx = int(np.floor(ll / freq_zoom_strides))
        cy = ll % freq_zoom_strides
        
        params_list.append([krs[cy*fzs].item(), \
                krs[cy*fzs+fzd].item(),\
                dts[cx*fzs].item(), \
                dts[cx*fzs+fzd].item()])


        subimage = gray_image[cx*fzs:cx*fzs+fzd,cy*fzs:cy*fzs+fzd]
        
        frequency_ratio.append(compute_frequency_ratio(subimage))
        entropy.append(compute_entropy(subimage))
        frequency_entropy.append(compute_frequency_entropy(subimage))
        
    
    plt.figure()
    plt.subplot(221)
    plt.imshow(gray_image.squeeze())
    plt.title("results image")
    plt.subplot(222)
    plt.imshow(np.array(frequency_ratio).reshape(freq_zoom_strides, freq_zoom_strides))
    plt.title("freq. ratio")
    plt.subplot(223)
    plt.imshow(np.array(entropy).reshape(freq_zoom_strides, freq_zoom_strides))
    plt.title("entropy")
    plt.subplot(224)
    plt.imshow(np.array(frequency_entropy).reshape(freq_zoom_strides, freq_zoom_strides))
    plt.title("frequency entropy")
    plt.tight_layout()
    plt.savefig(f"{root_dir}/assets/frequency_entropy_{time_stamp}_{idx}.png")
    plt.show()
    
    params_list_nonblank =  np.array(params_list)[np.array(entropy) > 0]
    frequency_entropy_nonblank = np.array(frequency_entropy)[np.array(entropy) > 0]
    params = params_list_nonblank[np.argmax(np.array(frequency_entropy_nonblank))]

    t3 = time.time()
    idx += 1    
    time_elapsed = t3-t0
    total_zooms += 1
    
    if np.sum(gray_image) == 0:
        print("zoom no longer interesting, quitting")
        break

In [None]:
"""
stability_sweep 4x at 32x32
total elapsed: 1355.300 s, last sweep: 203.198
    dt from 1.72e-01 to 3.02e-01
    kr from 2.650000e+01 to 3.40e+01

v_stability_sweep only 2x at 32x32
total elapsed: 1041.628 s, last sweep: 509.198
    dt from 1.00e-02 to 5.30e-01
    kr from 1.900000e+01 to 4.90e+01
"""

# 3x local runs in float16 takes ~255 s (didn't work, update_step used float32)
# 3x local runs in float16 takes ~224 s (after fix)
# 3x local runs in float32 takes ~213 s (after fix)
