In [None]:
from IPython.display import HTML
from matplotlib import pyplot as plt
import jax.numpy as jnp
import numpy as np
import exponax as ex
import jax
import cmasher

import seaborn as sns
sns.set_theme()

from vape import diverging_alpha,render,viewer

In [None]:

cmap_linear = cmasher.watermelon
cmap_nonlinear = sns.color_palette("icefire", as_cmap=True)
cmap_diff = cmasher.copper_s
cmap_diff

In [None]:
resolution = 1024
fps = 30

In [None]:

import copy
from matplotlib.colors import LinearSegmentedColormap, ListedColormap


def triangle_wave(x,p):
    return 2*np.abs(x/p-np.floor(x/p+0.5))  


def zigzag_alpha(cmap,min_alpha=0.2):
    """changes the alpha channel of a colormap to be linear (0->0, 1->1)

    Args:
        cmap (Colormap): colormap

    Returns:a
        Colormap: new colormap
    """
    if isinstance(cmap, ListedColormap):
        colors = copy.deepcopy(cmap.colors)
        for i, a in enumerate(colors):
            a.append((triangle_wave(i / (cmap.N - 1),0.5)*(1-min_alpha))+min_alpha)
        return ListedColormap(colors, cmap.name)
    elif isinstance(cmap, LinearSegmentedColormap):
        segmentdata = copy.deepcopy(cmap._segmentdata)
        segmentdata["alpha"] = np.array([
            [0.0, 0.0, 0.0],
            [0.25, 1.0, 1.0],
            [0.5, 0.0, 0.0],
            [0.75, 1.0, 1.0],
            [1.0, 0.0, 0.0]]
        )
        return LinearSegmentedColormap(cmap.name,segmentdata)
    else:
        raise TypeError(
            "cmap must be either a ListedColormap or a LinearSegmentedColormap"
        )


In [None]:
from IPython.display import Video
import imageio.v2 as iio
import os
from tqdm import tqdm
os.environ["IMAGEIO_FFMPEG_EXE"] = "ffmpeg"  

def symmetric_min_max(arr):
    vmin = arr.min().item()
    vmax = arr.max().item()
    absmax = max(abs(vmin), abs(vmax))
    return -absmax, absmax

def chunk_list(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]
        
def create_video(out_file:str,volume, cmap,resolution=1024, duration=10, fps=30, vrange=None) -> Video:
    """
    Create a video from a volume
    
    Args:
        out_file (str): output file
        volume (np.ndarray): volume of shape [N,C,H,W,D]
        cmap (Colormap): colormap
        resolution (int): video width and height in pixels
        duration (int): duration in seconds
        fps (int): frames per second
        vrange (float): range of values to display
    """
    if vrange is None:
        vmin, vmax = symmetric_min_max(volume)
    elif isinstance(vrange, (int, float)):
        vmin, vmax = -vrange, vrange
    elif isinstance(vrange, (tuple, list)):
        vmin, vmax = vrange
    else:
        raise ValueError("vrange must be None, a number, or a tuple of two numbers")
        
    num_channels = volume.shape[1]
    n = duration*fps
    
    print("num_channels=",num_channels,", vmin=",vmin,",vmax=",vmax)

    with iio.get_writer(out_file, format='FFMPEG', mode='I', fps=fps,codec='h264_nvenc',) as w:
        batch_size = 64 # how many frames to store in memory at once
        for time_steps in tqdm(list(chunk_list(list(range(n)), batch_size))):
            frames = np.zeros((len(time_steps), resolution, resolution*num_channels, 3), dtype=np.uint8)
            for c in range(num_channels):
                imgs = render(
                    np.array(volume[:, c]),
                    cmap,
                    [i / (n-1) for i in time_steps],
                    background=(0, 0, 0, 255),
                    distance_scale=10,
                    vmin=vmin,
                    vmax=vmax,
                    width=resolution,
                    height=resolution,
                )
                # gamma correction
                imgs = ((imgs/255.)**(2.4)*255).astype(np.uint8)
                frames[:,:,resolution*c:resolution*(c+1)] = imgs[..., :3]
            for img in frames:
                w.append_data(img)

        w.close()
    return Video(url=out_file)


In [None]:

