# Interactive visualization of all possible first tokens of FlexTok finetuned on CelebA-HQ

The first token in the FlexTok sequence captures the most essential high-level information about an image. Use this notebook to interactively explore different first token values and see how they affect the image reconstruction. Each slider corresponds to one FSQ level, and for each token index we show 9 random samples from the FlexTok d18-d28 decoder. When we consider only the first token, there are 64000 possible indices created by the FSQ levels [8,8,8,5,5,5]. In essence, that means FlexTok partitions the distribution of all possible images into 64000 clusters, each represented by a single token.

In [1]:
# Switch path to root of project
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
current_folder = globals()['_dh'][0]
os.chdir(os.path.dirname(os.path.abspath(current_folder)))

%load_ext autoreload
%autoreload 2

In [2]:
from PIL import Image
import matplotlib.pyplot as plt

import einops
import torch
import torchvision.transforms.functional as TF

from diffusers.models import AutoencoderKL

from flextok.flextok_wrapper import FlexTokFromHub, FlexTok
from flextok.utils.demo import imgs_from_urls, denormalize, batch_to_pil
from flextok.utils.misc import detect_bf16_support, get_bf16_context, get_generator
from flextok.utils.dataloader import CelebAHQDataset, create_celebahq_dataloader

# The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Global no_grad
torch.set_grad_enabled(False)

# Automatically set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)

# Detect if bf16 is enabled or not
enable_bf16 = detect_bf16_support()
print('BF16 enabled:', enable_bf16)

# Set up plotting
plt.rcParams['figure.figsize'] = (15, 10)
plt.rcParams['figure.dpi'] = 100

Device: cuda
BF16 enabled: True


  self.setter(val)


## 1 Sampling tokens

The FlexTok encoder maps an image into a one-dimensional sequence of 256 register tokens. These tokens are discretized using FSQ, resulting in a vocabulary of 64,000 tokens.

FlexTok was trained so that these quantized register tokens represent images in a hierarchical and ordered manner. Specifically, truncated subsequences of lengths _1, 2, 4, 8, 16, 32, 64, 128, and 256_ tokens all represent valid images, and as more tokens are used, the reconstructions increasingly resemble the encoded image. Although FlexTok was trained solely using a rectified flow and REPA objective, the truncated token sequences emerge as highly semantic compressions of the original image, capturing its most salient aspects with the fewest tokens.

Here, I'm loading a FlexTok d18 d18 model trained on ImageNet1K and finetuned on CelebA-HQ from a local checkpoint.

In [3]:
# Load a FlexTok d18-d18 model trained on ImageNet1K and finetuned on CelebA-HQ from a local checkpoint.
flextok = FlexTokFromHub.from_pretrained('EPFL-VILAB/flextok_d18_d18_in1k')

In [4]:
ckpt_path = "/home/iyu/ml-flextok/checkpoints/celebahq_ft/checkpoint_latest.pt"
checkpoint = torch.load(ckpt_path, map_location='cuda')
flextok.load_state_dict(checkpoint['model_state_dict'])
flextok = flextok.to(device).eval()

### 1.1 Sampling possible first tokens
This is an interactive GUI for exploring the possible values for the first token of FlexTok fine-tuned on CelebA-HQ. Use the sliders to explore each FSQ level (the FSQ levels are [8, 8, 8, 5, 5, 5]).

First, we write functions to sample random quantization levels and convert those levels into tokens. 

In [7]:
from itertools import product


def get_possible_combos(flextok_model: FlexTok):
    """
    Get all possible first zhats (quantized latents) from the FlexTok model.
    Args:
        flextok_model: The FlexTok model.
        num_samples: Number of samples to generate.
    Returns:
        batch of first zhats (num_samples, d).
    """
    # Get the FSQ from flextok model
    fsq = flextok_model.regularizer
    fsq_levels = fsq._levels  # e.g., [8, 8, 8, 5, 5, 5]
    print("FSQ levels:", fsq_levels)
    print("codebook size:", fsq.codebook_size)

    quantizations = [torch.linspace(-1, 1, steps=L) for L in fsq_levels]
    all_combinations = list(product(*quantizations))
    print("Total combinations (must equal codebook size):", len(all_combinations))
    
    return torch.stack([torch.tensor(comb) for comb in all_combinations], dim=0)

