# JAX Fractal Renderer

[https://github.com/tripplyons/jax-fractal-renderer](https://github.com/tripplyons/jax-fractal-renderer)

Hint: If you are using Colab, select "Runtime" > "Change runtime type" > "Hardware accelerator" > "GPU" to use GPU.

In [None]:
!pip install jax tqdm optax matplotlib pillow numpy psgd-jax flax
!mkdir -p output

In [None]:
import jax
import jax.numpy as jnp
import optax
from tqdm import tqdm
from PIL import Image
from IPython.display import display
import numpy as np
# the optimizer that I found to work the best
from psgd_jax.affine import scale_by_affine

index_dtype = jnp.int32
dtype = jnp.float32

def apply_transform(x, y, settings):
    a, b, c, d, e, f, power = settings

    # apply a matrix multiplication and vector addition using the constants
    x_prime = x * a + y * b + c
    y_prime = x * d + y * e + f

    # use a non-linearity that helps keep the points centered
    scale_squared = x_prime ** 2 + y_prime ** 2
    # divide by 2 for square root
    x_prime /= scale_squared ** (power / 2)
    y_prime /= scale_squared ** (power / 2)

    return x_prime, y_prime

def update(x, y, settings, rng):
    rate, settings_1, settings_2 = settings

    # apply the first transform to the points
    x_prime_1, y_prime_1 = apply_transform(x, y, settings_1)
    # apply the second transform to the points
    x_prime_2, y_prime_2 = apply_transform(x, y, settings_2)

    # randomly choose one of the transforms (chaos game)
    mask = jax.random.uniform(rng, shape=(x.shape[0],), minval=0, maxval=1, dtype=dtype)
    mask = 1 * (mask < rate)

    # choose the first transform where the mask is 1
    # and the second transform where the mask is 0
    x = x_prime_1 * mask + x_prime_2 * (1 - mask)
    y = y_prime_1 * mask + y_prime_2 * (1 - mask)

    return x, y

# density is the vector of the density values at each point
# it is initialized to zeros
# it is updated where the current points are
def optimize(density, batch_size, x_res, y_res, settings, iterations, bounds, rng):
    min_x, max_x, min_y, max_y = bounds

    # use a warmup schedule to start the optimization at a lower learning rate
    schedule = optax.schedules.warmup_constant_schedule(0, 1, iterations // 10)
    optimizer = optax.chain(
        scale_by_affine(), # optimizer
        optax.scale_by_schedule(schedule) # learning rate schedule
    )
    opt_state = optimizer.init(density)

    # start at the origin
    x = jnp.zeros(batch_size, dtype=dtype)
    y = jnp.zeros(batch_size, dtype=dtype)

    # calculate how large each pixel is
    dx = (max_x - min_x) / (x_res - 1)
    dy = (max_y - min_y) / (y_res - 1)

    @jax.jit
    def step(density, x, y, opt_state, rng):
        update_rng, rng = jax.random.split(rng, 2)

        # update the points
        x, y = update(x, y, settings, update_rng)

        # find the nearest pixel to the current points
        x_index = jnp.round((x - min_x) / dx).astype(index_dtype)
        y_index = jnp.round((y - min_y) / dy).astype(index_dtype)

        # go to an invalid index after the end of our vector if the coordinate is negative or too large
        x_index = x_index + (-x_index + x_res) * ((x_index < 0) + (x_index >= x_res))
        y_index = y_index + (-y_index + y_res * x_res) * ((y_index < 0) + (y_index >= y_res))

        # set the gradient to 0 by default
        grad = jnp.zeros((x_res * y_res,), dtype=dtype)
        # set the gradient to 1 at the current locations
        grad = grad.at[x_index * y_res + y_index].set(1)
        # normalize the gradient to help the optimizer converge faster
        grad = jax.nn.standardize(grad, axis=0)

        updates, opt_state = optimizer.update(grad, opt_state, density)
        density = optax.apply_updates(density, updates)

        return density, x, y, opt_state, rng

    # run the optimization for the specified number of iterations
    for _ in tqdm(range(iterations)):
        density, x, y, opt_state, rng = step(density, x, y, opt_state, rng)

    return density

def main(rng, res):
    x_res = res
    y_res = res
    batch_size = 1024
    iterations = 10000
    scale = 3
    bounds = (-scale, scale, -scale, scale)

    density = jnp.zeros((x_res * y_res), dtype=dtype)

    params_key, rng = jax.random.split(rng, 2)

    min_scale = 1
    max_scale = 1.5

    min_bias = 0.25
    max_bias = 0.75

    min_power = 0.2
    max_power = 0.3

    min_rate = 0.4
    max_rate = 0.5

    def make_settings(rng):
        angle1_key, angle2_key, angle3_key, scale1_key, scale2_key, bias_scale_key, power_key = jax.random.split(rng, 7)

        angle1 = jax.random.uniform(angle1_key, shape=(), minval=0, maxval=2 * jnp.pi, dtype=dtype)
        angle2 = jax.random.uniform(angle2_key, shape=(), minval=0, maxval=2 * jnp.pi, dtype=dtype)
        angle3 = jax.random.uniform(angle3_key, shape=(), minval=0, maxval=2 * jnp.pi, dtype=dtype)
        scale1 = jax.random.uniform(scale1_key, shape=(), minval=min_scale, maxval=max_scale, dtype=dtype)
        scale2 = jax.random.uniform(scale2_key, shape=(), minval=min_scale, maxval=max_scale, dtype=dtype)
        bias_scale = jax.random.uniform(bias_scale_key, shape=(), minval=min_bias, maxval=max_bias, dtype=dtype)
        power = jax.random.uniform(power_key, shape=(), minval=min_power, maxval=max_power, dtype=dtype)

        print(scale1, scale2, bias_scale, power)

        settings = (
            jnp.cos(angle1) * scale1,
            jnp.cos(angle2) * scale2,
            jnp.cos(angle3) * bias_scale,
            jnp.sin(angle1) * scale1,
            jnp.sin(angle2) * scale2,
            jnp.sin(angle3) * bias_scale,
            power
        )

        return settings

    settings1_key, settings2_key, rate_key = jax.random.split(params_key, 3)
    settings1 = make_settings(settings1_key)
    settings2 = make_settings(settings2_key)
    rate = jax.random.uniform(rate_key, shape=(), minval=min_rate, maxval=max_rate, dtype=dtype)
    print(rate)

    settings = (rate, settings1, settings2)

    density = optimize(density, batch_size, x_res, y_res, settings, iterations, bounds, rng)

    return density.reshape((x_res, y_res))

def plot(name, density):
    density -= density.min()
    density /= density.max()

    density = (255 * density).astype(jnp.uint8)
    density = np.array(density)

    img = Image.fromarray(density)
    display(img)
    img.save(f'output/{name}.png')

    return density

if __name__ == '__main__':
    res = 1024
    num_rows = 5
    num_cols = 5
    num_images = num_rows * num_cols
    images = []

    for seed in range(num_images):
        rng = jax.random.PRNGKey(seed)
        density = main(rng, res)
        images.append(plot(seed, density))

    grid = np.array(images).reshape((num_rows, num_cols, res, res))
    grid = np.transpose(grid, (0, 2, 1, 3)).reshape(res * num_rows, res * num_cols)

    grid_img = Image.fromarray(grid).resize((res, res))
    display(grid_img)
    grid_img.save('output/grid.png')