# Example local diffusion with FLUX

This notebook requires at least 12GB of GPU RAM. The it runs a local uncensored version of the [FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) model and allows for some parameters.

This will also serve as an example of CLIP encodings added to a vision model at a purely practical level.

First, some dependencies that will simplify our calls to the model:

In [None]:
%cd /content
!git clone -b totoro3 https://github.com/camenduru/ComfyUI /content/TotoroUI
%cd /content/TotoroUI

!pip install -q torchsde einops diffusers accelerate xformers==0.0.28.post2
!apt -y install -qq aria2

!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors -d /content/TotoroUI/models/unet -o flux1-schnell.safetensors
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/ae.sft -d /content/TotoroUI/models/vae -o ae.sft
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/clip_l.safetensors -d /content/TotoroUI/models/clip -o clip_l.safetensors
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/FLUX.1-dev/resolve/main/t5xxl_fp8_e4m3fn.safetensors -d /content/TotoroUI/models/clip -o t5xxl_fp8_e4m3fn.safetensors

We import all necessary libraries:

In [None]:
import random
import torch
import numpy as np
from PIL import Image
import nodes
from nodes import NODE_CLASS_MAPPINGS
from totoro_extras import nodes_custom_sampler
from totoro import model_management
from IPython.display import display


Finally we load all necessary components for our model to properly work.

In [None]:
DualCLIPLoader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]() # CLIP object to embed text
UNETLoader = NODE_CLASS_MAPPINGS["UNETLoader"]() # UNET, the architecture to load flux1-schnell
RandomNoise = nodes_custom_sampler.NODE_CLASS_MAPPINGS["RandomNoise"]() # Noise generator
BasicGuider = nodes_custom_sampler.NODE_CLASS_MAPPINGS["BasicGuider"]() # Text conditioning method
KSamplerSelect = nodes_custom_sampler.NODE_CLASS_MAPPINGS["KSamplerSelect"]() # Choose the noise sampler
BasicScheduler = nodes_custom_sampler.NODE_CLASS_MAPPINGS["BasicScheduler"]() # Choose the noise scheduler
SamplerCustomAdvanced = nodes_custom_sampler.NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
VAELoader = NODE_CLASS_MAPPINGS["VAELoader"]() # A VAE decodes the refined latent representation
VAEDecode = NODE_CLASS_MAPPINGS["VAEDecode"]() # Decoding algorithm for the VAE
EmptyLatentImage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]() # Generate a small image (latent representation)

# Load the models, not just the objects.
with torch.inference_mode():
    clip = DualCLIPLoader.load_clip("t5xxl_fp8_e4m3fn.safetensors", "clip_l.safetensors", "flux")[0]
    unet = UNETLoader.load_unet("flux1-schnell.safetensors", "fp8_e4m3fn")[0]
    vae = VAELoader.load_vae("ae.sft")[0]

def closestNumber(n, m):
    q = int(n / m)
    n1 = m * q
    if (n * m) > 0:
        n2 = m * (q + 1)
    else:
        n2 = m * (q - 1)
    if abs(n - n1) < abs(n - n2):
        return n1
    return n2

In [None]:
# @title Image generator
positive_prompt = "Wizard using fireball in a room too small" #@param {type:"string"}

# These are the recommended settings, play around if you want.
steps = 6 #@param {type:"slider", min:0, max:50, step:1}
width = 512 #@param {type:"slider", min:0, max:2048, step:256}
height = 512 #@param {type:"slider", min:0, max:2048, step:256}
n_img = 3 #@param {type:"slider", min:1, max:8, step:1}
seed = 45 #@param {type:"number"}

Flux generation algorithm:

In [None]:
for x in range(n_img):
  with torch.inference_mode():
      sampler_name = "euler"
      scheduler = "simple"

      if seed == 0:
          seed = random.randint(0, 18446744073709551615)
      else:
          seed += 23

      # 1. Generate CLIP embeddings from the positive prompt.
      cond, pooled = clip.encode_from_tokens(clip.tokenize(positive_prompt), return_pooled=True)
      cond = [[cond, {"pooled_output": pooled}]]

      # 2. Generate random noise
      noise = RandomNoise.get_noise(seed)[0]

      # 3. Influence the unet (FLUX) with text (cond)
      guider = BasicGuider.get_guider(unet, cond)[0]

      # 4. Choose noise sampler
      sampler = KSamplerSelect.get_sampler(sampler_name)[0]

      # 5. Choose the sigmas for the number of steps given using the scheduler given
      sigmas = BasicScheduler.get_sigmas(unet, scheduler, steps, 1.0)[0]

      # 6. Portray an empty latent space (or latent image)
      latent_image = EmptyLatentImage.generate(closestNumber(width, 16), closestNumber(height, 16))[0]

      # 7. Use the noise, guided by the text, to iteratively sample N times using the variances sigmas.
      # The result is sample: the refined latent representation
      sample, sample_denoised = SamplerCustomAdvanced.sample(noise, guider, sampler, sigmas, latent_image)
      model_management.soft_empty_cache()

      # 8. Decode the refined latent representation using the chosen VAE
      decoded = VAEDecode.decode(vae, sample)[0].detach()

      # 9. Save the image!
      img = Image.fromarray(np.array(decoded*255, dtype=np.uint8)[0])
      img.save(f"/content/flux_{x}.png")
      display(img)