In [None]:
import os, sys
sys.path.append("src/")
sys.argv = ["fpack_cache_latents.py",
    "--dataset_config", "/data/whisperer/datasets/storyviz/OpenS2V-Nexus/datasets/test2/train.toml",
    "--vae", "/data/stale/patrickkwon/video/stable-diffusion-webui/models/VAE/hunyuan-video-t2v-720p-vae.pt",
    "--image_encoder", "/shared/video/ComfyUI/models/clip_vision/sigclip_vision_patch14_384.safetensors",
    "--vae_chunk_size", "32",
    "--vae_spatial_tile_sample_min_size", "128", 
    "--skip_existing", "--keep_cache", "--one_frame", "--one_frame_no_2x", "--one_frame_no_4x"
]
from typing import List, Optional
import math

import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import SiglipImageProcessor, SiglipVisionModel, SiglipVisionConfig
from PIL import Image
from safetensors.torch import load_file

from musubi_tuner.dataset import config_utils
from musubi_tuner.dataset.config_utils import BlueprintGenerator, ConfigSanitizer
from musubi_tuner.dataset.image_video_dataset import ImageDataset, ItemInfo, save_latent_cache_framepack, ARCHITECTURE_FRAMEPACK
from musubi_tuner.frame_pack import hunyuan
from musubi_tuner.frame_pack.framepack_utils import load_image_encoders, load_vae, FEATURE_EXTRACTOR_CONFIG, IMAGE_ENCODER_CONFIG
from musubi_tuner.hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
from musubi_tuner.frame_pack.clip_vision import hf_clip_vision_encode
import musubi_tuner.cache_latents as cache_latents
from musubi_tuner.cache_latents import preprocess_contents

parser = cache_latents.setup_parser_common()
parser = cache_latents.hv_setup_parser(parser)  # VAE
# parser = framepack_setup_parser(parser)
parser.add_argument("--image_encoder", type=str, required=True)
parser.add_argument("--f1", action="store_true")
parser.add_argument("--one_frame", action="store_true")
parser.add_argument("--one_frame_no_2x", action="store_true")
parser.add_argument("--one_frame_no_4x", action="store_true")
args = parser.parse_args()

device = torch.device('cuda:0')

2025-08-01 16:35:49.676419: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754033749.798634  152801 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754033749.866710  152801 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-08-01 16:35:50.044851: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
vae = load_vae("/data/stale/patrickkwon/video/stable-diffusion-webui/models/VAE/hunyuan-video-t2v-720p-vae.pt", 32, 128, device=device)
vae.to(device)

feature_extractor = SiglipImageProcessor(**FEATURE_EXTRACTOR_CONFIG)

config = SiglipVisionConfig(**IMAGE_ENCODER_CONFIG)
image_encoder = SiglipVisionModel._from_config(config, torch_dtype=torch.float16)
state_dict = load_file("/shared/video/ComfyUI/models/clip_vision/sigclip_vision_patch14_384.safetensors")
image_encoder.load_state_dict(state_dict, strict=True, assign=True)
image_encoder.eval()

print("Model Loaded")

Model Loaded


In [None]:
# blueprint_generator = BlueprintGenerator(ConfigSanitizer())
# user_config = config_utils.load_user_config(args.dataset_config)
# blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_FRAMEPACK)
# train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
# datasets = train_dataset_group.datasets

dataset = ImageDataset(
    resolution=(960, 544), caption_extension='.txt', batch_size=1, num_repeats=1, enable_bucket=True, bucket_no_upscale=False, 

    image_directory=None, 
    image_jsonl_file='/data/whisperer/datasets/storyviz/OpenS2V-Nexus/datasets/test2/train-clean1.jsonl', 
    control_directory=None,
    cache_directory='/data/whisperer/datasets/storyviz/OpenS2V-Nexus/datasets/test2/cache_clean', 

    fp_latent_window_size=9, fp_1f_clean_indices=[0], fp_1f_target_index=9, 
    fp_1f_no_post=True, debug_dataset=False, architecture='fp', 
)
dataset.set_seed(0)

