In [1]:
import os, sys
sys.path.append("..")

from pathlib import Path
from tqdm import tqdm
from typing import List, Optional
from omegaconf import OmegaConf
import math, json, shutil
import numpy as np
import pandas as pd
from collections import defaultdict
import matplotlib.pyplot as plt
from multiprocessing import Value
from PIL import Image, ImageDraw
from safetensors.torch import load_file, save_file
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from transformers import SiglipImageProcessor, SiglipVisionModel, SiglipVisionConfig
import lovely_tensors as lt
lt.monkey_patch()

from musubi_tuner.cache_latents import setup_parser_common, hv_setup_parser, preprocess_contents
from musubi_tuner.dataset.config_utils import BlueprintGenerator, ConfigSanitizer, load_user_config, generate_dataset_group_by_blueprint
from musubi_tuner.dataset.image_video_dataset import DatasetGroup, ImageDataset, ItemInfo, save_latent_cache_framepack, BucketSelector
from musubi_tuner.frame_pack.clip_vision import hf_clip_vision_encode
from musubi_tuner.frame_pack.framepack_utils import load_image_encoders, load_vae, FEATURE_EXTRACTOR_CONFIG, IMAGE_ENCODER_CONFIG
from musubi_tuner.frame_pack.hunyuan import vae_encode
from musubi_tuner.hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
from musubi_tuner.fpack_cache_latents import encode_and_save_batch_one_frame, append_section_idx_to_latent_cache_path
from musubi_tuner.hv_train_network import collator_class, load_prompts
from musubi_tuner.utils.bbox_utils import get_bbox_from_mask, get_mask_from_bboxes, draw_bboxes, get_facebbox_from_bbox, get_bbox_from_meta
from musubi_tuner.utils.preproc_utils import get_text_preproc, prepare_control_inputs_for_entity, preproc_mask, postproc_imgs

device = torch.device('cuda:0')

# mpath = Path('/groups/chenchen/patrick/OpenS2V-Nexus/datasets/')

  from .autonotebook import tqdm as notebook_tqdm


Trying to import sageattention
Failed to import sageattention
model_path is /projects/bffz/ykwon4/musubi-tuner/src/practice/../musubi_tuner/ckpts/hr16/yolox-onnx/yolox_l.torchscript.pt
model_path is /projects/bffz/ykwon4/musubi-tuner/src/practice/../musubi_tuner/ckpts/hr16/DWPose-TorchScript-BatchSize5/dw-ll_ucoco_384_bs5.torchscript.pt

DWPose: Using yolox_l.torchscript.pt for bbox detection and dw-ll_ucoco_384_bs5.torchscript.pt for pose estimation
DWPose: Caching TorchScript module yolox_l.torchscript.pt on ...
DWPose: Caching TorchScript module dw-ll_ucoco_384_bs5.torchscript.pt on ...


Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


In [2]:
vae = load_vae("/projects/bffz/ykwon4/ComfyUI/models/vae/hunyuan-video-t2v-720p-vae.pt", 32, 128, device=device)
vae.to(device)

feature_extractor = SiglipImageProcessor(**FEATURE_EXTRACTOR_CONFIG)

config = SiglipVisionConfig(**IMAGE_ENCODER_CONFIG)
image_encoder = SiglipVisionModel._from_config(config, torch_dtype=torch.float16)
state_dict = load_file("/projects/bffz/ykwon4/ComfyUI/models/clip_vision/sigclip_vision_patch14_384.safetensors")
image_encoder.load_state_dict(state_dict, strict=True, assign=True)
image_encoder.eval()

print("Model Loaded")

INFO:musubi_tuner.hunyuan_model.vae:Loading 3D VAE model (884-16c-hy) from: /projects/bffz/ykwon4/ComfyUI/models/vae/hunyuan-video-t2v-720p-vae.pt


INFO:musubi_tuner.hunyuan_model.vae:VAE to dtype: torch.float16
INFO:musubi_tuner.frame_pack.framepack_utils:Set chunk_size to 32 for CausalConv3d
INFO:musubi_tuner.frame_pack.framepack_utils:Enabled spatial tiling with min size 128


Model Loaded


In [2]:
sys.argv = ["fpack_cache_latents.py",
    "--dataset_config", "/projects/bffz/ykwon4/test.toml",
    "--vae", "~/patrick/ComfyUI/models/vae/hunyuan_video_vae_bf16.safetensors",
    "--image_encoder", "~/patrick/ComfyUI/models/clip_vision/sigclip_vision_patch14_384.safetensors",
    "--vae_chunk_size", "32",
    "--vae_spatial_tile_sample_min_size", "128", 
    "--skip_existing", "--keep_cache", "--one_frame", "--one_frame_no_2x", "--one_frame_no_4x"
]

