In [None]:
import torch 

inputs_1 = torch.load("/home/tosinkuye/apex/si_inputs1.pt")
inputs_2 = torch.load("/home/tosinkuye/apex/denoise_inputs1.pt")
inputs_2.pop('attention_kwargs')

for k, v in inputs_1.items():
    print(k, v.shape, v.dtype)
    
for k, v in inputs_2.items():
    print(k, v.shape, v.dtype)  



In [None]:
from src.transformer.wan.base.model import WanTransformer3DModel
from src.converters.transformer_converters import WanTransformerConverter

extra_path = '/mnt/localssd/apex-diffusion/components/BowenXue/Stand-In/resolve/main/Stand-In_wan2.1_T2V_14B_ver1.0.ckpt'
extra_weights = torch.load(extra_path)

converter = WanTransformerConverter()
converter.convert(extra_weights)

model = WanTransformer3DModel.from_pretrained("/mnt/localssd/apex-diffusion/components/Wan-AI_Wan2.1-T2V-14B-Diffusers/transformer", torch_dtype=torch.bfloat16)
model.init_ip_projections(device='meta', dtype=torch.bfloat16)
model.load_state_dict(extra_weights, strict=False, assign=True)
model.config.ip_adapter = True
model.to("cuda")

In [None]:
import sys
from glob import glob
sys.path.append("/home/tosinkuye/apex/Stand-In")
from models import ModelManager
from models.set_condition_branch import set_stand_in
from dataclasses import dataclass
from typing import Any

paths = glob("/mnt/localssd/checkpoints/base_model/*.safetensors")
print(paths)

model_manager = ModelManager()
model_manager.load_model(
    paths,
    device="cuda",
    torch_dtype=torch.bfloat16,
)

dit = model_manager.fetch_model('wan_video_dit', index=2)

@dataclass
class FakePipe:
    dit: Any

fake_pipe = FakePipe(dit=dit)
set_stand_in(fake_pipe, model_path=extra_path)

In [None]:
from models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
timestep = inputs_1['timestep']
timestep_ip = torch.zeros_like(timestep)

with torch.no_grad():
    print(dit.time_embedding[0].weight.dtype)
    t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
    t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
    t_ip = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep_ip))
    t_mod_ip = dit.time_projection(t_ip).unflatten(1, (6, dit.dim))
    print(t_mod_ip.shape, t_mod.shape)
    
with torch.no_grad():
    condition_embedder = model.condition_embedder
    temb = condition_embedder.time_embedder(condition_embedder.timesteps_proj(timestep)).type_as(t_mod)
    timestep_proj = condition_embedder.time_proj(condition_embedder.act_fn(temb)).unflatten(1, (6, dit.dim))
    temb_ip = condition_embedder.time_embedder(condition_embedder.timesteps_proj(timestep_ip)).type_as(t_mod_ip)
    timestep_proj_ip = condition_embedder.time_proj(condition_embedder.act_fn(temb_ip)).unflatten(1, (6, dit.dim))
    print(timestep_proj.shape, timestep_proj_ip.shape)


In [None]:
hidden_states = inputs_1['latents']
hidden_states_ip = inputs_1['ip_image']

hidden_states_2 = inputs_2['latents']
hidden_states_ip_2 = inputs_2['ip_image']

torch.testing.assert_close(hidden_states, hidden_states_2)
torch.testing.assert_close(hidden_states_ip, hidden_states_ip_2)


with torch.no_grad():
    pe_x, (f, h, w) = dit.patchify(hidden_states)
    pe_x_ip, (f_ip, h_ip, w_ip) = dit.patchify(hidden_states_ip)
    print(pe_x.shape, f, h, w)
    print(pe_x_ip.shape, f_ip, h_ip, w_ip)
        
with torch.no_grad():
    pe_hidden_states = model.patch_embedding(hidden_states_2).flatten(2).transpose(1, 2)
    pe_hidden_states_ip = model.patch_embedding(hidden_states_ip_2).flatten(2).transpose(1, 2)
    print(pe_hidden_states.shape, f, h, w)
    print(pe_hidden_states_ip.shape, f_ip, h_ip, w_ip)
    
    
    

In [None]:
context = inputs_1['context']

with torch.no_grad():
    context_embeddings = dit.text_embedding(context)
    print(context_embeddings.shape)

with torch.no_grad():
    encoder_hidden_states = model.condition_embedder.text_embedder(context)
    print(encoder_hidden_states.shape)
    
    

In [None]:

with torch.no_grad():
    offset = 1
    freqs = (
            torch.cat(
                [
                    dit.freqs[0][offset : f + offset]
                    .view(f, 1, 1, -1)
                    .expand(f, h, w, -1),
                    dit.freqs[1][offset : h + offset]
                    .view(1, h, 1, -1)
                    .expand(f, h, w, -1),
                    dit.freqs[2][offset : w + offset]
                    .view(1, 1, w, -1)
                    .expand(f, h, w, -1),
                ],
                dim=-1,
            )
            .reshape(f * h * w, 1, -1)
            .to(hidden_states.device)
        )

    freqs_ip = (
                torch.cat(
                    [
                        dit.freqs[0][0]
                        .view(f_ip, 1, 1, -1)
                        .expand(f_ip, h_ip, w_ip, -1),
                        dit.freqs[1][h + offset : h + offset + h_ip]
                        .view(1, h_ip, 1, -1)
                        .expand(f_ip, h_ip, w_ip, -1),
                        dit.freqs[2][w + offset : w + offset + w_ip]
                        .view(1, 1, w_ip, -1)
                        .expand(f_ip, h_ip, w_ip, -1),
                    ],
                    dim=-1,
                )
                .reshape(f_ip * h_ip * w_ip, 1, -1)
                .to(hidden_states_ip.device)
            )
    freqs = torch.cat([freqs, freqs_ip], dim=0)
    
with torch.no_grad():
    rotary_emb = model.rope(hidden_states)
    rotary_emb_ip = model.rope(hidden_states, hidden_states_ip, time_index=0)
    rotary_emb = torch.cat([rotary_emb, rotary_emb_ip], dim=2)



print(freqs.shape, rotary_emb.shape)
    
    
    
    

In [None]:
torch.testing.assert_close(freqs, rotary_emb.squeeze(0).transpose(0, 1))

In [None]:
dit_block = dit.blocks[0]
model_block = model.blocks[0]

In [None]:
model_inputs = torch.load("/home/tosinkuye/apex/all_transformer_inputs.pt")

dit_kwargs = model_inputs["dit_kwargs"]
model_kwargs = model_inputs["model_kwargs"]

In [None]:

with torch.no_grad():
    dit_block_output = dit_block(**dit_kwargs)
    #dit_blk_x, dit_blk_x_ip = dit_block_output
    #print(dit_blk_x.shape, dit_blk_x_ip.shape)
    

with torch.no_grad():
    ip_hidden_states_len = model_kwargs['hidden_states_ip'].shape[1]
    model_block_output = model_block(**model_kwargs)
    #model_blk_x, model_blk_x_ip = model_block_output[:, :-ip_hidden_states_len], model_block_output[:, -ip_hidden_states_len:]
    #print(model_blk_x.shape, model_blk_x_ip.shape)


In [None]:
def compare(x1, x2):
    diff = x1 - x2
    print(diff.abs().max())
    torch.testing.assert_close(x1, x2)


compare(model_block_output, dit_block_output)