In [None]:
import os
import time

import numpy as np
import jax.numpy as jnp
import numpy.random as npr

import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.animation as animation
from matplotlib import cm

matplotlib.rcParams["animation.embed_limit"] = 4096

import skimage
import skimage.io as sio
import skimage.transform
import fracatal

from fracatal.functional_pt.convolve import ft_convolve
from fracatal.functional_pt.compose import make_gaussian, \
        make_mixed_gaussian, \
        make_kernel_field, \
        make_update_function, \
        make_update_step, \
        make_make_kernel_function

from fracatal.functional_pt.metrics import compute_entropy, \
        compute_frequency_ratio, \
        compute_frequency_entropy
        
from fracatal.scripts import v_stability_sweep, stability_sweep     

import IPython

In [None]:
"""
animation functions
"""

def get_fig(grid, my_cmap="jet"):
    
    global subplot_0
    
    fig, ax = plt.subplots(1,1, facecolor="black", figsize=(6,6))

    if grid.shape[1] == 3:
        subplot_0 = ax.imshow(grid.squeeze().permute(1,2,0))
    else:
        subplot_0 = ax.imshow(grid.squeeze(), cmap=my_cmap)

    ax.set_xticklabels("")
    ax.set_yticklabels("")
    return fig, ax

def update_frame(ii):
    
    global grid

    if grid.shape[1] == 3:
        subplot_0.set_array(grid.squeeze().permute(1,2,0))
    else:
        subplot_0.set_array(grid.squeeze())
    plt.tight_layout()
    
    grid = update_step(grid)

In [None]:
def get_params(pattern_name, grid_dim=64):
    if pattern_name == "gyrorbium":
        mu_g = 0.156
        sigma_g = 0.0224
        dts = 0.05
        my_cmap = "afmhot"
        mode = 0
        
    elif pattern_name == "gyropteron_arcus":
        mu_g = 0.293
        sigma_g = 0.0511
        dts = 0.05
        my_cmap = "magma"    
        mode = 0
    
    elif pattern_name == "adorbium":
        mu_g = 0.167
        sigma_g = 0.013
        dts = 0.1
        my_cmap = "viridis"    
        mode = 3
        
    elif pattern_name == "orbium_unicaudatus":
        mu_g = 0.15
        sigma_g = 0.017
        dts = 0.05
        my_cmap = "copper"
        mode = 0
        
    elif pattern_name == "scutium_gravidus_single":
        mu_g = 0.283
        sigma_g = 0.0369
        dts = 0.5
        my_cmap = "bone"
        mode = 0
        
    elif pattern_name == "asymdrop":
        mu_g = 0.12
        sigma_g = 0.005
        dts = 0.2
        my_cmap = "cividis"
        mode = 1
    
    elif "hydrogeminium" in pattern_name:
        mu_g = 0.26
        sigma_g = 0.056
        dts = 0.1
        
        my_cmap = "jet"
        mode = 0
    
    if "hydrogeminium" in pattern_name:
        amplitudes = [0.5, 1.0, 0.6667]
        means = [0.0938, 0.2814, 0.4690]
        standard_deviations = [0.0330, 0.0330, 0.0330]
        k0 = 31
        make_kernel = make_make_kernel_function(amplitudes, means, standard_deviations,dim=grid_dim-6)
    else:
        amplitudes = [1.0]
        means = [0.5]
        standard_deviations = [0.15]
        k0 = 13
        
        make_kernel = make_make_kernel_function(amplitudes, means, standard_deviations,dim=grid_dim-6)


    return mu_g, sigma_g, dts, my_cmap, mode, make_kernel, k0


In [None]:
def my_colormap(color="red"):

    def cm(x):
        mapped = np.zeros((*x.shape[:2], 4))
        
        if color == "red":
            mapped[...,0] = (x / x.max()).squeeze()
        elif color == "green":
            mapped[...,1] = (x / x.max()).squeeze()
        elif color == "blue":
            mapped[...,2] = (x / x.max()).squeeze()

        return mapped

    return cm
    

def make_image(image, cms, not_done=None, exploded=None, vanished=None, invert=True, max_t=10):
    image = image.squeeze()
    truncated_image = np.clip(image, 0, max_t)
    
    new_image = np.zeros((image.shape[-2], image.shape[-1], 4))

    # cms - colormaps
    #
    

    if exploded is not None and exploded.sum() > 0:
        new_image[exploded] = cms[0](truncated_image[exploded])
    if vanished is not None and vanished.sum() > 0:
        new_image[vanished] = cms[2](truncated_image[vanished])
    if not_done is not None and not_done.sum() > 0:
        new_image[not_done] = cms[1](truncated_image[not_done])
    else:
        new_image = cms[1](truncated_image)
        
    if invert:
        new_image = 1.0 - new_image[:,:,:3]
    else:
        new_image = new_image[:,:,:3]
    
    return new_image