parser = setup_parser_common()
parser = hv_setup_parser(parser)  # VAE
# parser = framepack_setup_parser(parser)
parser.add_argument("--image_encoder", type=str, required=True)
parser.add_argument("--f1", action="store_true")
parser.add_argument("--one_frame", action="store_true")
parser.add_argument("--one_frame_no_2x", action="store_true")
parser.add_argument("--one_frame_no_4x", action="store_true")
args = parser.parse_args()

blueprint_generator = BlueprintGenerator(ConfigSanitizer())
user_config = load_user_config(args.dataset_config)
# blueprint = blueprint_generator.generate(user_config, args, architecture='fp')
# train_dataset_group = generate_dataset_group_by_blueprint(blueprint.dataset_group, training=True)
# dataset = train_dataset_group.datasets[0]

# dataset = ImageDataset(
#     resolution=(960, 544), caption_extension='.txt', batch_size=2, num_repeats=1, 
#     enable_bucket=True, bucket_no_upscale=False, 
#     image_directory=None, 
#     image_jsonl_file=str(mpath / "OpenS2V_part1_test3_2_test_v2.jsonl"), 
#     control_directory=None,
#     cache_directory=str(mpath / "test3_2_cache_v2"), 
#     # cache_directory=None,
#     fp_latent_window_size=9, fp_1f_clean_indices=[0], fp_1f_target_index=9, 
#     fp_1f_no_post=True, debug_dataset=False, architecture='fp', 
#     control_resolution=(256,256)
# )

In [4]:
bm = train_dataset_group.datasets[0].batch_manager
for k,v in bm.buckets.items():
    print(f"{k} : {len(v)}")

(1056, 864, 9, 1, True) : 1690
(1104, 832, 9, 1, True) : 2115
(1472, 624, 9, 1, True) : 1870
(1280, 720, 9, 1, True) : 6954
(1168, 784, 9, 1, True) : 2269
(1504, 608, 9, 1, True) : 1689
(960, 960, 9, 1, True) : 1354
(1072, 848, 9, 1, True) : 41
(720, 1280, 9, 1, True) : 19
(832, 1104, 9, 1, True) : 17
(784, 1168, 9, 1, True) : 20
(848, 1072, 9, 1, True) : 37
(1296, 704, 9, 1, True) : 23


In [19]:
res = (1504, 608)
key = (res[0], res[1], 9, 1, True)
bucket = bm.buckets[key]

false_list = []
buckets = [x for x in bm.bucket_batch_indices if x[0] == key]
for idx, (bucket_reso, batch_idx) in tqdm(enumerate(buckets), total=len(buckets)):
    start = batch_idx * bm.batch_size
    end = min(start + bm.batch_size, len(bucket))

    batch_tensor_data = defaultdict(list)
    for item_info in bucket[start:end]:
        sd_latent = load_file(item_info.latent_cache_path)

        for content_key in sd_latent.keys():
            if not content_key.endswith("_mask"):
                content_key_2 = content_key.rsplit("_", 1)[0]  # remove dtype
                if any([content_key_2.startswith(x) for x in ['latents_', 'target_latent_']]):
                    content_key_2 = content_key_2.rsplit("_", 1)[0]  # remove FxHxW
            batch_tensor_data[content_key_2].append(sd_latent[content_key])

    if len(set([x.shape for x in batch_tensor_data['latents']])) > 1:
        print(f"{[Path(x.latent_cache_path).stem for x in bucket[start:end]]}")
        false_list.append(idx)
    elif len(set([x.shape for x in batch_tensor_data['target_latent_masks']])) > 1:
        print(f"{[Path(x.latent_cache_path).stem for x in bucket[start:end]]}")
        false_list.append(idx)
    # try:
    #     for i, key in enumerate(list(batch_tensor_data.keys())):
    #         batch_tensor_data[key] = torch.stack(batch_tensor_data[key])
    #         if key.startswith("latents_clean"):
    #             batch_tensor_data[key] = batch_tensor_data[key][:,:,:bm.control_count_per_image,:,:]
    #         if key.startswith("target_latent_masks"):
    #             batch_tensor_data[key] = batch_tensor_data[key][:,:,:bm.control_count_per_image,:,:]
    #         if key.startswith("clean_latent_bboxes"):
    #             batch_tensor_data[key] = batch_tensor_data[key][:,:bm.control_count_per_image,:]
    # except Exception as e:
    #     print(f"{e} : {[Path(x.latent_cache_path).stem for x in bucket[start:end]]}")
    #     false_list.append(idx)

100%|██████████| 212/212 [04:20<00:00,  1.23s/it]