def zhat_to_tokens(flextok_model: FlexTok, zhats: torch.Tensor):
    """
    Given list of zhats, generate tokens from zhats.
    Args:
        flextok_model: The FlexTok model.
        zhats: zhats (N, d).
    Returns:
        tokens (N, 1).
    """
    fsq = flextok_model.regularizer
    tokens = fsq.codes_to_indices(zhats)  # (N, 1)
    print("tokens shape:", tokens.shape)
    return tokens.long()  # type: ignore

Then, we detokenize these tokens back into images using the FlexTok rectified flow decoder. There are three important hyperparameters for the rectified flow decoder:

- `timesteps`: Number of denoising steps. 25 steps provides a good balance between reconstruction quality and inference speed.
- `guidance_scale`: Classifier-free guidance scale. See the paper appendix for guidance scale sweeps for all models. We recommend guidance scale 7.5, except for the FlexTok d12-d12 model where we recommend guidance scale 15.
- `perform_norm_guidance`: Whether or not to perform Adaptive Projected Guidance (APG), see https://arxiv.org/abs/2410.02416. We recommend setting this to True.

In [8]:
zhats = get_possible_combos(flextok).to(device)  # (64000, 6)
tokens_list = zhat_to_tokens(flextok, zhats).unsqueeze(-1)  # (64000, 1)
tokens_list = tokens_list.split(1)  # list of (1, 1) tensors
print("zhats", zhats[:10])
print("tokens list", tokens_list[:10])

