##Flux random walk

In [None]:
# This is a simple python notebook to render random walks through flux latent
# space, simultaneously traversing prompt embed space (concepts) and init noise
# space (shapes).
#
# Works with flux-schnell-fp8 and flux-dev-fp8. Requires less than 12 GB RAM
# and less than 15 GB VRAM, so it should run fine on Google Colab. Faster with
# a better GPU, of course.
#
# Robert Luxemburg, 2024, Public Domain

##Create output directory

In [None]:
# Once you're certain that this notebook does not mess with your personal
# files, set trust_this_notebook to True. This will allow it to save the
# generated images and videos directly to your Google Drive.

import os
from google.colab import drive

trust_this_notebook = False

if trust_this_notebook:
    drive.mount("/content/drive")
    OUTPUT_DIR = "/content/drive/MyDrive/FLUX.1/outputs"
else:
    OUTPUT_DIR = "/content/outputs"

os.makedirs(OUTPUT_DIR, exist_ok=True)

##Download the model

In [None]:
# This code is based on camenduru's flux notebook. You can find the original at
# https://github.com/camenduru/flux-jupyter/blob/main/flux.1-dev_jupyter.ipynb
#
# The flux model can be selected below. Schnell is schneller, dev looks better.
#
# If you're not planning to use "tokens" mode (see further down), you can set
# load_encoder to False. This will save some time on startup, and some memory.

model = "schnell" # "schnell" or "dev"
load_encoder = True

%cd /content
!git clone -b totoro4 https://github.com/camenduru/ComfyUI /content/TotoroUI
%cd /content/TotoroUI

# drill a tiny hole through some comfy internals
filename = "/content/TotoroUI/totoro/sample.py"
source = open(filename).read()
string = "generator = torch.manual_seed(seed)"
patch = "if type(seed) is torch.Tensor: return seed\n    "
open(filename, "w").write(source.replace(string, patch + string))

!pip install -q torchsde einops diffusers accelerate xformers==0.0.27
!apt -y install -qq aria2

if load_encoder:
    !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/t5xxl_fp8_e4m3fn.safetensors -d /content/TotoroUI/models/clip -o t5xxl_fp8_e4m3fn.sft
    !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/clip_l.safetensors -d /content/TotoroUI/models/clip -o clip_l.sft

if model == "schnell":
    !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/kijai/flux-fp8/resolve/main/flux1-schnell-fp8.safetensors -d /content/TotoroUI/models/unet -o flux1-schnell-fp8.sft
else:
    !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/kijai/flux-fp8/resolve/main/flux1-dev-fp8.safetensors -d /content/TotoroUI/models/unet -o flux1-dev-fp8.sft
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/ae.sft -d /content/TotoroUI/models/vae -o ae.sft


##Load the model

In [None]:
import numpy as np
from PIL import Image
import torch

from nodes import NODE_CLASS_MAPPINGS
from totoro import model_management
from totoro_extras import nodes_custom_sampler, nodes_flux

DualCLIPLoader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
UNETLoader = NODE_CLASS_MAPPINGS["UNETLoader"]()
FluxGuidance = nodes_flux.NODE_CLASS_MAPPINGS["FluxGuidance"]()
RandomNoise = nodes_custom_sampler.NODE_CLASS_MAPPINGS["RandomNoise"]()
BasicGuider = nodes_custom_sampler.NODE_CLASS_MAPPINGS["BasicGuider"]()
KSamplerSelect = nodes_custom_sampler.NODE_CLASS_MAPPINGS["KSamplerSelect"]()
BasicScheduler = nodes_custom_sampler.NODE_CLASS_MAPPINGS["BasicScheduler"]()
SamplerCustomAdvanced = nodes_custom_sampler.NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
VAELoader = NODE_CLASS_MAPPINGS["VAELoader"]()
VAEDecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
EmptyLatentImage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()

with torch.inference_mode():
    if load_encoder:
        clip = DualCLIPLoader.load_clip("t5xxl_fp8_e4m3fn.sft", "clip_l.sft", "flux")[0]
    unet = UNETLoader.load_unet(f"flux1-{model}-fp8.sft", "fp8_e4m3fn")[0]
    vae = VAELoader.load_vae("ae.sft")[0]

##Define functions

In [None]:
from scipy.ndimage import gaussian_filter

def get_noise(seed, shape, sigma):
    g = torch.Generator().manual_seed(seed) if type(seed) is int else seed
    noise = torch.randn(shape, generator=g)
    sigmas = [sigma] + (len(shape) - 1) * [0]
    noise = gaussian_filter(noise, sigma=sigmas, mode="wrap")
    noise = (noise - noise.mean()) / noise.std()
    return torch.Tensor(noise).to(torch.float16)

def get_tokens(seed):
    g = torch.Generator().manual_seed(seed) if type(seed) is int else seed
    ids_l = [49406] + torch.randint(1, 49406, (75,), generator=g).tolist() + [49407]
    ids_t5xxl = torch.randint(2, 32128, (255,), generator=g).tolist() + [1]
    return {
        "l": [[(id, 1.0) for id in ids_l]],
        "t5xxl": [[(id, 1.0) for id in ids_t5xxl]]
    }

def slerp(vs, t, loop=True, DOT_THRESHOLD=0.9995):
    n = len(vs)
    if n == 1:
        return vs[0]
    nn = n if loop else n - 1
    v0 = vs[int(t * nn) % n]
    v1 = vs[int(t * nn + 1) % n]
    t = t * nn % 1
    dot = torch.sum(v0 * v1 / (torch.linalg.norm(v0) * torch.linalg.norm(v1)))
    if torch.abs(dot) > DOT_THRESHOLD or torch.isnan(dot):
        return (1 - t) * v0 + t * v1
    theta_0 = torch.acos(dot)
    sin_theta_0 = torch.sin(theta_0)
    theta_t = theta_0 * t
    sin_theta_t = torch.sin(theta_t)
    s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
    s1 = sin_theta_t / sin_theta_0
    return s0 * v0 + s1 * v1