In [None]:
for idx in false_list:
    bucket_reso, batch_idx = buckets[idx]
    start = batch_idx * bm.batch_size
    end = min(start + bm.batch_size, len(bucket))

    for item_info in bucket[start:end]:
        sd_latent = load_file(item_info.latent_cache_path)

        for content_key in sd_latent.keys():
            if not content_key.endswith("_mask"):
                content_key_2 = content_key.rsplit("_", 1)[0]  # remove dtype
                if any([content_key_2.startswith(x) for x in ['latents_', 'target_latent_']]):
                    content_key_2 = content_key_2.rsplit("_", 1)[0]  # remove FxHxW
            if content_key_2 == 'latents':
                # print(sd_latent[content_key].shape[-2:])
                if list(sd_latent[content_key].shape[-2:]) != [56, 144]:
                    print(item_info.latent_cache_path)
                    # Path(item_info.latent_cache_path).unlink()
                    # Path(item_info.text_encoder_output_cache_path).unlink()
        # batch_tensor_data[content_key_2].append(sd_latent[content_key])

In [57]:
idx = 1986

bucket_reso, batch_idx = bm.bucket_batch_indices[idx]
bucket = bm.buckets[bucket_reso]
start, end = batch_idx * bm.batch_size, min(start + bm.batch_size, len(bucket))

batch_tensor_data = defaultdict(list)
for item_info in bucket[start:end]:
    sd_latent = load_file(item_info.latent_cache_path)

    for content_key in sd_latent.keys():
        if not content_key.endswith("_mask"):
            content_key_2 = content_key.rsplit("_", 1)[0]  # remove dtype
            if any([content_key_2.startswith(x) for x in ['latents_', 'target_latent_']]):
                content_key_2 = content_key_2.rsplit("_", 1)[0]  # remove FxHxW
        batch_tensor_data[content_key_2].append(sd_latent[content_key])

batch_tensor_data