def interpolate_history(history_1, history_2, alpha):
    """
    get the x and y limits for parameters during a zoom between images at two scales
    This functionality is adapted from (MIT Licensed)
    https://github.com/Sohl-Dickstein/fractal
    """
    
    _, limits_1 = history_1
    _, limits_2 = history_2
    
    if alpha == 0:
        # avoid NaNs on very last frame
        return limits_1
    
    w1 = np.array([limits_1[1] - limits_1[0], limits_1[3] - limits_1[2]])
    w2 = np.array([limits_2[1] - limits_2[0], limits_2[3] - limits_2[2]])
    c1 = np.array([(limits_1[0] + limits_1[1])/2, (limits_1[2] + limits_1[3])/2])
    c2 = np.array([(limits_2[0] + limits_2[1])/2, (limits_2[2] + limits_2[3])/2])
    
    gamma = np.exp((1-alpha)*0 + alpha*np.log(w2/w1))
    
    # ct = cstar + (c1 - cstar)*gamma
    # c1 = cstar + (c1 - cstar)*1
    # c2 = cstar + (c1 - cstar)*w2/w1
    cstar = (c2 - c1*w2/w1) / (1 - w2 / w1)
    
    ct = cstar + (c1 - cstar)*gamma
    hwt = gamma*w1
    
    return [ct[0] - hwt[0]/2, ct[0] + hwt[0]/2, ct[1] - hwt[1]/2, ct[1] + hwt[1]/2]
    
def plot_fig(history):

    global frame_0

    img_0 = history[0][0]
    bounds = history[0][1]
    min_h, max_h, min_w, max_w = bounds
    
    fig, ax = plt.subplots(1,1, figsize=(18,16), facecolor="black")

    frame_0 = ax.imshow(img_0, \
            extent=[min_w, max_w, min_h, max_h],\
                        cmap="magma",
            origin='upper',
            aspect='auto')#,
            #interpolation='nearest')

    ax.set_ylabel("$\Delta t$")
    ax.set_xlabel("$\sigma$")
    ax.set_title("Title")
    
    rect = patches.Rectangle((min_w, min_h), max_w-min_w, max_h-min_h, linewidth=1, edgecolor='k', facecolor='none')
    ax.add_patch(rect)
    
    frame_0.set_extent([min_w, max_w, min_h, max_h])
    frame_0.set_data(img_0)
    
    # Set the new tick positions on the x-axis
    aaxx = plt.gca()
    aaxx.set_xticks([min_w, max_w])
    aaxx.set_yticks([min_h, max_h])
    
    plt.tight_layout()
   
    plt.draw()
    print(img_0.shape)
    
    return fig, ax, frame_0

def make_animator(history, timesteps=30, axes_labels=["$Delta t$", "$\sigma$"]):

    img_0 = history[0][0]
    img_1 = history[1][0]
    bounds_0 = history[0][1]
    bounds_1 = history[1][1]
    extent0 = history[0][1]
    extent1 = history[1][1]
    
    fig, ax, frame_0 = plot_fig(history)
    [min_h1, max_h1, min_w1, max_w1] = bounds_1
    [min_h0, max_h0, min_w0, max_w0] = bounds_0
    
    frame_1 = ax.imshow(img_1*0,
            origin='upper',\
            extent=[min_w1, max_w1, min_h1, max_h1],\
            aspect='auto',\
            vmin=0, vmax=1.0,\
            cmap="magma")#,\
            #interpolation='nearest')
    
    
    hist0 = history[0]
    hist1 = history[1]
    
    def update_frame(ii):
    
        
        history_index = ii // timesteps

        if history_index < (len(history)-1):
            hist0 = history[history_index]
            hist1 = history[history_index+1] 
        else: 
            hist0 = history[history_index]
            hist1 = history[history_index]
        
        img_0 = hist0[0]
        img_1 = hist1[0]
        bounds_0 = hist0[1]
        bounds_1 = hist1[1]
        
        [min_h1, max_h1, min_w1, max_w1] = bounds_1
        [min_h0, max_h0, min_w0, max_w0] = bounds_0
        
        alpha = (ii % timesteps) / timesteps

        differences = [elem_a - elem_b for elem_a, elem_b in zip(hist0[1], hist1[1])]
        is_different = np.sum(differences)
        
        if is_different:
            limits = interpolate_history(hist0, hist1, alpha)
        else:
            limits = hist0[1]
            alpha = 0.0

        alpha_area = jnp.sin(alpha*np.pi/2)**2

        ax.set_ylim(limits[0], limits[1])
        ax.set_xlim(limits[2], limits[3])

        ax.set_xticks([limits[2], limits[3]])
        ax.set_yticks([limits[0], limits[1]])
        ax.set_xticklabels([f"{limits[2]:.4e}", f"{limits[3]:.4e}"], \
                           color=[0.9, 0.9, 0.9], fontsize=20, rotation=10)
        ax.set_yticklabels([f"{limits[0]:.4e}", f"{limits[1]:.4e}"], \
                           color=[0.9, 0.9, 0.9], fontsize=20)

        ax.set_xlabel(axes_labels[0], color=[0.9, 0.9, 0.9], fontsize=48)
        ax.set_ylabel(axes_labels[1], color=[0.9, 0.9, 0.9], fontsize=48)
        
        print(f"index {history_index}, frame {ii} of {timesteps*(len(history)-1)}", end="\r")
        
        frame_0.set_data(img_0)
        frame_0.set_extent([min_w0, max_w0, min_h0, max_h0])
                
        frame_1.set_data(img_1)
        frame_1.set_extent([min_w1, max_w1, min_h1, max_h1])
        frame_1.set_alpha(alpha)
        
        #frame_bg.set_extent([min_w1, max_w1, min_h1, max_h1])
        #frame_bg.set_alpha(1.0 - alpha)
        plt.tight_layout()
    

        return fig
        
    my_animator = animation.FuncAnimation(fig, update_frame, frames=1+(timesteps*(len(history)-1)), interval=33, repeat=False)

    return my_animator



