In [None]:
import os

import jax.numpy as np
import jax.numpy as jnp
#import porespy as ps

import matplotlib.pyplot as plt

import IPython

import matplotlib
import matplotlib.patches as patches
import matplotlib.animation as animation
from matplotlib import cm

matplotlib.rcParams["animation.embed_limit"] = 4096

In [None]:
def make_image(image, max_t=10):
    return (image.squeeze() / max_t)

def interpolate_history(hist1, hist2, alpha):
    """
    get the mnmx (hyperparameter bounding box) value for a fraction alpha between
    two images
    """
    
    _, mnmx1 = hist1
    _, mnmx2 = hist2
    
    if alpha == 0:
        # avoid NaNs on very last frame
        return mnmx1
    
    w1 = np.array([mnmx1[1] - mnmx1[0], mnmx1[3] - mnmx1[2]])
    w2 = np.array([mnmx2[1] - mnmx2[0], mnmx2[3] - mnmx2[2]])
    c1 = np.array([(mnmx1[0] + mnmx1[1])/2, (mnmx1[2] + mnmx1[3])/2])
    c2 = np.array([(mnmx2[0] + mnmx2[1])/2, (mnmx2[2] + mnmx2[3])/2])
    
    gamma = np.exp((1-alpha)*0 + alpha*np.log(w2/w1))
    
    # ct = cstar + (c1 - cstar)*gamma
    # c1 = cstar + (c1 - cstar)*1
    # c2 = cstar + (c1 - cstar)*w2/w1
    cstar = (c2 - c1*w2/w1) / (1 - w2 / w1)
    
    ct = cstar + (c1 - cstar)*gamma
    hwt = gamma*w1
    
    return [ct[0] - hwt[0]/2, ct[0] + hwt[0]/2, ct[1] - hwt[1]/2, ct[1] + hwt[1]/2]
    
def plot_fig(history):

    global frame_0

    img_0 = history[0][0]
    bounds = history[0][1]
    min_h, max_h, min_w, max_w = bounds
    
    fig, ax = plt.subplots(1,1, figsize=(8,8))

    frame_0 = ax.imshow(img_0, \
            extent=[min_w, max_w, min_h, max_h],\
                        cmap="magma",
            origin='upper',
            aspect='auto',
            interpolation='nearest')

    ax.set_ylabel("$\Delta t$")
    ax.set_xlabel("$\sigma$")
    ax.set_title("Title")
    
    rect = patches.Rectangle((min_w, min_h), max_w-min_w, max_h-min_h, linewidth=1, edgecolor='k', facecolor='none')
    ax.add_patch(rect)
    
    frame_0.set_extent([min_w, max_w, min_h, max_h])
    frame_0.set_data(img_0)
    
    # Set the new tick positions on the x-axis
    aaxx = plt.gca()
    aaxx.set_xticks([min_w, max_w])
    aaxx.set_yticks([min_h, max_h])
    
    plt.tight_layout()
   
    plt.draw()
    print(img_0.shape)
    
    return fig, ax, frame_0

def make_animator(history, timesteps=30):

    img_0 = history[0][0]
    img_1 = history[1][0]
    bounds_0 = history[0][1]
    bounds_1 = history[1][1]
    extent0 = history[0][1]
    extent1 = history[1][1]
    
    fig, ax, frame_0 = plot_fig(history)
    [min_h1, max_h1, min_w1, max_w1] = bounds_1
    [min_h0, max_h0, min_w0, max_w0] = bounds_0
    
    frame_1 = ax.imshow(img_1*0,
            origin='upper',\
            extent=[min_w1, max_w1, min_h1, max_h1],\
            aspect='auto',\
            vmin=0, vmax=1.0,\
            cmap="magma",\
            interpolation='nearest')
    
    #frame_2 = ax.imshow(img_2, extent=extent1,\
    #        origin='upper',
    #        aspect='auto',
    #        cmap="magma",
    #        interpolation='nearest')
    
    hist0 = history[0]
    hist1 = history[1]
    
    def update_frame(ii):
    
        
        history_index = ii // timesteps

        if history_index < (len(history)-1):
            hist0 = history[history_index]
            hist1 = history[history_index+1] 
        else: 
            hist0 = history[history_index]
            hist1 = history[history_index]
        
        img_0 = hist0[0]
        img_1 = hist1[0]
        bounds_0 = hist0[1]
        bounds_1 = hist1[1]
        
        [min_h1, max_h1, min_w1, max_w1] = bounds_1
        [min_h0, max_h0, min_w0, max_w0] = bounds_0
        
        alpha = (ii % timesteps) / timesteps
        
        limits = interpolate_history(hist0, hist1, alpha)

        alpha_area = jnp.sin(alpha*np.pi/2)**2

        ax.set_ylim(limits[0], limits[1])
        ax.set_xlim(limits[2], limits[3])

        ax.set_xticks([limits[2], limits[3]])
        ax.set_yticks([limits[0], limits[1]])
        ax.set_xticklabels([f"{limits[2]:.4e}", f"{limits[3]:.4e}"])
        ax.set_yticklabels([f"{limits[0]:.4e}", f"{limits[1]:.4e}"])
        
        print(f"index {history_index}, frame {ii} of {timesteps*(len(history)-1)}", end="\r")
        
        frame_0.set_data(img_0)
        frame_0.set_extent([min_w0, max_w0, min_h0, max_h0])
                
        frame_1.set_data(img_1)
        frame_1.set_extent([min_w1, max_w1, min_h1, max_h1])
        frame_1.set_alpha(alpha)
        plt.tight_layout()
    

        return fig
        
    my_animator = animation.FuncAnimation(fig, update_frame, frames=1+(timesteps*(len(history)-1)), interval=33, repeat=False)

    return my_animator



In [None]:
folder_name = "gyrorbium_gyrorbium_test_1715052411_694593"
data_directory = os.path.join("..", "results", f"{folder_name}")

directory_list = os.listdir(data_directory)

for elem in directory_list:
    if os.path.splitext(elem)[1] == ".txt":
        metadata_path = os.path.join(data_directory, elem)

    
min_dt, max_dt = 0.05, 1.05
min_sigma, max_sigma = 0.005, 0.025

with open(metadata_path, "r") as f:
    metadata_lines = f.readlines()

print(metadata_lines[0])

zoom_history = []
for my_line in metadata_lines[1:]:

    filename = os.path.split(my_line.split(",")[18])[-1]
    if "accumulated_t" in filename:
        print(filename)
        filepath = os.path.join(data_directory, filename)
    
        image = make_image(np.load(filepath))
        
        min_dt, max_dt = float(my_line.split(",")[4]), float(my_line.split(",")[5])
        min_sigma, max_sigma = float(my_line.split(",")[6]), float(my_line.split(",")[7])
        zoom_history.append([image, (min_dt, max_dt, min_sigma, max_sigma)])
    
plt.imshow(zoom_history[0][0])


In [None]:
anim  = make_animator(zoom_history, timesteps=33)
IPython.display.HTML(anim.to_jshtml())

In [None]:
anim.save("gyrorbium_zoom_test.mp4")