defaultdict(list,
            {'clean_latent_indices': [tensor[1] i64 [0],
              tensor[1] i64 [0],
              tensor[1] i64 [0],
              tensor[1] i64 [0],
              tensor[1] i64 [0],
              tensor[1] i64 [0],
              tensor[1] i64 [0],
              tensor[1] i64 [0]],
             'latent_indices': [tensor[1] i64 [3],
              tensor[1] i64 [3],
              tensor[1] i64 [3],
              tensor[1] i64 [3],
              tensor[1] i64 [3],
              tensor[1] i64 [3],
              tensor[1] i64 [3],
              tensor[1] i64 [3]],
             'clean_latent_bboxes': [tensor[1, 4] x∈[0.025, 0.992] μ=0.419 σ=0.447 [[0.107, 0.025, 0.551, 0.992]],
              tensor[1, 4] x∈[0.050, 0.922] μ=0.469 σ=0.383 [[0.280, 0.050, 0.625, 0.922]],
              tensor[1, 4] x∈[0.119, 0.911] μ=0.571 σ=0.358 [[0.454, 0.119, 0.798, 0.911]],
              tensor[1, 4] x∈[0.103, 0.794] μ=0.467 σ=0.336 [[0.263, 0.103, 0.706, 0.794]],
              tenso

In [5]:
image_jsonl_file = Path(user_config['datasets'][0]['image_jsonl_file'])
cache_directory = Path(user_config['datasets'][0]['cache_directory'])

image_jsonls = [json.loads(x) for x in image_jsonl_file.read_text().split("\n")]
cache_files = defaultdict(list)
for x in cache_directory.glob("*.safetensors"):
    tokens = x.name.split("_")
    item_key = "_".join(tokens[:-2])
    cache_files[item_key].append(x)

remain_jsonl = []
throwaway_jsonl = []
for image_jsonl in image_jsonls:
    item_key = Path(image_jsonl['image_path']).parent.name
    if len(cache_files.get(item_key, [])) == 2:
        remain_jsonl.append(image_jsonl)
    else:
        throwaway_jsonl.append(image_jsonl)
print(len(remain_jsonl), len(throwaway_jsonl))

# image_jsonl_file_v2 = (image_jsonl_file.parent / f"{image_jsonl_file.stem}_v2.jsonl")
# image_jsonl_file_v2.write_text("\n".join([json.dumps(x) for x in remain_jsonl]))

# images_v2 = [Path(image_jsonl['image_path']).parent.name for image_jsonl in remain_jsonl]
# for item_key, cache_list in cache_files.items():
#     if item_key not in images_v2:
#         print(cache_list)
#         for cache in cache_list:
#             cache.unlink()
print(sum([len(x) for _, x in cache_files.items()]) / 2)

18102 0
18102.0


In [None]:
_,batch = next(iter(dataset.retrieve_latent_cache_batches(4)))
_, _, image, contents, content_masks, target_masks, clean_latent_bboxes = preprocess_contents(batch)
print(image)
print(contents)
print(target_masks)
print(clean_latent_bboxes)

i = 0
print(Path(batch[i].item_key).parent.name)
meta_path = Path(batch[i].item_key).parent / "meta.yaml"
meta = OmegaConf.load(meta_path)
print(meta['target_body'], meta['width'], meta['height'])

bboxes = get_bbox_from_meta(meta_path, 2)
print(bboxes)

mask = target_masks[i,0].permute(1,2,0).cpu().numpy().astype(bool).max(-1)
face_bbox = clean_latent_bboxes[i,0].numpy()
draw_bboxes(Image.fromarray(mask).convert("RGB").resize((960,544)), face_bbox)

In [None]:
with torch.no_grad():
    image = image.to(vae.device, dtype=vae.dtype)  # B, C, H, W
    contents = contents.to(vae.device, dtype=vae.dtype)  # B, C, F, H, W
    target_masks = target_masks.to(vae.device, dtype=vae.dtype)

    # VAE encode: we need to encode one frame at a time because VAE encoder has stride=4 for the time dimension except for the first frame.
    target_latent = vae_encode(image, vae).to("cpu")  # B, C, 1, H/8, W/8
    clean_latents = [vae_encode(contents[:, :, idx : idx + 1], vae).to("cpu") for idx in range(contents.shape[2])]
    clean_latents = torch.cat(clean_latents, dim=2)  # B, C, F, H/8, W/8

    # apply alphas to latents
    for b, item in enumerate(batch):
        for i, content_mask in enumerate(content_masks[b]):
            if content_mask is not None:
                # apply mask to the latents
                # print(f"Applying content mask for item {item.item_key}, frame {i}")
                clean_latents[b : b + 1, :, i : i + 1] *= content_mask

    # Vision encoding per‑item (once): use control content because it is the start image
    # images = [item.control_content[0] for item in batch]  # list of [H, W, C]
    images = [item.embed_content for item in batch]

    # encode image with image encoder
    image_embeddings = []
    with torch.no_grad():
        for image in images:
            if image.shape[-1] == 4:
                image = image[..., :3]
            image_encoder_output = hf_clip_vision_encode(image, feature_extractor, image_encoder)
            image_embeddings.append(image_encoder_output.last_hidden_state)
    image_embeddings = torch.cat(image_embeddings, dim=0)  # B, LEN, 1152
    image_embeddings = image_embeddings.to("cpu")  # Save memory

In [4]:
# dataset.set_seed(0)
# dataset.prepare_for_training()
# # for _, batch in tqdm(dataset.retrieve_latent_cache_batches(4)):
#     # items = batch
# train_dataset_group = DatasetGroup([dataset])

collator = collator_class(Value("i", 0), Value("i", 0), None)

train_dataloader = DataLoader(
    train_dataset_group,
    batch_size=1,
    shuffle=True,
    collate_fn=collator,
    num_workers=2,
    persistent_workers=True,
)

In [5]:
for batch in tqdm(train_dataloader):
    # print(batch['latents'], batch['latents_clean'])
    continue

  1%|          | 18/2274 [00:57<1:41:25,  2.70s/it]

stack expects each tensor to be equal size, but got [16, 1, 56, 144] at entry 0 and [16, 1, 58, 140] at entry 5


INFO:root:ItemInfo(item_key=sInLfhpnLh0_segment_8_step1-64-215_step2-64-151_step4_step5_step6, caption=, original_size=(1520, 608), bucket_size=(1504, 608, 9, 1, True), frame_count=None, latent_cache_path=/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/sInLfhpnLh0_segment_8_step1-64-215_step2-64-151_step4_step5_step6_1520x0608_fp.safetensors, content=None)
INFO:root:['/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/sIjZtzBoBbM_segment_16_step1-0-49_step2-0-49_step4_step5_step6_1520x0608_fp.safetensors', '/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/pJsXYQJv3es_segment_15_step1-0-65_step2-0-65_step4_step5_step6_1520x0608_fp.safetensors', '/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/XiQCDv4IMNc_segment_53_step1-10-139_step2-10-129_step4_step5_step6_1520x0608_fp.safetensors', '/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/sInLfhpnLh0_segment_4_step1-0-183_step2-0-183_step4_step5_step6_1520x0

stack expects each tensor to be equal size, but got [1, 2, 56, 144] at entry 0 and [1, 2, 58, 140] at entry 5


INFO:root:ItemInfo(item_key=sInLfhpnLh0_segment_8_step1-64-215_step2-64-151_step4_step5_step6, caption=, original_size=(1520, 608), bucket_size=(1504, 608, 9, 1, True), frame_count=None, latent_cache_path=/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/sInLfhpnLh0_segment_8_step1-64-215_step2-64-151_step4_step5_step6_1520x0608_fp.safetensors, content=None)
INFO:root:['/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/sIjZtzBoBbM_segment_16_step1-0-49_step2-0-49_step4_step5_step6_1520x0608_fp.safetensors', '/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/pJsXYQJv3es_segment_15_step1-0-65_step2-0-65_step4_step5_step6_1520x0608_fp.safetensors', '/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/XiQCDv4IMNc_segment_53_step1-10-139_step2-10-129_step4_step5_step6_1520x0608_fp.safetensors', '/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/sInLfhpnLh0_segment_4_step1-0-183_step2-0-183_step4_step5_step6_1520x0

list indices must be integers or slices, not tuple
2240 2248 6
ItemInfo(item_key=KuI-h24wLq4_segment_0_step1-0-145_step2-0-145_step4_step5_step6, caption=, original_size=(1520, 608), bucket_size=(1504, 608, 9, 1, True), frame_count=None, latent_cache_path=/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/KuI-h24wLq4_segment_0_step1-0-145_step2-0-145_step4_step5_step6_1520x0608_fp.safetensors, content=None)
{'clean_latent_indices': tensor[8, 1] i64 [38;2;127;127;127mall_zeros[0m [[0], [0], [0], [0], [0], [0], [0], [0]], 'latent_indices': tensor[8, 1] i64 x∈[3, 3] μ=3.000 σ=0. [[3], [3], [3], [3], [3], [3], [3], [3]], 'clean_latent_bboxes': tensor[8, 1, 4] n=32 x∈[-0.017, 0.996] μ=0.499 σ=0.326, 'image_embeddings': tensor[8, 729, 1152] f16 n=6718464 (13Mb) x∈[-62.250, 102.625] μ=0.023 σ=1.926, 'latents': [tensor[16, 1, 56, 144] f16 n=129024 (0.2Mb) x∈[-2.871, 4.746] μ=0.023 σ=1.125, tensor[16, 1, 56, 144] f16 n=129024 (0.2Mb) x∈[-4.656, 3.752] μ=0.118 σ=1.102, tensor[16

  2%|▏         | 42/2274 [02:13<1:51:57,  3.01s/it]

stack expects each tensor to be equal size, but got [16, 1, 56, 144] at entry 0 and [16, 1, 58, 140] at entry 3


INFO:root:ItemInfo(item_key=a8UnNvKK3Ew_segment_3_step1-0-87_step2-0-87_step4_step5_step6, caption=, original_size=(1520, 608), bucket_size=(1504, 608, 9, 1, True), frame_count=None, latent_cache_path=/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/a8UnNvKK3Ew_segment_3_step1-0-87_step2-0-87_step4_step5_step6_1520x0608_fp.safetensors, content=None)
INFO:root:['/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/pBe6xEPi42k_segment_66_step1-0-125_step2-0-125_step4_step5_step6_1520x0608_fp.safetensors', '/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/sEc9Smax2jc_segment_1_step1-10-73_step2-10-63_step4_step5_step6_1520x0608_fp.safetensors', '/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/m19eAq4TU_A_segment_0_step1-158-369_step2-158-211_step4_step5_step6_1520x0608_fp.safetensors', '/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/41BSrVeb-h0_segment_67_step1-0-57_step2-0-57_step4_step5_step6_1424x0576_f

stack expects each tensor to be equal size, but got [1, 2, 56, 144] at entry 0 and [1, 2, 58, 140] at entry 3


INFO:root:ItemInfo(item_key=a8UnNvKK3Ew_segment_3_step1-0-87_step2-0-87_step4_step5_step6, caption=, original_size=(1520, 608), bucket_size=(1504, 608, 9, 1, True), frame_count=None, latent_cache_path=/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/a8UnNvKK3Ew_segment_3_step1-0-87_step2-0-87_step4_step5_step6_1520x0608_fp.safetensors, content=None)
INFO:root:['/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/pBe6xEPi42k_segment_66_step1-0-125_step2-0-125_step4_step5_step6_1520x0608_fp.safetensors', '/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/sEc9Smax2jc_segment_1_step1-10-73_step2-10-63_step4_step5_step6_1520x0608_fp.safetensors', '/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/m19eAq4TU_A_segment_0_step1-158-369_step2-158-211_step4_step5_step6_1520x0608_fp.safetensors', '/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/41BSrVeb-h0_segment_67_step1-0-57_step2-0-57_step4_step5_step6_1424x0576_f

list indices must be integers or slices, not tuple
176 184 6
ItemInfo(item_key=Pooj2tBUBEM_segment_537_step1-0-75_step2-0-75_step4_step5_step6, caption=, original_size=(1520, 608), bucket_size=(1504, 608, 9, 1, True), frame_count=None, latent_cache_path=/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/Pooj2tBUBEM_segment_537_step1-0-75_step2-0-75_step4_step5_step6_1520x0608_fp.safetensors, content=None)
{'clean_latent_indices': tensor[8, 1] i64 [38;2;127;127;127mall_zeros[0m [[0], [0], [0], [0], [0], [0], [0], [0]], 'latent_indices': tensor[8, 1] i64 x∈[3, 3] μ=3.000 σ=0. [[3], [3], [3], [3], [3], [3], [3], [3]], 'clean_latent_bboxes': tensor[8, 1, 4] n=32 x∈[0.004, 0.993] μ=0.495 σ=0.313, 'image_embeddings': tensor[8, 729, 1152] f16 n=6718464 (13Mb) x∈[-62.562, 103.688] μ=0.024 σ=1.929, 'latents': [tensor[16, 1, 56, 144] f16 n=129024 (0.2Mb) x∈[-4.293, 4.453] μ=0.034 σ=1.043, tensor[16, 1, 56, 144] f16 n=129024 (0.2Mb) x∈[-3.898, 4.297] μ=-0.354 σ=1.565, tensor[16, 

  4%|▍         | 99/2274 [05:24<1:45:08,  2.90s/it]

stack expects each tensor to be equal size, but got [16, 1, 56, 144] at entry 0 and [16, 1, 58, 140] at entry 3


INFO:root:ItemInfo(item_key=Hj6mXN5u3DM_segment_13_step1-0-469_step2-0-469_step4_step5_step6, caption=, original_size=(1520, 608), bucket_size=(1504, 608, 9, 1, True), frame_count=None, latent_cache_path=/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/Hj6mXN5u3DM_segment_13_step1-0-469_step2-0-469_step4_step5_step6_1520x0608_fp.safetensors, content=None)
INFO:root:['/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/PJGkKLXiJVE_segment_2_step1-0-75_step2-0-75_step4_step5_step6_1520x0608_fp.safetensors', '/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/JBh9ixSWQJg_segment_25_step1-2-119_step2-2-117_step4_step5_step6_1520x0608_fp.safetensors', '/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/WNlsokkSoMg_segment_38_step1-0-163_step2-0-163_step4_step5_step6_1520x0608_fp.safetensors', '/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/_Ua-d9OeUOg_segment_45_step1-0-53_step2-0-53_step4_step5_step6_1408x0576_

stack expects each tensor to be equal size, but got [1, 2, 56, 144] at entry 0 and [1, 2, 58, 140] at entry 3


INFO:root:ItemInfo(item_key=Hj6mXN5u3DM_segment_13_step1-0-469_step2-0-469_step4_step5_step6, caption=, original_size=(1520, 608), bucket_size=(1504, 608, 9, 1, True), frame_count=None, latent_cache_path=/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/Hj6mXN5u3DM_segment_13_step1-0-469_step2-0-469_step4_step5_step6_1520x0608_fp.safetensors, content=None)
INFO:root:['/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/PJGkKLXiJVE_segment_2_step1-0-75_step2-0-75_step4_step5_step6_1520x0608_fp.safetensors', '/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/JBh9ixSWQJg_segment_25_step1-2-119_step2-2-117_step4_step5_step6_1520x0608_fp.safetensors', '/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/WNlsokkSoMg_segment_38_step1-0-163_step2-0-163_step4_step5_step6_1520x0608_fp.safetensors', '/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/_Ua-d9OeUOg_segment_45_step1-0-53_step2-0-53_step4_step5_step6_1408x0576_

list indices must be integers or slices, not tuple
536 544 6
ItemInfo(item_key=YE4DVdcn43w_segment_5_step1-18-113_step2-18-95_step4_step5_step6, caption=, original_size=(1520, 608), bucket_size=(1504, 608, 9, 1, True), frame_count=None, latent_cache_path=/work/hdd/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v4_cache_selected/YE4DVdcn43w_segment_5_step1-18-113_step2-18-95_step4_step5_step6_1520x0608_fp.safetensors, content=None)


  4%|▍         | 100/2274 [05:27<1:43:39,  2.86s/it]

{'clean_latent_indices': tensor[8, 1] i64 [38;2;127;127;127mall_zeros[0m [[0], [0], [0], [0], [0], [0], [0], [0]], 'latent_indices': tensor[8, 1] i64 x∈[3, 3] μ=3.000 σ=0. [[3], [3], [3], [3], [3], [3], [3], [3]], 'clean_latent_bboxes': tensor[8, 1, 4] n=32 x∈[0.033, 0.996] μ=0.520 σ=0.339, 'image_embeddings': tensor[8, 729, 1152] f16 n=6718464 (13Mb) x∈[-62.625, 101.688] μ=0.023 σ=1.873, 'latents': [tensor[16, 1, 56, 144] f16 n=129024 (0.2Mb) x∈[-2.627, 2.959] μ=-0.001 σ=0.912, tensor[16, 1, 56, 144] f16 n=129024 (0.2Mb) x∈[-4.629, 4.789] μ=-0.004 σ=1.475, tensor[16, 1, 56, 144] f16 n=129024 (0.2Mb) x∈[-3.875, 4.438] μ=-0.008 σ=0.924, tensor[16, 1, 58, 140] f16 n=129920 (0.2Mb) x∈[-2.717, 3.980] μ=0.233 σ=1.174, tensor[16, 1, 56, 144] f16 n=129024 (0.2Mb) x∈[-4.406, 3.551] μ=-0.311 σ=1.399, tensor[16, 1, 56, 144] f16 n=129024 (0.2Mb) x∈[-3.566, 4.332] μ=0.033 σ=1.056, tensor[16, 1, 56, 144] f16 n=129024 (0.2Mb) x∈[-2.895, 3.598] μ=-0.009 σ=0.885, tensor[16, 1, 56, 144] f16 n=129024 

  5%|▍         | 105/2274 [05:45<1:58:58,  3.29s/it]


KeyboardInterrupt: 

In [None]:
train_els = [json.loads(x) for x in 
    (mpath / "OpenS2V_part1_test3_2_test_v2.jsonl").read_text().split('\n')
]
for el in tqdm(train_els):
    name = Path(el['meta']).parent.name
    tmp_paths = list(mpath.glob(f'test3_2_cache_v2/{name}_*_fp.safetensors'))
    if len(tmp_paths) > 0:
        cache = load_file(tmp_paths[0])
        cache['clean_latent_indices_int64'] = torch.tensor([0], dtype=torch.int64)
        save_file(cache, tmp_paths[0])

In [None]:
dset_path = Path("/groups/chenchen/patrick/OpenS2V-Nexus/datasets")
dset_name = "test3_2_cache_v2"
candidates = sorted([x.name.split("_fp")[0] for x in (dset_path / dset_name).glob("*_fp.safetensors")])
print(len(candidates))

list_with_bbox = []
for name in tqdm(candidates):
    control_kwargs = load_file(dset_path / f"{dset_name}/{name}_fp.safetensors")

    entity_key = [k for k in control_kwargs.keys() if "target_latent_masks_" in k][0]
    clean_key = [k for k in control_kwargs.keys() if "latents_clean_" in k][0]
    entity_mask = control_kwargs[entity_key][0].permute(1,2,0).cpu().numpy().astype(bool)[...,0]
    clean_latents = control_kwargs[clean_key]
    w, h = entity_mask.shape[1], entity_mask.shape[0]
    clean_w, clean_h = clean_latents.shape[3], clean_latents.shape[2]

    # face_bbox = [
    #     bbox[0], bbox[1], 
    #     min((bbox[0]*entity_mask.shape[1]+clean_w)/entity_mask.shape[1], 1.0),
    #     min((bbox[1]*entity_mask.shape[0]+clean_h)/entity_mask.shape[1], 1.0),
    # ]
    bbox = get_bbox_from_mask(entity_mask)
    face_bbox = get_facebbox_from_bbox(bbox, clean_w, clean_h, w, h, full_width=False)
    clean_latent_bboxes = torch.tensor([face_bbox]).float()
    control_kwargs["clean_latent_bboxes_float32"] = clean_latent_bboxes
    # draw_bboxes(Image.fromarray(entity_mask).convert("RGB").resize((960,544)), [face_bbox])
    
    # if len(control_kwargs["clean_latent_bboxes_float32"].shape) == 4:
    #     control_kwargs["clean_latent_bboxes_float32"] = control_kwargs["clean_latent_bboxes_float32"][0,0].float()
    # elif len(control_kwargs["clean_latent_bboxes_float32"].shape) == 3:
    #     control_kwargs["clean_latent_bboxes_float32"] = control_kwargs["clean_latent_bboxes_float32"][0].float()
    # elif len(control_kwargs["clean_latent_bboxes_float32"].shape) == 1:
    #     control_kwargs["clean_latent_bboxes_float32"] = control_kwargs["clean_latent_bboxes_float32"].unsqueeze(0).float()

    save_file(control_kwargs, dset_path / f"{dset_name}/{name}_fp.safetensors")
    if "clean_latent_bboxes_float32" in control_kwargs:
        list_with_bbox.append(name)
print(len(list_with_bbox))

In [None]:
buckset_selector = BucketSelector([960, 544], True, False, "fp")

train_els = [json.loads(x) for x in 
    (mpath / "OpenS2V_part1_test3_2_test_v2.jsonl").read_text().split('\n')
]

for i, batch in enumerate(tqdm(train_els)):
    # train_els[i]['meta'] = str(Path(batch['image_path']).parent / "meta.yaml")
    name = Path(batch['meta']).parent.name
    meta = OmegaConf.load(batch['meta'])
    keys = [0]
    c_H, c_W = 256, 256

    image_size = Image.open(batch['image_path']).convert("RGB").size
    bucket_reso = buckset_selector.get_bucket_resolution(image_size)
    entity_bboxes = [meta['target_body'].get(str(x), [0.0,0.0,0.0,0.0]) for x in [keys]]
    entity_bboxes = [[x[0]/meta['width'], x[1]/meta['height'], x[2]/meta['width'], x[3]/meta['height']] for x in entity_bboxes]
    clean_latent_bboxes = torch.tensor([[bbox[0], bbox[1], bbox[0]+(c_H / bucket_reso[1]), bbox[1]+(c_W / bucket_reso[0])] for bbox in entity_bboxes]).float()

    control_cache_path = (mpath / f'test3_2_cache_v2/{name}_1280x0720_fp.safetensors')
    control_cache = load_file(control_cache_path)
    control_cache['clean_latent_bboxes_float32'] = clean_latent_bboxes
    save_file(control_cache, mpath / f'test3_2_cache_v2/{name}_1280x0720_fp.safetensors')


# (mpath / "OpenS2V_part1_test3_2_test_v2.jsonl").write_text(
#     "\n".join([json.dumps(x) for x in train_els])
# )

In [None]:
test_els = [json.loads(x) for x in 
    (mpath / "OpenS2V_part1_test3_2_test.jsonl").read_text().split('\n')
]
samples = []
for i in [0, 10]:
    name = Path(test_els[i]['image_path']).parent.name
    batch_path = mpath / f"test3_2/{name}"
    meta = OmegaConf.load(batch_path / "meta.yaml")

    prompt = meta['cap'][0]
    control_image_path = batch_path / "source_facecrop_0.png"
    mask_path = batch_path / "target_bodymask_0.png"
    of = "--of target_index=9,control_index=0,no_2x,no_4x,no_post --d 1111 --f 1 --s 25 --w 1280 --h 720"
    sample = f"{prompt} --i {str(control_image_path.absolute())} --ci {str(control_image_path.absolute())} --em {str(mask_path.absolute())} {of}"
    samples.append(sample)

Path(mpath / "test3_2_sample_prompts.txt").write_text("\n".join(samples))

In [None]:
# num_workers = max(1, os.cpu_count() - 1)

# all_latent_cache_paths = []
# for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)):
#     filtered_batch = []
#     for item in batch:
#         if item.frame_count is None:
#             all_latent_cache_paths.append(item.latent_cache_path)
#             all_existing = os.path.exists(item.latent_cache_path)
#         else:
#             latent_f = (item.frame_count - 1) // 4 + 1
#             num_sections = max(1, math.floor((latent_f - 1) / item.fp_latent_window_size))  # min 1 section
#             all_existing = True
#             for sec in range(num_sections):
#                 p = append_section_idx_to_latent_cache_path(item.latent_cache_path, sec)
#                 all_latent_cache_paths.append(p)
#                 all_existing = all_existing and os.path.exists(p)

#         if not all_existing:  # if any section cache is missing
#             filtered_batch.append(item)

#     if len(filtered_batch) == 0:  # all sections exist
#         continue

#     encode_and_save_batch_one_frame(
#         vae, feature_extractor, image_encoder, filtered_batch, 
#         vanilla_sampling = False,
#         one_frame_no_2x = True,
#         one_frame_no_4x = True,
#     )
    
# # normalize paths
# all_latent_cache_paths = [os.path.normpath(p) for p in all_latent_cache_paths]
# all_latent_cache_paths = set(all_latent_cache_paths)

In [None]:

# test_clean1 = [json.loads(x) for x in 
#     (mpath / "OpenS2V_part1_test3_2_train.jsonl").read_text().split('\n')
# ]
# test_clean1 = np.delete(test_clean1, errorneous)

# (mpath / "OpenS2V_part1_test3_2_test_2.jsonl").write_text(
#     "\n".join([json.dumps(x) for x in test_clean1])
# )

# control_imgs_sizes = []
# for i, test_clean in enumerate(tqdm(test_clean1)):
#     control_img_sizes = [Image.open(v).size for k,v in test_clean.items() if 'control_path' in k]
#     control_imgs_sizes.extend(control_img_sizes)
# print(np.mean([list(x) for x in control_imgs_sizes], axis=0))

# control_imgs_sizes = []
# errorneous = []
# for i, test_clean in enumerate(tqdm(test_clean1)):
#     try:
#         control_img_sizes = [Image.open(v).size for k,v in test_clean.items() if 'control_path' in k]
#         control_imgs_sizes.extend(control_img_sizes)
#     except Exception as e:
#         print(f"Error processing {test_clean['image_path']}: {e}")
#         errorneous.append(i)

# for i in erroneous:
# for test_clean in test_clean1:

# train_els = [json.loads(x) for x in 
#     (mpath / "OpenS2V_part1_test3_2_test.jsonl").read_text().split('\n')
# ]
# for i, batch in enumerate(train_els):
#     train_els[i]['meta'] = str(Path(batch['image_path']).parent / "meta.yaml")

# (mpath / "OpenS2V_part1_test3_2_test_v2.jsonl").write_text(
#     "\n".join([json.dumps(x) for x in train_els])
# )