In [None]:
import os,sys
os.environ["CUDA_VISIBLE_DEVICES"]="1"
sys.path.append('..')
sys.path.append('.')
if not os.path.exists('hyvideo'):
    os.chdir('../')
# print current work dir
print(os.getcwd())
%load_ext autoreload
%autoreload 2

In [None]:
import os
import torch
from torch.utils.data import DataLoader
from loguru import logger
import matplotlib.pyplot as plt
from hyvideo.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
from hyvideo.vae import load_vae
from hyvideo.utils.file_utils import save_videos_grid
from hyvideo.inference import HunyuanVideoSampler
from pathlib import Path
from datetime import datetime
import time
from hyvideo.config import *
from VGDFR.hunyuan_vgdfr import DyLatentMergeModRoPEGenSampler,mod_rope_forward


In [None]:
string_args="""--video-size 544 960 --video-length 65 --infer-steps 50 --prompt cat. --flow-reverse --use-cpu-offload --save-path ./results --seed 3"""
string_args=string_args.split(" ")
print(string_args)

def parse_args_new(namespace=None,string_args=None):
    parser = argparse.ArgumentParser(description="HunyuanVideo inference script")

    parser = add_network_args(parser)
    parser = add_extra_models_args(parser)
    parser = add_denoise_schedule_args(parser)
    parser = add_inference_args(parser)
    parser = add_parallel_args(parser)

    args = parser.parse_args(string_args,namespace=namespace)
    args = sanity_check_args(args)

    return args

args = parse_args_new(string_args=string_args)
print(args)


In [None]:
models_root_path = Path(args.model_base)
hunyuan_video_sampler = DyLatentMergeModRoPEGenSampler.from_pretrained(models_root_path, args=args)

In [None]:
prompt_path="VBench/prompts/augmented_prompts/gpt_enhanced_prompts/prompts_per_dimension_longer"
all_prompts=[]
for file in os.listdir(prompt_path):
    with open(os.path.join(prompt_path,file)) as f:
        all_prompts.extend(f.readlines()[:2])
print(all_prompts)
print(len(all_prompts))
with open("eval2/prompts.txt","w") as f:
    f.writelines(all_prompts)


In [None]:
# 540p raw
seed=3
args.infer_steps=50
width,height=960,544
video_length=97
for prompt in all_prompts:
    denoise_args_tuple,generator = hunyuan_video_sampler.prepare_denoise_data(
        prompt=prompt, 
        height=height,
        width=width,
        video_length=video_length,
        seed=seed,
        negative_prompt=args.neg_prompt,
        infer_steps=args.infer_steps,
        guidance_scale=args.cfg_scale,
        num_videos_per_prompt=args.num_videos,
        flow_shift=args.flow_shift,
        batch_size=args.batch_size,
        embedded_guidance_scale=args.embedded_cfg_scale
    )
    latents=hunyuan_video_sampler.pipeline.forward_with_latent_merge(*denoise_args_tuple,merge_t=-1,sim_threshold=0.5)
    samples=hunyuan_video_sampler.pipeline.latent_to_pixel(latents,generator)
    save_path = args.save_path
    # Save samples
    if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
        for i, sample in enumerate(samples):
            sample = samples[i].unsqueeze(0)
            time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
            file_name=f"raw_seed{seed}_{prompt[:100].replace('/','')}"
            raw_save_path = f"{save_path}/9/{file_name}.mp4"
            save_videos_grid(sample, raw_save_path, fps=12)
            logger.info(f'Sample save to: {raw_save_path}')
            torch.save(sample, f"{save_path}/9/{file_name}.pt")
            logger.info(f'tensor save to: {raw_save_path}')
            

In [None]:
from hyvideo.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
from hyvideo.vae.unet_causal_3d_blocks import UpsampleCausal3D,DownsampleCausal3D,CausalConv3d
from hyvideo.vae import load_vae
compression_module, _, s_ratio, t_ratio = load_vae(
    vae_type="884-16c-hy",         
    vae_precision="fp16",
    logger=logger,
    vae_path="ckpts/hunyuan-video-t2v-720p/vae",
    device="cuda",
)
compression_module=compression_module.eval()
compression_module.enable_tiling()
for name,module in compression_module.decoder.named_modules():
    if isinstance(module,UpsampleCausal3D):
        # if module.upsample_factor[1]>1 and "_blocks.2" not in name:
        if module.upsample_factor[1]>1 and "_blocks.0" in name:
            module._raw_upsample_factor=module.upsample_factor
            module.upsample_factor=(module.upsample_factor[0],1,1)
            print(f"UpsampleCausal3D {name}: convert {module._raw_upsample_factor} to {module.upsample_factor}")
        
for name,module in compression_module.encoder.named_modules():
    if isinstance(module,CausalConv3d):
        module=module.conv
        # if module.stride[1]==2 and "_blocks.0" not in name:
        if module.stride[1]==2 and "_blocks.2" in name:
            module._raw_stride=module.stride
            module.stride=(module.stride[0],1,1)
            print(f"Downsample Conv3d {name}: convert {module._raw_stride} to {module.stride}")
# print(compression_module.tile_latent_min_size)
compression_module.tile_latent_min_size=64
hunyuan_video_sampler.pipeline.compression_module=compression_module

In [None]:
seed=3
args.infer_steps=50
num_quick_inference_steps=5
width,height=960,544
width,height=500,300
video_length=97

for dlfr_sim_threshold in [0.6,0.7,0.8,0.9]:
    for prompt in all_prompts:
        
        samples = hunyuan_video_sampler.predict(
            prompt=prompt, 
            height=height,
            width=width,
            video_length=video_length,
            seed=seed,
            negative_prompt=args.neg_prompt,
            infer_steps=args.infer_steps,
            guidance_scale=args.cfg_scale,
            num_videos_per_prompt=args.num_videos,
            flow_shift=args.flow_shift,
            batch_size=args.batch_size,
            num_quick_inference_steps=num_quick_inference_steps,
            dlfr_sim_threshold=dlfr_sim_threshold,
            embedded_guidance_scale=args.embedded_cfg_scale
        )['samples']
        save_path = args.save_path
        # log_dlfr_t = hunyuan_video_sampler.pipeline.log_dlfr_t
        # Save samples
        if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
            for i, sample in enumerate(samples):
                sample = samples[i].unsqueeze(0)
                time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
                file_name=f"raw_seed{seed}_{prompt[:100].replace('/','')}"
                raw_save_path = f"{save_path}/vgdfr/th_{dlfr_sim_threshold}/{file_name}.mp4"
                save_videos_grid(sample, raw_save_path, fps=12)
                torch.save(sample, f"{save_path}/vgdfr/th_{dlfr_sim_threshold}/{file_name}.pt")
                # with open(f"{save_path}/vgdfr/th_{dlfr_sim_threshold}/log_dlfr_t.txt", "a+") as f:
                #     f.write(f"{file_name}, {log_dlfr_t[0]}, {log_dlfr_t[1]}\n")
            

  7%|▋         | 3/46 [00:02<00:35,  1.22it/s]

In [None]:
if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
            for i, sample in enumerate(samples):
                sample = samples[i].unsqueeze(0)
                time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
                file_name=f"raw_seed{seed}_{prompt[:100].replace('/','')}"
                raw_save_path = f"{save_path}/vgdfr/th_{dlfr_sim_threshold}/{file_name}.mp4"
                save_videos_grid(sample, raw_save_path, fps=12)
                torch.save(sample, f"{save_path}/vgdfr/th_{dlfr_sim_threshold}/{file_name}.pt")