In [13]:
from diffusers.utils import export_to_video
from causvid.data import TextDataset
from omegaconf import OmegaConf
# from tqdm import tqdm
from tqdm.notebook import tqdm
import torch
import os
import importlib

torch.set_grad_enabled(False)

torch.autograd.grad_mode.set_grad_enabled(mode=False)

In [21]:
exp_name = "first_test"
logs_dir = os.path.join("logs", exp_name)
os.makedirs(logs_dir, exist_ok=True)

device = "cuda:0"
config_path = "configs/wan_causal_dmd.yaml"
config = OmegaConf.load(config_path)
checkpoint_folder = os.path.join("checkpoints", "autoregressive_checkpoint")
output_folder = logs_dir 
prompt_file_path = os.path.join("prompt_files", "prompt1.txt")
prompt_dataset = TextDataset(prompt_file_path)

In [10]:
# prepare generator
from causvid.models.wan.wan_wrapper import CausalWanDiffusionWrapper
generator = CausalWanDiffusionWrapper() # this loads the base "wan_models/Wan2.1-T2V-1.3B/" model
causvid_state_dict = torch.load(os.path.join(checkpoint_folder, "model.pt"), map_location=device)['generator'] # this loads causvid model
generator.load_state_dict(causvid_state_dict, strict=True)

<All keys matched successfully>

In [11]:
# prepare text encoder
from causvid.models.wan.wan_wrapper import WanTextEncoder
text_encoder = WanTextEncoder() # default text encoder, in CPU

In [12]:
# prepare vae
from causvid.models.wan.wan_wrapper import WanVAEWrapper
vae = WanVAEWrapper()

In [38]:
# VACE
import VACE_essentials.vace_wan_model
importlib.reload(VACE_essentials.vace_wan_model)
from VACE_essentials.vace_wan_model import VaceWanModel
vace_ckpt = os.path.join("VACE_essentials", "vace_wan_models", "Wan2.1-VACE-1.3B")
vace = VaceWanModel.from_pretrained(vace_ckpt)

In [27]:
import causvid.models.wan.my_inference
importlib.reload(causvid.models.wan.my_inference)
from causvid.models.wan.my_inference import MyInferencePipeline

my_pipeline = MyInferencePipeline(generator=generator,
                                  text_encoder=text_encoder,
                                  vae=vae,
                                  args=config, dtype=torch.bfloat16, device=device)

KV inference with 3 frames per block


In [28]:
sampled_noise = torch.randn(
    [1, 21, 16, 60, 104], device=device, dtype=torch.bfloat16
)

for prompt_index in tqdm(range(len(prompt_dataset))):
    prompts = [prompt_dataset[prompt_index]]

    video = my_pipeline.inference(
        noise=sampled_noise,
        text_prompts=prompts
    )[0].permute(0, 2, 3, 1).cpu().numpy()

    export_to_video(
        video, os.path.join(output_folder, f"my_output_{prompt_index:03d}.mp4"), fps=16)

  0%|          | 0/2 [00:00<?, ?it/s]

In [85]:
# prepare generator
import causvid.models.wan.wan_wrapper
importlib.reload(causvid.models.wan.wan_wrapper)
importlib.reload(causvid.models.wan.vace_causal_model)
from causvid.models.wan.wan_wrapper import VaceCausalWanDiffusionWrapper

generator = VaceCausalWanDiffusionWrapper() # this loads the base "wan_models/Wan2.1-T2V-1.3B/" model
generator.load_state_dict(causvid_state_dict, strict=False)#True)
generator.model.vace_blocks = vace.vace_blocks
generator.model.vace_patch_embedding = vace.vace_blocks