FSQ levels: tensor([8, 8, 8, 5, 5, 5], device='cuda:0', dtype=torch.int32)
codebook size: 64000
Total combinations (must equal codebook size): 64000
tokens shape: torch.Size([64000])
zhats tensor([[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
        [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.5000],
        [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000,  0.0000],
        [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000,  0.5000],
        [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000,  1.0000],
        [-1.0000, -1.0000, -1.0000, -1.0000, -0.5000, -1.0000],
        [-1.0000, -1.0000, -1.0000, -1.0000, -0.5000, -0.5000],
        [-1.0000, -1.0000, -1.0000, -1.0000, -0.5000,  0.0000],
        [-1.0000, -1.0000, -1.0000, -1.0000, -0.5000,  0.5000],
        [-1.0000, -1.0000, -1.0000, -1.0000, -0.5000,  1.0000]],
       device='cuda:0')
tokens list (tensor([[0]], device='cuda:0'), tensor([[12800]], device='cuda:0'), tensor([[25600]], device='cuda:0'), tensor([[38400]], device='cud

In [None]:
# now detokenize in batches to avoid OOM
num_samples_per_quantization_combo = 9
batch_size = 32
fsq_levels = flextok.regularizer._levels
print("FSQ levels:", fsq_levels)

import tqdm
img_output_dir = "/home/iyu/flextok_first_token_samples/"
os.makedirs(img_output_dir, exist_ok=True)

for i in tqdm.tqdm(range(0, len(tokens_list), batch_size)):
    batch_tokens_list = tokens_list[i:i+batch_size]  # list of (1, 1) tensors
    batch_zhats = zhats[i:i+batch_size]  # (B, 6)
    for sample in range(num_samples_per_quantization_combo):
        with get_bf16_context(enable_bf16):
            with torch.no_grad():
                reconst = flextok.detokenize(
                    batch_tokens_list,
                    timesteps=25, # Number of denoising steps
                    guidance_scale=7.5, # Classifier-free guidance scale
                    perform_norm_guidance=True, # APG, see https://arxiv.org/abs/2410.02416
                    # Optionally control initial noise. Note that while the initial noise is deterministic, the rest of the model isn't.
                    generator=None,
                    verbose=False, # Enable to show denoising progress bar with tqdm
                )
        # save image samples to disk to avoid OOM
        for j in range(reconst.shape[0]):
            img = reconst[j]
            img = denormalize(img).clamp(0, 1)
            img_pil = TF.to_pil_image(img.cpu())
            zhat_tuple = tuple(batch_zhats[j].cpu().numpy())
            save_path = os.path.join(img_output_dir, f"quant_{'_'.join([str(v) for idx, v in enumerate(zhat_tuple)])}_sample_{sample+1}.png")
            img_pil.save(save_path)
        del reconst  # free memory
    torch.cuda.empty_cache()

FSQ levels: tensor([8, 8, 8, 5, 5, 5], device='cuda:0', dtype=torch.int32)
 24%|████████████████████████████████▎                                                                                                   | 489/2000 [13:22:06<37:30:22, 89.36s/it]

  posembs = posembs[slices]
  posembs = posembs[slices]
  1%|▉                                                                                                                                     | 14/2000 [54:15<126:26:47, 229.21s/it]

## 2 Interactive GUI for Exploring First Tokens

Use the sliders below to explore different quantization combinations for the first token. Each slider corresponds to one FSQ dimension, and the GUI displays 9 random samples for the selected quantization combination.

In [9]:
import ipywidgets as widgets
from IPython.display import display, clear_output

# FSQ levels configuration
fsq_levels = [8, 8, 8, 5, 5, 5]
img_output_dir = "/home/iyu/flextok_first_token_samples/"

# Create the quantization value mappings for each level
def get_quant_values(level):
    """Get quantization values for a given FSQ level"""
    return torch.linspace(-1, 1, steps=level).tolist()

# Pre-compute all quantization values
quant_values_per_level = [get_quant_values(level) for level in fsq_levels]

def load_and_display_images(slider_values):
    """Load and display images for the selected quantization combination"""
    # Convert slider indices to actual quantization values
    quant_combo = [quant_values_per_level[i][slider_values[i]] for i in range(len(fsq_levels))]
    
    # Format the quantization combo for filename
    quant_str = "_".join([str(float(v)) for v in quant_combo])
    
    # Load the 9 sample images
    fig, axes = plt.subplots(3, 3, figsize=(12, 12))
    axes = axes.flatten()
    
    images_found = 0
    for sample_num in range(1, 10):
        img_path = os.path.join(img_output_dir, f"quant_{quant_str}_sample_{sample_num}.png")
        
        if os.path.exists(img_path):
            img = Image.open(img_path)
            axes[sample_num - 1].imshow(img)
            axes[sample_num - 1].axis('off')
            images_found += 1
        else:
            axes[sample_num - 1].text(0.5, 0.5, 'Not found', 
                                     ha='center', va='center', fontsize=12)
            axes[sample_num - 1].axis('off')
    
    # Title showing the quantization values
    fig.suptitle(f'Quantization: [{", ".join([f"{v:.2f}" for v in quant_combo])}]', 
                 fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    if images_found == 0:
        print(f"Warning: No images found for quantization {quant_combo}")
        print(f"Looking for pattern: quant_{quant_str}_sample_*.png")

# Create sliders for each FSQ level
sliders = []
slider_labels = []

for i, level in enumerate(fsq_levels):
    # Get the quantization values for this level
    quant_vals = quant_values_per_level[i]
    
    # Create slider
    slider = widgets.IntSlider(
        value=level // 2,  # Start at middle value
        min=0,
        max=level - 1,
        step=1,
        description=f'Dim {i}:',
        continuous_update=False,
        layout=widgets.Layout(width='600px')
    )
    
    # Create label showing the actual quantization value
    label = widgets.Label(value=f'{quant_vals[slider.value]:.2f}')
    
    # Update label when slider changes
    def make_update_label(slider, label, quant_vals):
        def update_label(change):
            label.value = f'{quant_vals[change.new]:.2f}'
        return update_label
    
    slider.observe(make_update_label(slider, label, quant_vals), names='value')
    
    sliders.append(slider)
    slider_labels.append(label)

# Create output widget
output = widgets.Output()

def on_slider_change(change):
    """Update display when any slider changes"""
    with output:
        clear_output(wait=True)
        slider_values = [s.value for s in sliders]
        load_and_display_images(slider_values)

# Attach the update function to all sliders
for slider in sliders:
    slider.observe(on_slider_change, names='value')

# Create the UI layout
slider_boxes = [widgets.HBox([slider, label]) for slider, label in zip(sliders, slider_labels)]
ui = widgets.VBox(slider_boxes + [output])

# Display the UI
display(ui)

# Initial display
with output:
    slider_values = [s.value for s in sliders]
    load_and_display_images(slider_values)

VBox(children=(HBox(children=(IntSlider(value=4, continuous_update=False, description='Dim 0:', layout=Layout(…