In [None]:
#folder_name = "asymdrop_asymdrop_test_1717471068_926391" # sweep_mu = True; sweep_kr = False
#folder_name = "orbium_unicaudatus_orbium_unicaudatus_sigma_kr_zoom_1718061217_8708138" # sweep_kr = True; sweep_mu = False
folder_name = "adorbium_fractal_zoom_p768" # sweep_mu = False; sweep_kr = False

max_t = 16
sweep_mu = False
sweep_kr = False

data_directory = os.path.join("..", "results", f"{folder_name}")

directory_list = os.listdir(data_directory)

for elem in directory_list:
    if os.path.splitext(elem)[1] == ".txt":
        metadata_path = os.path.join(data_directory, elem)

    
min_dt, max_dt = 0.05, 1.05
min_sigma, max_sigma = 0.005, 0.025

with open(metadata_path, "r") as f:
    metadata_lines = f.readlines()

print(metadata_lines[0])

zoom_history = []
cms = [my_colormap("red"), my_colormap("green"), my_colormap("blue")]

for my_line in metadata_lines[1:]:

    if not sweep_kr:
        acc_t_filename = os.path.split(my_line.split(",")[18])[-1]
        explode_filename = os.path.split(my_line.split(",")[20])[-1]
        vanish_filename = os.path.split(my_line.split(",")[21])[-1]
        
        if "explode" in explode_filename:
            print(explode_filename)
            explode_filepath = os.path.join(data_directory, explode_filename)
            explode = (np.load(explode_filepath) > 0).squeeze()
        if "vanish" in vanish_filename:
            print(vanish_filename)
            vanish_filepath = os.path.join(data_directory, vanish_filename)
            vanish = (np.load(vanish_filepath) > 0).squeeze()
        if "accumulated_t" in acc_t_filename:
            print(acc_t_filename)
            acc_t_filepath = os.path.join(data_directory, acc_t_filename)
            acc_t = (np.load(acc_t_filepath)).squeeze()
                    
        not_done = (vanish + explode) == 0

    else:        
        acc_t_filename = os.path.split(my_line.split(",")[18])[-1]
        explode_filename = os.path.split(my_line.split(",")[20])[-1]
        vanish_filename = os.path.split(my_line.split(",")[21])[-1]
        
        if "explode" in explode_filename:
            print(explode_filename)
            explode_filepath = os.path.join(data_directory, explode_filename)
            explode = (np.load(explode_filepath) > 0).squeeze()
        if "vanish" in vanish_filename:
            print(vanish_filename)
            vanish_filepath = os.path.join(data_directory, vanish_filename)
            vanish = (np.load(vanish_filepath) > 0).squeeze()
        if "accumulated_t" in acc_t_filename:
            print(acc_t_filename)
            acc_t_filepath = os.path.join(data_directory, acc_t_filename)
            acc_t = (np.load(acc_t_filepath)).squeeze()
            
        not_done = (vanish + explode) == 0
        
            
    
    
    image = make_image(acc_t, cms=cms, not_done=not_done, exploded=explode, vanished=vanish, invert=False, max_t=max_t)
    
    min_dt, max_dt = float(my_line.split(",")[4]), float(my_line.split(",")[5])
    min_sigma, max_sigma = float(my_line.split(",")[6]), float(my_line.split(",")[7])
    min_mu, max_mu = float(my_line.split(",")[2]), float(my_line.split(",")[3])
    
    min_kr, max_kr = float(my_line.split(",")[8]), float(my_line.split(",")[9])

    if sweep_kr: 
        zoom_history.append([image, (max_sigma, min_sigma, min_kr, max_kr)])
    elif sweep_mu:
        zoom_history.append([image, (min_mu, max_mu, min_dt, max_dt)])
    else:
        zoom_history.append([image, (min_dt, max_dt, min_sigma, max_sigma)])

plt.subplot(121); plt.imshow(zoom_history[0][0])
plt.subplot(122); plt.imshow(zoom_history[-1][0])


In [None]:
my_timesteps = 128

if sweep_kr:
    y_label = "$\sigma$"
    x_label = "$k_r$"
elif sweep_mu:
    y_label = "$\mu$"
    x_label = "$\Delta t$"
else:
    y_label = "$\Delta t$"
    
    x_label = "$\sigma$"

anim  = make_animator(zoom_history, timesteps=my_timesteps, axes_labels=[x_label, y_label])

In [None]:
anim.save(f"exp_{folder_name}_zoom_{my_timesteps}timesteps.mp4")