In [1]:
import os, sys
sys.path.append("..")

from pathlib import Path
from tqdm import tqdm
import importlib
from dataclasses import asdict
from omegaconf import OmegaConf
from argparse import Namespace
import math, json, glob, shutil
import numpy as np
from multiprocessing import Value
from PIL import Image, ImageDraw

In [2]:
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from safetensors.torch import load_file, save_file
import lovely_tensors as lt
lt.monkey_patch()

In [3]:
from musubi_tuner.dataset import config_utils
from musubi_tuner.dataset.config_utils import BlueprintGenerator, ConfigSanitizer
# from musubi_tuner.networks import lora_framepack
from musubi_tuner.hv_train_network import collator_class, setup_parser_common, read_config_from_file
from musubi_tuner.fpack_train_network import framepack_setup_parser, FramePackNetworkTrainer

Trying to import sageattention
Failed to import sageattention


INFO:root:Xformers is installed!
INFO:root:Flash Attn is not installed!
INFO:root:Sage Attn is not installed!


In [4]:
sys.argv = [
    "fpack_train_network.py",
    "--dit", "/projects/bffz/ykwon4/ComfyUI/models/diffusion_models/FramePackI2V_HY_bf16.safetensors",
    "--vae", "/projects/bffz/ykwon4/ComfyUI/models/vae/hunyuan-video-t2v-720p-vae.pt",
    "--text_encoder1", "/projects/bffz/ykwon4/ComfyUI/models/text_encoders/llava_llama3_fp16.safetensors",
    "--text_encoder2", "/projects/bffz/ykwon4/ComfyUI/models/text_encoders/clip_l.safetensors",
    "--image_encoder", "/projects/bffz/ykwon4/ComfyUI/models/clip_vision/sigclip_vision_patch14_384.safetensors",
    "--dataset_config", "/projects/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v2.toml",
    "--sdpa", "--mixed_precision", "bf16", "--one_frame",
    "--optimizer_type", "adamw8bit", 
    "--learning_rate", "2e-4", 
    "--gradient_checkpointing",
    "--timestep_sampling", "shift", 
    "--weighting_scheme", "none", 
    "--discrete_flow_shift", "3.0",
    "--max_data_loader_n_workers", "8", 
    "--persistent_data_loader_workers",
    "--network_module", "networks.lora_framepack", 
    "--network_dim", "32",
    "--max_train_epochs", "16", 
    "--save_every_n_epochs", "1", 
    "--seed", "42",
    "--sample_prompts", "/projects/bffz/ykwon4/OpenS2V-Nexus/test3_part2_sample_prompts.txt",
    "--sample_every_n_epochs", "1", 
    "--sample_at_first",
    "--output_dir", "outputs/training/idmask_control_lora", 
    "--output_name", "idmask_control_lora_test1",
    "--logging_dir", "outputs/training/idmask_control_lora/logs", 
    "--log_with", "tensorboard",
    "--remove_embedding", "--use_attention_controlimage_masking"
]

parser = setup_parser_common()
parser = framepack_setup_parser(parser)
args = parser.parse_args()
args = read_config_from_file(args, parser)
args.vae_dtype = "float16"  # fixed
args.dit_dtype = "bfloat16"  # fixed
args.sample_solver = "unipc"  # for sample generation, fixed to unipc

device = torch.device('cuda:0')

trainer = FramePackNetworkTrainer()
trainer.handle_model_specific_args(args)

In [5]:
blueprint_generator = BlueprintGenerator(ConfigSanitizer())
user_config = config_utils.load_user_config(args.dataset_config)
blueprint = blueprint_generator.generate(user_config, args, architecture=trainer.architecture)
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group, training=True)
collator = collator_class(Value("i", 0), Value("i", 0), None)
train_dataloader = DataLoader(
    train_dataset_group, batch_size=1, shuffle=True, collate_fn=collator, num_workers=8, persistent_workers=True    
)
# train_dataset_group.datasets[0].batch_manager[271]

batch = next(iter(train_dataloader))
batch