In [None]:
def encode_and_save_batch_one_frame(
    vae: AutoencoderKLCausal3D,
    feature_extractor: SiglipImageProcessor,
    image_encoder: SiglipVisionModel,
    batch: List[ItemInfo],
    vanilla_sampling: bool = False,
    one_frame_no_2x: bool = False,
    one_frame_no_4x: bool = False,
):
    # item.content: target image (H, W, C)
    # item.control_content: list of images (H, W, C)
    _, _, contents, content_masks = preprocess_contents(batch)
    contents = contents.to(vae.device, dtype=vae.dtype)  # B, C, F, H, W

    # VAE encode: we need to encode one frame at a time because VAE encoder has stride=4 for the time dimension except for the first frame.
    latents = [hunyuan.vae_encode(contents[:, :, idx : idx + 1], vae).to("cpu") for idx in range(contents.shape[2])]
    latents = torch.cat(latents, dim=2)  # B, C, F, H/8, W/8

    # apply alphas to latents
    for b, item in enumerate(batch):
        for i, content_mask in enumerate(content_masks[b]):
            if content_mask is not None:
                # apply mask to the latents
                # print(f"Applying content mask for item {item.item_key}, frame {i}")
                latents[b : b + 1, :, i : i + 1] *= content_mask

    # Vision encoding per‑item (once): use control content because it is the start image
    # images = [item.control_content[0] for item in batch]  # list of [H, W, C]
    images = [item.embed_content for item in batch]

    # encode image with image encoder
    image_embeddings = []
    with torch.no_grad():
        for image in images:
            if image.shape[-1] == 4:
                image = image[..., :3]
            image_encoder_output = hf_clip_vision_encode(image, feature_extractor, image_encoder)
            image_embeddings.append(image_encoder_output.last_hidden_state)
    image_embeddings = torch.cat(image_embeddings, dim=0)  # B, LEN, 1152
    image_embeddings = image_embeddings.to("cpu")  # Save memory

    # save cache for each item in the batch
    for b, item in enumerate(batch):
        # indices generation (same as inference): each item may have different clean_latent_indices, so we generate them per item
        clean_latent_indices = item.fp_1f_clean_indices  # list of indices for clean latents
        if clean_latent_indices is None or len(clean_latent_indices) == 0:
            clean_latent_indices = [0]

        if not item.fp_1f_no_post:
            clean_latent_indices = clean_latent_indices + [1 + item.fp_latent_window_size]
        clean_latent_indices = torch.Tensor(clean_latent_indices).long()  #  N

        latent_index = torch.Tensor([item.fp_1f_target_index]).long()  #  1

        # zero values is not needed to cache even if one_frame_no_2x or 4x is False
        clean_latents_2x = None
        clean_latents_4x = None

        if one_frame_no_2x:
            clean_latent_2x_indices = None
        else:
            index = 1 + item.fp_latent_window_size + 1
            clean_latent_2x_indices = torch.arange(index, index + 2)  #  2

        if one_frame_no_4x:
            clean_latent_4x_indices = None
        else:
            index = 1 + item.fp_latent_window_size + 1 + 2
            clean_latent_4x_indices = torch.arange(index, index + 16)  #  16

        # clean latents preparation (emulating inference)
        clean_latents = latents[b, :, :-1]  # C, F, H, W
        if not item.fp_1f_no_post:
            # If zero post is enabled, we need to add a zero frame at the end
            clean_latents = F.pad(clean_latents, (0, 0, 0, 0, 0, 1), value=0.0)  # C, F+1, H, W

        # Target latents for this section (ground truth)
        target_latents = latents[b, :, -1:]  # C, 1, H, W

        print(f"Saving cache for item {item.item_key} at {item.latent_cache_path}. no_post: {item.fp_1f_no_post}")
        print(f"  Clean latent indices: {clean_latent_indices}, latent index: {latent_index}")
        print(f"  Clean latents: {clean_latents.shape}, target latents: {target_latents.shape}")
        print(f"  Clean latents 2x indices: {clean_latent_2x_indices}, clean latents 4x indices: {clean_latent_4x_indices}")
        print(
            f"  Clean latents 2x: {clean_latents_2x.shape if clean_latents_2x is not None else 'None'}, "
            f"Clean latents 4x: {clean_latents_4x.shape if clean_latents_4x is not None else 'None'}"
        )
        print(f"  Image embeddings: {image_embeddings[b].shape}")

        # save cache (file path is inside item.latent_cache_path pattern), remove batch dim
        save_latent_cache_framepack(
            item_info=item,
            latent=target_latents,  # Ground truth for this section
            latent_indices=latent_index,  # Indices for the ground truth section
            clean_latents=clean_latents,  # Start frame + history placeholder
            clean_latent_indices=clean_latent_indices,  # Indices for start frame + history placeholder
            clean_latents_2x=clean_latents_2x,  # History placeholder
            clean_latent_2x_indices=clean_latent_2x_indices,  # Indices for history placeholder
            clean_latents_4x=clean_latents_4x,  # History placeholder
            clean_latent_4x_indices=clean_latent_4x_indices,  # Indices for history placeholder
            image_embeddings=image_embeddings[b],
        )


In [None]:
num_workers = max(1, os.cpu_count() - 1)

all_latent_cache_paths = []
for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)):
    filtered_batch = []
    for item in batch:
        if item.frame_count is None:
            all_latent_cache_paths.append(item.latent_cache_path)
            all_existing = os.path.exists(item.latent_cache_path)
        else:
            latent_f = (item.frame_count - 1) // 4 + 1
            num_sections = max(1, math.floor((latent_f - 1) / item.fp_latent_window_size))  # min 1 section
            all_existing = True
            for sec in range(num_sections):
                p = append_section_idx_to_latent_cache_path(item.latent_cache_path, sec)
                all_latent_cache_paths.append(p)
                all_existing = all_existing and os.path.exists(p)

        if not all_existing:  # if any section cache is missing
            filtered_batch.append(item)

    if len(filtered_batch) == 0:  # all sections exist
        continue

    encode_and_save_batch_one_frame(
        vae, feature_extractor, image_encoder, filtered_batch, False, True, True
    )
    
# normalize paths
all_latent_cache_paths = [os.path.normpath(p) for p in all_latent_cache_paths]
all_latent_cache_paths = set(all_latent_cache_paths)

# remove old cache files not in the dataset
all_cache_files = dataset.get_all_latent_cache_files()
for cache_file in all_cache_files:
    if os.path.normpath(cache_file) not in all_latent_cache_paths:
        if not args.keep_cache:
            os.remove(cache_file)