Some weights of VaceCausalWanModel were not initialized from the model checkpoint at wan_models/Wan2.1-T2V-1.3B/ and are newly initialized: ['vace_blocks.4.cross_attn.norm_q.weight', 'vace_blocks.8.cross_attn.k.weight', 'vace_blocks.8.self_attn.norm_k.weight', 'vace_blocks.4.cross_attn.o.weight', 'vace_blocks.1.cross_attn.v.weight', 'vace_blocks.9.cross_attn.o.weight', 'vace_blocks.0.self_attn.norm_q.weight', 'vace_blocks.12.cross_attn.k.bias', 'vace_blocks.1.self_attn.o.weight', 'vace_blocks.10.modulation', 'vace_blocks.10.cross_attn.q.bias', 'vace_blocks.5.self_attn.k.bias', 'vace_blocks.5.cross_attn.norm_q.weight', 'vace_blocks.12.cross_attn.norm_k.weight', 'vace_blocks.7.cross_attn.k.weight', 'vace_blocks.13.cross_attn.q.bias', 'vace_blocks.11.self_attn.k.bias', 'vace_blocks.11.cross_attn.norm_q.weight', 'vace_blocks.11.ffn.0.weight', 'vace_blocks.14.self_attn.v.weight', 'vace_blocks.13.cross_attn.v.bias', 'vace_blocks.8.cross_attn.v.weight', 'vace_blocks.7.self_attn.k.weight', 'va

_IncompatibleKeys(missing_keys=['model.vace_blocks.0.modulation', 'model.vace_blocks.0.self_attn.q.weight', 'model.vace_blocks.0.self_attn.q.bias', 'model.vace_blocks.0.self_attn.k.weight', 'model.vace_blocks.0.self_attn.k.bias', 'model.vace_blocks.0.self_attn.v.weight', 'model.vace_blocks.0.self_attn.v.bias', 'model.vace_blocks.0.self_attn.o.weight', 'model.vace_blocks.0.self_attn.o.bias', 'model.vace_blocks.0.self_attn.norm_q.weight', 'model.vace_blocks.0.self_attn.norm_k.weight', 'model.vace_blocks.0.norm3.weight', 'model.vace_blocks.0.norm3.bias', 'model.vace_blocks.0.cross_attn.q.weight', 'model.vace_blocks.0.cross_attn.q.bias', 'model.vace_blocks.0.cross_attn.k.weight', 'model.vace_blocks.0.cross_attn.k.bias', 'model.vace_blocks.0.cross_attn.v.weight', 'model.vace_blocks.0.cross_attn.v.bias', 'model.vace_blocks.0.cross_attn.o.weight', 'model.vace_blocks.0.cross_attn.o.bias', 'model.vace_blocks.0.cross_attn.norm_q.weight', 'model.vace_blocks.0.cross_attn.norm_k.weight', 'model.vac

In [88]:
print(generator.model.vace_blocks)

ModuleList(
  (0): VaceWanAttentionBlock(
    (norm1): WanLayerNorm((1536,), eps=1e-06, elementwise_affine=False)
    (self_attn): WanSelfAttention(
      (q): Linear(in_features=1536, out_features=1536, bias=True)
      (k): Linear(in_features=1536, out_features=1536, bias=True)
      (v): Linear(in_features=1536, out_features=1536, bias=True)
      (o): Linear(in_features=1536, out_features=1536, bias=True)
      (norm_q): WanRMSNorm()
      (norm_k): WanRMSNorm()
    )
    (norm3): WanLayerNorm((1536,), eps=1e-06, elementwise_affine=True)
    (cross_attn): WanT2VCrossAttention(
      (q): Linear(in_features=1536, out_features=1536, bias=True)
      (k): Linear(in_features=1536, out_features=1536, bias=True)
      (v): Linear(in_features=1536, out_features=1536, bias=True)
      (o): Linear(in_features=1536, out_features=1536, bias=True)
      (norm_q): WanRMSNorm()
      (norm_k): WanRMSNorm()
    )
    (norm2): WanLayerNorm((1536,), eps=1e-06, elementwise_affine=False)
    (ffn): S

In [86]:
import causvid.models.wan.my_inference
importlib.reload(causvid.models.wan.my_inference)
from causvid.models.wan.my_inference import MyInferencePipeline

my_pipeline = MyInferencePipeline(generator=generator,
                                  text_encoder=text_encoder,
                                  vae=vae,
                                  args=config, dtype=torch.bfloat16, device=device)

NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.

In [83]:
sampled_noise = torch.randn(
    [1, 21, 16, 60, 104], device=device, dtype=torch.bfloat16
)

for prompt_index in tqdm(range(len(prompt_dataset))):
    prompts = [prompt_dataset[prompt_index]]

    video = my_pipeline.inference(
        noise=sampled_noise,
        text_prompts=prompts
    )[0].permute(0, 2, 3, 1).cpu().numpy()

    export_to_video(
        video, os.path.join(output_folder, f"my_output_{prompt_index:03d}.mp4"), fps=16)

  0%|          | 0/2 [00:00<?, ?it/s]

AttributeError: 'VaceCausalWanModel' object has no attribute '_forward_vace'