# MDS and Latents Visualization

Read MDS shards (prepare output) and mds_latents (precompute output). Decode precomputed latents via FLUX VAE to visualize reconstructed images.

## Setup

In [None]:
import numpy as np
import torch
from diffusers import AutoencoderKLFlux2
from matplotlib import pyplot as plt
from streaming import StreamingDataset

# Flat MDS dirs: prepare output and precompute output
MDS_DIR = "./mds"
LATENTS_DIR = "./mds_latents_flux2"

## Visualize MDS (prepare output)

In [None]:
def visualize_mds(mds_dir: str, num_samples: int = 4):
    ds = StreamingDataset(local=mds_dir, batch_size=1, shuffle=False)
    n = min(num_samples, len(ds))
    fig, axes = plt.subplots(1, n, figsize=(4 * n, 4))
    if n == 1:
        axes = [axes]
    for i in range(n):
        sample = ds[i]
        img = sample["image"]
        if not isinstance(img, np.ndarray):
            img = np.array(img.convert("RGB")) if hasattr(img, "convert") else np.array(img)
        axes[i].imshow(img)
        cap = sample.get("caption", "")
        axes[i].set_title(cap[:60] + "..." if len(cap) > 60 else cap, fontsize=8)
        axes[i].axis("off")
    plt.tight_layout()
    plt.show()


visualize_mds(MDS_DIR)

## Visualize mds_latents (VAE decode)

In [None]:
def bytes_to_latent(data: bytes, resolution: int = 512, latent_channels: int = 16) -> torch.Tensor:
    h = w = resolution // 8
    arr = np.frombuffer(data, dtype=np.float16)
    return torch.from_numpy(arr.reshape(latent_channels, h, w).astype(np.float32))


RESOLUTION = 512
LATENT_KEY = f"latents_{RESOLUTION}"
MODEL_ID = "black-forest-labs/FLUX.2-klein-base-4B"

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
vae = AutoencoderKLFlux2.from_pretrained(MODEL_ID, subfolder="vae").to(device).eval()

In [None]:
def visualize_latents(latents_dir: str, num_samples: int = 4):
    ds = StreamingDataset(local=latents_dir, batch_size=1, shuffle=False)
    n = min(num_samples, len(ds))
    fig, axes = plt.subplots(1, n, figsize=(4 * n, 4))
    if n == 1:
        axes = [axes]
    for i in range(n):
        sample = ds[i]
        lat_bytes = sample[LATENT_KEY]
        lat = bytes_to_latent(lat_bytes, RESOLUTION).unsqueeze(0).to(device)
        with torch.no_grad():
            decoded = vae.decode(lat).sample
        img = decoded[0].permute(1, 2, 0).cpu().numpy()
        img = (img / 2 + 0.5).clip(0, 1)
        axes[i].imshow(img)
        axes[i].set_title(sample.get("caption", "")[:60] + "..." if len(str(sample.get("caption", ""))) > 60 else sample.get("caption", ""), fontsize=8)
        axes[i].axis("off")
    plt.tight_layout()
    plt.show()


visualize_latents(LATENTS_DIR)