from os import makedirs


def symmetric_min_max(arr):
    vmin = arr.min().item()
    vmax = arr.max().item()
    absmax = max(abs(vmin), abs(vmax))
    return -absmax,absmax
#


DOMAIN_EXTENT = 1.0
NUM_POINTS = 64
DT = 0.01
NU = 0.01

burgers_stepper = ex.stepper.Burgers(3, DOMAIN_EXTENT, NUM_POINTS, DT, diffusivity=NU)

grid = ex.make_grid(3, DOMAIN_EXTENT, NUM_POINTS)


ic_gen = ex.ic.RandomTruncatedFourierSeries(3, cutoff=2, max_one=True)
multi_channel_ic_gen = ex.ic.RandomMultiChannelICGenerator([ic_gen, ic_gen, ic_gen])
u_0 = multi_channel_ic_gen(NUM_POINTS, key=jax.random.PRNGKey(1))

burgers_trj_3d = ex.rollout(burgers_stepper, 128, include_init=True)(u_0)

makedirs("videos", exist_ok=True)
create_video("videos/burgers.mp4",burgers_trj_3d,zigzag_alpha(cmap_nonlinear, 0.1),resolution=resolution, duration=10, fps=fps, vrange=0.4)

In [None]:
DOMAIN_EXTENT = 1.0
NUM_POINTS = 64
DT = 0.01
VELOCITY = 1.0

advection_stepper = ex.stepper.Advection(
    3,
    DOMAIN_EXTENT,
    NUM_POINTS,
    DT,
    velocity=np.array([VELOCITY * 2, VELOCITY * 0.4, VELOCITY]),
)

u_0 = ex.ic.DiffusedNoise(3, max_one=True, zero_mean=True)(
    NUM_POINTS, key=jax.random.PRNGKey(0)
)
advection_trj_3d = ex.rollout(advection_stepper, 64, include_init=True)(u_0)

create_video("videos/advection.mp4",advection_trj_3d,zigzag_alpha(cmap_linear, 0.1),resolution=resolution, duration=10, fps=fps, vrange=0.5)

In [None]:
DOMAIN_EXTENT = 30.0
NUM_POINTS = 64
DT = 0.1

ks_stepper = ex.stepper.KuramotoSivashinsky(3, DOMAIN_EXTENT, NUM_POINTS, DT)

# IC is irrelevant
u_0 = jax.random.normal(jax.random.PRNGKey(0), (1, NUM_POINTS, NUM_POINTS, NUM_POINTS))
warmed_up_u_0 = ex.repeat(ks_stepper, 500)(u_0)
ks_trj_3d = ex.rollout(ks_stepper, 64, include_init=True)(warmed_up_u_0)

create_video("videos/ks.mp4",ks_trj_3d,zigzag_alpha(cmap_nonlinear, 0.1),resolution=resolution, duration=10, fps=fps, vrange=3.0)

In [None]:
DOMAIN_EXTENT = 1.0
NUM_POINTS = 64
DT = 30.0
DIFFUSIVITY_0 = 2e-5
DIFFUSIVITY_1 = 1e-5
FEED_RATE = 0.04
KILL_RATE = 0.06

gray_scott_stepper = ex.RepeatedStepper(
    ex.reaction.GrayScott(
        3,
        DOMAIN_EXTENT,
        NUM_POINTS,
        DT / 30,
        diffusivity_1=DIFFUSIVITY_0,
        diffusivity_2=DIFFUSIVITY_1,
        feed_rate=FEED_RATE,
        kill_rate=KILL_RATE,
    ),
    15,
)

u_0 = ex.ic.RandomMultiChannelICGenerator(
    [
        ex.ic.RandomGaussianBlobs(3, one_complement=True, num_blobs=1),
        ex.ic.RandomGaussianBlobs(3, num_blobs=1),
    ]
)(NUM_POINTS, key=jax.random.PRNGKey(0))

gray_scott_trj_3d = ex.rollout(gray_scott_stepper, 128, include_init=True)(u_0)


create_video("videos/gray_scott.mp4",gray_scott_trj_3d,zigzag_alpha(cmap_diff, 0.0),resolution=resolution, duration=10, fps=fps, vrange=(0,1))