In [13]:
from diffusers.utils import export_to_video
from causvid.data import TextDataset
from omegaconf import OmegaConf
# from tqdm import tqdm
from tqdm.notebook import tqdm
import torch
import os
import importlib

torch.set_grad_enabled(False)

torch.autograd.grad_mode.set_grad_enabled(mode=False)

In [238]:
exp_name = "first_test"
logs_dir = os.path.join("logs", exp_name)
os.makedirs(logs_dir, exist_ok=True)

device = "cuda:0"
config_path = "configs/wan_causal_dmd.yaml"
config = OmegaConf.load(config_path)
checkpoint_folder = os.path.join("checkpoints", "autoregressive_checkpoint")
output_folder = logs_dir 
prompt_file_path = os.path.join("prompt_files", "prompt2.txt")
prompt_dataset = TextDataset(prompt_file_path)

In [240]:
print(prompt_dataset.texts)

['two rugged nordic warriors in detailed medieval armor, bare-knuckle boxing inside a wooden fighting ring, surrounded by torchlight and roaring spectators, cinematic composition, cold misty air, sparks and sweat flying, dramatic side lighting, ultra-detailed leather and steel textures, tense atmosphere, epic fantasy realism, cinematic photography, shallow depth of field, volumetric fog, inspired by Game of Thrones, high-contrast chiaroscuro lighting, moody blue-gray tones, filmic realism, masterpiece digital art by Karla Ortiz and Ruan Jia, 8k resolution', 'a massive polar bear and a giant grizzly bear locked in brutal combat inside a wooden fighting ring, snow swirling around them, blood and fur flying, cinematic low-angle shot, cold blue lighting, frosty breath, shards of ice, roaring spectators in furs, ancient nordic arena carved from ice and wood, epic fantasy realism, moody atmosphere, volumetric fog, dramatic side lighting, masterpiece digital art by Karla Ortiz and Ruan Jia, i

In [10]:
# prepare generator
from causvid.models.wan.wan_wrapper import CausalWanDiffusionWrapper
generator = CausalWanDiffusionWrapper() # this loads the base "wan_models/Wan2.1-T2V-1.3B/" model
causvid_state_dict = torch.load(os.path.join(checkpoint_folder, "model.pt"), map_location=device)['generator'] # this loads causvid model
generator.load_state_dict(causvid_state_dict, strict=True)

<All keys matched successfully>

In [11]:
# prepare text encoder
from causvid.models.wan.wan_wrapper import WanTextEncoder
text_encoder = WanTextEncoder() # default text encoder, in CPU

In [12]:
# prepare vae
from causvid.models.wan.wan_wrapper import WanVAEWrapper
vae = WanVAEWrapper()

In [127]:
# VACE
import VACE_essentials.vace_wan_model
importlib.reload(VACE_essentials.vace_wan_model)
from VACE_essentials.vace_wan_model import VaceWanModel
vace_ckpt = os.path.join("VACE_essentials", "vace_wan_models", "Wan2.1-VACE-1.3B")
vace = VaceWanModel.from_pretrained(vace_ckpt)

in dim?  96
patch size? (1, 2, 2)
patch size?  (1, 2, 2)
OrderedDict([('weight', tensor(..., device='meta', size=(1536, 96, 1, 2, 2))), ('bias', tensor(..., device='meta', size=(1536,)))])


In [27]:
import causvid.models.wan.my_inference
importlib.reload(causvid.models.wan.my_inference)
from causvid.models.wan.my_inference import MyInferencePipeline

my_pipeline = MyInferencePipeline(generator=generator,
                                  text_encoder=text_encoder,
                                  vae=vae,
                                  args=config, dtype=torch.bfloat16, device=device)

KV inference with 3 frames per block


In [28]:
sampled_noise = torch.randn(
    [1, 21, 16, 60, 104], device=device, dtype=torch.bfloat16
)

for prompt_index in tqdm(range(len(prompt_dataset))):
    prompts = [prompt_dataset[prompt_index]]

    video = my_pipeline.inference(
        noise=sampled_noise,
        text_prompts=prompts
    )[0].permute(0, 2, 3, 1).cpu().numpy()

    export_to_video(
        video, os.path.join(output_folder, f"my_output_{prompt_index:03d}.mp4"), fps=16)

  0%|          | 0/2 [00:00<?, ?it/s]

In [235]:
# prepare generator
import causvid.models.wan.wan_wrapper
importlib.reload(causvid.models.wan.wan_wrapper)
importlib.reload(causvid.models.wan.vace_causal_model)
from causvid.models.wan.wan_wrapper import VaceCausalWanDiffusionWrapper

generator = VaceCausalWanDiffusionWrapper() # this loads the base "wan_models/Wan2.1-T2V-1.3B/" model
generator.load_state_dict(causvid_state_dict, strict=False)#True)
# generator.model.vace_blocks = vace.vace_blocks
# generator.model.vace_patch_embedding = vace.vace_patch_embedding

# copy over key parts from original vace
if next(generator.model.vace_blocks.parameters(), None) is not None and generator.model.vace_blocks.parameters().__iter__().__next__().is_meta:
    generator.model.vace_blocks.to_empty(device=generator.model.device)
generator.model.vace_blocks.load_state_dict(vace.vace_blocks.state_dict(), strict=True)

if next(generator.model.vace_patch_embedding.parameters(), None) is not None and generator.model.vace_patch_embedding.parameters().__iter__().__next__().is_meta:
    generator.model.vace_patch_embedding.to_empty(device=generator.model.device)
generator.model.vace_patch_embedding.load_state_dict(vace.vace_patch_embedding.state_dict(), strict=True)

Some weights of VaceCausalWanModel were not initialized from the model checkpoint at wan_models/Wan2.1-T2V-1.3B/ and are newly initialized: ['vace_blocks.4.cross_attn.norm_q.weight', 'vace_blocks.8.cross_attn.k.weight', 'vace_blocks.8.self_attn.norm_k.weight', 'vace_blocks.4.cross_attn.o.weight', 'vace_blocks.1.cross_attn.v.weight', 'vace_blocks.9.cross_attn.o.weight', 'vace_blocks.0.self_attn.norm_q.weight', 'vace_blocks.12.cross_attn.k.bias', 'vace_blocks.1.self_attn.o.weight', 'vace_blocks.10.modulation', 'vace_blocks.10.cross_attn.q.bias', 'vace_blocks.5.self_attn.k.bias', 'vace_blocks.5.cross_attn.norm_q.weight', 'vace_blocks.12.cross_attn.norm_k.weight', 'vace_blocks.7.cross_attn.k.weight', 'vace_blocks.13.cross_attn.q.bias', 'vace_blocks.11.self_attn.k.bias', 'vace_blocks.11.cross_attn.norm_q.weight', 'vace_blocks.11.ffn.0.weight', 'vace_blocks.14.self_attn.v.weight', 'vace_blocks.13.cross_attn.v.bias', 'vace_blocks.8.cross_attn.v.weight', 'vace_blocks.7.self_attn.k.weight', 'va

in dim?  96
patch size? (1, 2, 2)
patch size?  (1, 2, 2)
OrderedDict([('weight', tensor(..., device='meta', size=(1536, 96, 1, 2, 2))), ('bias', tensor(..., device='meta', size=(1536,)))])


<All keys matched successfully>

In [236]:
import causvid.models.wan.my_inference
importlib.reload(causvid.models.wan.my_inference)
from causvid.models.wan.my_inference import MyInferencePipeline

my_pipeline = MyInferencePipeline(generator=generator,
                                  text_encoder=text_encoder,
                                  vae=vae,
                                  args=config, dtype=torch.bfloat16, device=device)

KV inference with 3 frames per block


In [241]:
sampled_noise = torch.randn(
    [1, 21, 16, 60, 104], device=device, dtype=torch.bfloat16
)

for prompt_index in tqdm(range(len(prompt_dataset))):
    prompts = [prompt_dataset[prompt_index]]

    video = my_pipeline.inference(
        noise=sampled_noise,
        text_prompts=prompts
    )[0].permute(0, 2, 3, 1).cpu().numpy()

    export_to_video(
        video, os.path.join(output_folder, f"vace_output_{prompt_index:03d}.mp4"), fps=16)

  0%|          | 0/4 [00:00<?, ?it/s]

num_blocks?? 7
t:  tensor([[1000, 1000, 1000]], device='cuda:0')
e.shape:  torch.Size([3, 1536])
e0.shape:  torch.Size([1, 3, 6, 1536])
e0.shape:  torch.Size([1, 3, 6, 1536])
seq_lens:  tensor([4680])
current_start:  0
item in vace_context:  torch.Size([96, 21, 60, 104])
x.shape:  torch.Size([1, 4680, 1536])
u shape after vace patch embedding:  torch.Size([1, 1536, 21, 30, 52])
u shape after reshaping:  torch.Size([1, 32760, 1536])
c shape:  torch.Size([1, 4680, 1536])
self.before_proj:  Linear(in_features=1536, out_features=1536, bias=True)
c shape:  torch.Size([1, 4680, 1536])
self.before_proj(c) shape:  torch.Size([1, 4680, 1536])
x.shape:  torch.Size([1, 4680, 1536])
current_end:  4680
t:  tensor([[757, 757, 757]], device='cuda:0')
e.shape:  torch.Size([3, 1536])
e0.shape:  torch.Size([1, 3, 6, 1536])
e0.shape:  torch.Size([1, 3, 6, 1536])
seq_lens:  tensor([4680])
current_start:  0
item in vace_context:  torch.Size([96, 21, 60, 104])
x.shape:  torch.Size([1, 4680, 1536])
u shape a