def encode(tokens):
    with torch.inference_mode():
        cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
    return cond, pooled

def render(
    filename,
    prompt, # tuple of tensors ((1, 256, 4096), (1, 768))
    noise,  # tensor (16, height//8, width//8)
    steps=4 if model == "schnell" else 20,
    guidance=3.5
):
    if os.path.exists(filename):
        return
    print(filename.replace(f"{OUTPUT_DIR}/", ""))
    cond = [[prompt[0], {"pooled_output": prompt[1]}]]
    width, height = noise.shape[2] * 8, noise.shape[1] * 8

    with torch.inference_mode():
        cond = FluxGuidance.append(cond, guidance)[0]
        random_noise = RandomNoise.get_noise(noise)[0]
        guider = BasicGuider.get_guider(unet, cond)[0]
        sampler = KSamplerSelect.get_sampler("euler")[0]
        sigmas = BasicScheduler.get_sigmas(unet, "simple", steps, 1.0)[0]
        latent_image = EmptyLatentImage.generate(width, height)[0]
        sample, sample_denoised = SamplerCustomAdvanced.sample(
            random_noise, guider, sampler, sigmas, latent_image
        )
        model_management.soft_empty_cache()
        decoded = VAEDecode.decode(vae, sample)[0].detach()
        image = Image.fromarray(np.array(decoded * 255, dtype=np.uint8)[0])

    os.makedirs(os.path.dirname(filename), exist_ok=True)
    image.save(filename)
    return image

##Render images

In [None]:
# There are three different modes, "gauss", "slerp" and "tokens".
#
# "gauss" applies gaussian blur to the input noise, resulting in smoother
# transitions between frames. For long sequences, this will run out of memory.
# With plot_noise = True, you can check how the input looks for a given sigma.
#
# "slerp" performs spherial linear interpolation along a sequence of random
# samples. This uses less RAM, but results in faster transitions and can lead
# to visible changes of direction (if you know what you're looking for).
#
# "tokens" works like "slerp", but instead of random noise uses embeddings of
# random tokens. This will take some time on startup, but should produce fewer
# blank, blurry or monochromatic images (i.e. totally "meaningless" noise).
#
# You can set mode, width, height, steps, sigma and seed below. Width and
# height must be multiples of 16, steps are the number of frames, and sigma is
# the amount of blur ("gauss") or the number of samples ("slerp" or "tokens").
#
# This renders two consecutive frames first, so you can estimate the rate of
# change, and then a few "keyframes". If you don't like them, pick a new seed.

import matplotlib.pyplot as plt
from tqdm import tqdm

mode = "gauss" # "gauss", "slerp" or "tokens"
width, height = 1024, 1024
steps, sigma = 225, 15
seed = 42
plot_noise = True

dirname = f"{model},{mode},{width},{height},{steps},{sigma},{seed}"
g = torch.Generator().manual_seed(seed)

if mode == "gauss":
    cond = get_noise(g, (steps, 1, 256, 4096), sigma)
    pooled = get_noise(g, (steps, 1, 768), sigma)
    init = get_noise(g, (steps, 16, height//8, width//8), sigma)
elif mode == "slerp":
    cond = torch.randn((sigma, 1, 256, 4096), generator=g)
    pooled = torch.randn((sigma, 1, 768), generator=g)
    init = torch.randn((sigma, 16, height//8, width//8), generator=g)
else:
    cond_pooled = (encode(get_tokens(g)) for _ in tqdm(range(sigma)))
    cond, pooled = zip(*cond_pooled)
    init = torch.randn((sigma, 16, height//8, width//8), generator=g)

# trying to match mean and std of T5+Clip encoder output
if mode in ("gauss", "slerp"):
    cond *= 0.14
    pooled -= 0.11

if plot_noise:
    if mode == "gauss":
        plt.plot(init[:,0,0,0])
    else:
        plt.plot([slerp(init[:,0,0,0], t) for t in np.arange(0, 1, 1 / steps)])
    plt.show()

n = steps//sigma if mode == "gauss" else sigma
for step in (steps - 1, n, 1):
    for i in range(0, steps, step):
        if mode == "gauss":
            prompt = cond[i], pooled[i]
            noise = init[i]
        else:
            t = i / steps
            prompt = slerp(cond, t), slerp(pooled, t)
            noise = slerp(init, t)
        render(
            f"{OUTPUT_DIR}/random_walk/{dirname}/{i:08d}.png",
            prompt,
            noise
        )

##Render video

In [None]:
# For the video output, you can chose a start frame and the direction (1 for
# forward, -1 for reversed). You may also want to upscale the video slightly,
# to keep sites like YouTube from downscaling it later. fps is the frame rate.

start = 0
direction = 1 # 1 or -1
width, height = 1080, 1080
fps = 15

path = f"{OUTPUT_DIR}/random_walk/{dirname}"
filename = f"{path}.txt"
indices = ((start + i * direction) % steps for i in range(steps))
frames = (f"file '{path}/{i:08d}.png'" for i in indices)
open(filename, "w").write("\n".join(frames))

!ffmpeg -y -r {fps} -f concat -safe 0 -i {filename} \
    -vf scale={width}:{height}:flags=lanczos \
    -vcodec libx264 -pix_fmt yuv420p -crf 17 {path}.mp4

os.remove(filename)