INFO:musubi_tuner.dataset.image_video_dataset:load image jsonl from /projects/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v2.jsonl
INFO:musubi_tuner.dataset.image_video_dataset:loaded 25813 images
INFO:musubi_tuner.dataset.image_video_dataset:found 25813 images with 1 control images per image in JSONL data
INFO:musubi_tuner.dataset.image_video_dataset:found 25813 masks with 1 control masks per image in JSONL data
INFO:musubi_tuner.dataset.image_video_dataset:found 25813 metadata with 1 bbox paths per image in JSONL data
INFO:musubi_tuner.dataset.config_utils:[Dataset 0]
  is_image_dataset: True
  resolution: (960, 544)
  batch_size: 16
  num_repeats: 1
  caption_extension: ".txt"
  enable_bucket: False
  bucket_no_upscale: False
  cache_directory: "/projects/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v2_cache"
  debug_dataset: False
  item_name_type: "basename"  
    image_directory: "None"
    image_jsonl_file: "/projects/bffz/ykwon4/OpenS2V-Nexus/test3_part2_v2.jsonl"
    fp_latent_window_size:

{'clean_latent_indices': tensor[16, 1] i64 [38;2;127;127;127mall_zeros[0m,
 'latent_indices': tensor[16, 1] i64 x∈[3, 3] μ=3.000 σ=0.,
 'clean_latent_bboxes': tensor[16, 1, 4] n=64 x∈[0.006, 0.997] μ=0.473 σ=0.309,
 'image_embeddings': tensor[16, 729, 1152] f16 n=13436928 (26Mb) x∈[-61.750, 108.312] μ=0.023 σ=1.882,
 'latents': tensor[16, 16, 1, 90, 160] f16 n=3686400 (7.0Mb) x∈[-5.094, 5.004] μ=0.008 σ=1.153,
 'latents_clean': tensor[16, 16, 1, 64, 64] f16 n=1048576 (2Mb) x∈[-4.746, 4.652] μ=-0.019 σ=1.084,
 'target_latent_masks': tensor[16, 1, 1, 90, 160] f16 n=230400 (0.4Mb) x∈[0., 1.000] μ=0.350 σ=0.473,
 'llama_vec': tensor[16, 512, 4096] f16 n=33554432 (64Mb) x∈[-36.531, 10.781] μ=-0.001 σ=0.394,
 'llama_attention_mask': tensor[16, 512] bool n=8192 (8Kb) x∈[False, True] μ=0.267 σ=0.442,
 'clip_l_pooler': tensor[16, 768] n=12288 (48Kb) x∈[-5.064, 7.806] μ=-0.107 σ=0.996}

In [6]:
for batch in tqdm(train_dataloader):
    pass

100%|██████████| 1614/1614 [33:53<00:00,  1.26s/it]


In [None]:
# prepare dtype
weight_dtype = torch.bfloat16
dit_dtype = torch.bfloat16
dit_weight_dtype = torch.bfloat16
# get embedding for sampling images
vae_dtype = torch.float16
sample_parameters = trainer.process_sample_prompts(args, Namespace(**{'device': device}), args.sample_prompts)

# Load VAE model for sampling images: VAE is loaded to cpu to save gpu memory
vae = trainer.load_vae(args, vae_dtype=vae_dtype, vae_path=args.vae)
vae.requires_grad_(False)
vae.eval()
vae.to(device)

transformer = trainer.load_transformer(
    Namespace(**{'device': device}), args, args.dit, "torch", args.split_attn, device, dit_weight_dtype
)
transformer.eval()
transformer.requires_grad_(False)

# # apply network to DiT
# network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True)
weights_sd = load_file("/home/yo564250/workspace/whisperer/related/framepackbase/musubi-tuner/outputs/training/idmask_control_lora/idmask_control_lora_test3-000008.safetensors")
module = lora_framepack.create_arch_network_from_weights(
    1.0, weights_sd, unet=transformer, for_inference=True
)
module.merge_to(None, transformer, weights_sd, weight_dtype, "cpu")

transformer.enable_gradient_checkpointing()
# network.enable_gradient_checkpointing()  # may have no effect

In [None]:
with torch.inference_mode():
    latents = batch["latents"]
    noise = torch.randn_like(latents)

    # get_noisy_model_input_and_timesteps
    logits_norm = torch.randn(8, device=device)
    t = logits_norm.sigmoid()
    t = (t * 3.0) / (1 + (3.0 - 1) * t)
    timesteps = t * 1000.0
    t = t.view(-1, 1, 1, 1, 1)
    noisy_model_input = (1 - t) * latents + t * noise
    timesteps += 1
    # compute_loss_weighting_for_sd3
    weighting = None

    

In [39]:
with torch.no_grad():
    latents = batch['latents'].to(device, dtype=dit_dtype) #B C T H W
    latent_indices = batch['latent_indices']
    clean_latents = batch['latents_clean'].to(device, dtype=dit_dtype) #B C n h w
    clean_latent_indices = batch['clean_latent_indices'] #B 1
    clean_latent_bboxes = batch['clean_latent_bboxes'] #B n 4

    print(latents.shape)
    hidden_states = transformer.x_embedder.proj(latents)
    B, C, T, H, W = hidden_states.shape
    print(hidden_states.shape)
    hidden_states = hidden_states.flatten(2).transpose(1, 2)
    print(hidden_states.shape)

    rope_freqs = transformer.rope(
        frame_indices=latent_indices, height=H, width=W, 
        device=hidden_states.device)
    print(rope_freqs.shape)
    rope_freqs = rope_freqs.flatten(2).transpose(1, 2)
    print(rope_freqs.shape)

    print(clean_latents.shape)
    N = clean_latents.shape[2]
    clean_latents = [clean_latents[:,:,[i]] for i in range(N)]
    clean_latent_indices = [clean_latent_indices[:,[i]] for i in range(clean_latent_indices.shape[1])]
    if len(clean_latent_indices) != len(clean_latents):
        clean_latent_indices = [clean_latent_indices[0]] * N
    clean_latent_bboxes = [clean_latent_bboxes[:,i] for i in range(clean_latent_bboxes.shape[1])]
    
    processed_clean_latents, clean_latent_rope_freqs = [], []
    for i, clean_latent in enumerate(clean_latents):
        clean_latent = clean_latent.to(hidden_states)
        clean_latent = transformer.clean_x_embedder.proj(clean_latent)
        clean_latent_index = clean_latent_indices[i]
        clean_latent_bbox = clean_latent_bboxes[i]

        if clean_latent.shape[0] != B:
            clean_latent = repeat_to_batch_size(clean_latent, B)
        if clean_latent_index.shape[0] != B:
            clean_latent_index = repeat_to_batch_size(clean_latent_index, B)
        _, _, _, clean_H, clean_W = clean_latent.shape
        if clean_latent_bbox.shape[0] != B:
            clean_latent_bbox = repeat_to_batch_size(clean_latent_bbox, B)

        clean_latent_rope_freq = []
        for b in range(B):
            cb_rel = clean_latent_bbox[b]
            if cb_rel.mean() > 0.0:
                cb_0, cb_1, cb_2, cb_3 = int(cb_rel[0]*W), int(cb_rel[1]*H), int(cb_rel[2]*W), int(cb_rel[3]*H)
                cb_rope_freq = transformer.rope(
                    frame_indices=clean_latent_index[[b]], 
                    height=cb_3, width=cb_2,
                    start_height=cb_1, start_width=cb_0,
                    step_H=(cb_3 - cb_1) / clean_H,
                    step_W=(cb_2 - cb_0) / clean_W,
                    device=clean_latent.device)

                clean_latent_rope_freq.append(cb_rope_freq)
        clean_latent_rope_freq = torch.cat(clean_latent_rope_freq, dim=0)

        clean_latent = clean_latent.flatten(2).transpose(1, 2)
        clean_latent_rope_freq = clean_latent_rope_freq.flatten(2).transpose(1, 2)
        print(clean_latent.shape)
        print(clean_latent_rope_freq.shape)

        processed_clean_latents.append(clean_latent)
        clean_latent_rope_freqs.append(clean_latent_rope_freq)

    processed_clean_latents = torch.cat(processed_clean_latents, dim=1)
    clean_latent_rope_freqs = torch.cat(clean_latent_rope_freqs, dim=1)
    print(processed_clean_latents.shape)
    print(clean_latent_rope_freqs.shape)

torch.Size([8, 16, 1, 68, 120])
torch.Size([8, 3072, 1, 34, 60])
torch.Size([8, 2040, 3072])
torch.Size([8, 256, 1, 34, 60])
torch.Size([8, 2040, 256])
torch.Size([8, 16, 2, 32, 32])
torch.Size([8, 256, 3072])
torch.Size([8, 256, 256])
torch.Size([8, 256, 3072])
torch.Size([3, 256, 256])


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 8 but got size 3 for tensor number 1 in the list.

In [19]:
load_file('/groups/chenchen/patrick/OpenS2V-Nexus/datasets/test3_2_cache_v2/_0ccfGEpIwY_segment_16_step1-0-85_step2-0-85_step4_step5_step6_1280x0720_fp.safetensors')

{'clean_latent_bboxes_float32': tensor[1, 4] x∈[0.029, 0.608] μ=0.370 σ=0.252 [[0.342, 0.029, 0.608, 0.500]],
 'clean_latent_indices_int64': tensor[1] i64 [0],
 'image_embeddings_float16': tensor[729, 1152] f16 n=839808 (1.6Mb) x∈[-60.719, 90.688] μ=0.026 σ=1.979,
 'latent_indices_int64': tensor[1] i64 [9],
 'latents_1x68x120_float16': tensor[16, 1, 68, 120] f16 n=130560 (0.2Mb) x∈[-4.133, 4.762] μ=0.327 σ=1.521,
 'latents_clean_1x68x120_float16': tensor[16, 1, 32, 32] f16 n=16384 (32Kb) x∈[-4.199, 3.943] μ=0.018 σ=1.051,
 'target_latent_masks_1x68x120_float16': tensor[1, 1, 68, 120] f16 n=8160 (16Kb) x∈[0., 1.000] μ=0.307 σ=0.455}

In [None]:
def process_input_hidden_states(
    self,
    latents,
    latent_indices=None,
    clean_latents=None,
    clean_latent_indices=None,
    clean_latent_bboxes=None, # [B, N, 4]
):
    hidden_states = self.x_embedder.proj(latents)
    B, C, T, H, W = hidden_states.shape

    if latent_indices is None:
        latent_indices = repeat_to_batch_size(torch.arange(0, T).unsqueeze(0), B)
    if latent_indices.shape[0] != B:
        latent_indices = repeat_to_batch_size(latent_indices, B)

    hidden_states = hidden_states.flatten(2).transpose(1, 2)

    rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device)
    rope_freqs = rope_freqs.flatten(2).transpose(1, 2)

    if clean_latents is not None and clean_latent_indices is not None:
        processed_clean_latents, clean_latent_rope_freqs = [], []
        if type(clean_latents) != list:
            clean_latents = [clean_latents[:,:,[i]] for i in range(clean_latents.shape[2])]
            clean_latent_indices = [clean_latent_indices[:,[i]] for i in range(clean_latent_indices.shape[1])]
        if len(clean_latent_indices) != len(clean_latents):
            clean_latent_indices = [clean_latent_indices[0]] * len(clean_latents)
        if clean_latent_bboxes is None:
            clean_latent_bboxes = [torch.tensor([[0.0, 0.0, 1.0, 1.0]], device=hidden_states.device).repeat(B,1)] * N
        else:
            clean_latent_bboxes = [clean_latent_bboxes[:,i] for i in range(clean_latent_bboxes.shape[1])]

        for i, clean_latent in enumerate(clean_latents):
            clean_latent = clean_latent.to(hidden_states)
            clean_latent = self.clean_x_embedder.proj(clean_latent)
            clean_latent_index = clean_latent_indices[i]
            clean_latent_bbox = clean_latent_bboxes[:,i]

            if clean_latent.shape[0] != B:
                clean_latent = repeat_to_batch_size(clean_latent, B)
            if clean_latent_index.shape[0] != B:
                clean_latent_index = repeat_to_batch_size(clean_latent_index, B)
            _, _, _, clean_H, clean_W = clean_latent.shape

            if clean_latent_bbox.shape[0] != B:
                clean_latent_bbox = repeat_to_batch_size(clean_latent_bbox, B)

            clean_latent_rope_freq = []
            for b in range(B):
                cb = [
                    int(clean_latent_bbox[b,0]*W), int(clean_latent_bbox[b,1]*H), 
                    int(clean_latent_bbox[b,2]*W), int(clean_latent_bbox[b,3]*H)
                ]
                cb_rope_freq = self.rope(
                    frame_indices=clean_latent_index[[b]], 
                    height=cb[3], width=cb[2],
                    start_height=cb[1], start_width=cb[0],
                    step_H=(cb[3] - cb[1]) / clean_H,
                    step_W=(cb[2] - cb[0]) / clean_W,
                    device=clean_latent.device)

                clean_latent_rope_freq.append(cb_rope_freq)
            clean_latent_rope_freq = torch.cat(clean_latent_rope_freq, dim=0)

            clean_latent = clean_latent.flatten(2).transpose(1, 2)
            clean_latent_rope_freq = clean_latent_rope_freq.flatten(2).transpose(1, 2)
            
            processed_clean_latents.append(clean_latent)
            clean_latent_rope_freqs.append(clean_latent_rope_freq)

        processed_clean_latents = torch.cat(processed_clean_latents, dim=1)
        clean_latent_rope_freqs = torch.cat(clean_latent_rope_freqs, dim=1)

        # logger.info(f"Clean Latent Rope Freq Shape: {clean_latent_rope_freq.shape}")
        # logger.info(f"Rope Freq Shape: {rope_freqs.shape}")
        hidden_states = torch.cat([processed_clean_latents, hidden_states], dim=1)
        rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1)

    return hidden_states, rope_freqs

In [4]:
with torch.inference_mode():
    trainer.sample_image_inference(
        Namespace(**{'device': device}), args, transformer, dit_dtype, vae, 
        ".", sample_parameters[1], 0, 0
    )

INFO:musubi_tuner.hv_train_network:prompt: Three individuals seated closely together in what appears to be a casual indoor setting. The person in the center is wearing a gray hoodie with pink accents and has light-colored hair. To the left, another individual is dressed in a red shirt with a graphic design, and to the right, a person with long dark hair is wearing a light-colored top. The background includes a wall with a colorful mural or artwork, and the room has a modern, cozy ambiance with soft lighting. The individuals are engaged in conversation, with the central figure speaking and the others listening and reacting with smiles and nods. The camera remains stationary, capturing the scene from a medium shot perspective.
INFO:musubi_tuner.hv_train_network:height: 720
INFO:musubi_tuner.hv_train_network:width: 1280
INFO:musubi_tuner.hv_train_network:frame count: 1
INFO:musubi_tuner.hv_train_network:sample steps: 25
INFO:musubi_tuner.hv_train_network:guidance scale: 10.0
INFO:musubi_t

INFO:musubi_tuner.fpack_train_network:Encoding control image: /groups/chenchen/patrick/OpenS2V-Nexus/datasets/test3_2/RH9DTExtz1s_segment_55_step1-0-73_step2-0-73_step4_step5_step6/source_facecrop_0.png
INFO:musubi_tuner.fpack_train_network:Encoding entity mask: /groups/chenchen/patrick/OpenS2V-Nexus/datasets/test3_2/RH9DTExtz1s_segment_55_step1-0-73_step2-0-73_step4_step5_step6/target_bodmask_0.png
INFO:musubi_tuner.fpack_train_network:Set index for clean latent 1x: ['0']
INFO:musubi_tuner.fpack_train_network:Set index for target: 9
INFO:musubi_tuner.fpack_train_network:No clean_latents_2x
INFO:musubi_tuner.fpack_train_network:No clean_latents_4x
INFO:musubi_tuner.fpack_train_network:One frame inference. clean_latent: torch.Size([1, 16, 1, 34, 25]) latent_indices: tensor[1, 1] i64 [[9]], clean_latent_indices: tensor[1, 1] i64 [[0]], num_frames: 1


  0%|          | 0/25 [00:00<?, ?it/s]

INFO:musubi_tuner.fpack_train_network:Waiting for 5 seconds to finish block swap
INFO:musubi_tuner.fpack_generate_video:Decoding video...
INFO:musubi_tuner.fpack_generate_video:Bulk decoding or one frame inference
INFO:musubi_tuner.fpack_generate_video:Decoded. Pixel shape torch.Size([1, 3, 1, 720, 1280])
