In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('../owl-vaes')
sys.path.append('../')
import tarfile
import io
import torch

path = "/home/developer/workspace/data/vast/0000.tar"

def process_tensor_file(tar, base_name, suffix):
    try:
        f = tar.extractfile(f"{base_name}.{suffix}.pt")
        if f is not None:
            tensor_data = f.read()
            tensor = torch.load(io.BytesIO(tensor_data))
            return tensor
    except:
        return None
    return None

with tarfile.open(path, 'r') as tar:
    mean_wan = process_tensor_file(tar, '0002', 'wan')
    mean_dcae = process_tensor_file(tar, '0002', 'dcae')

print(mean_wan.shape, mean_dcae.shape)
print(mean_wan.dtype, mean_dcae.dtype)
print(mean_wan.min(), mean_wan.max())
print(mean_dcae.min(), mean_dcae.max())

In [None]:
check_wan = True
from diffusers import AutoencoderKLWan
from owl_vaes.configs import ResNetConfig
from owl_vaes.models.dcae import DCAE
import torch
if check_wan:
    if len(mean_wan.shape) == 4:
        mean_wan = mean_wan.unsqueeze(0)
        print(mean_wan.shape)
    vae = AutoencoderKLWan.from_pretrained(
        "Wan-AI/Wan2.1-VACE-14B-diffusers",
        subfolder="vae",
        torch_dtype=torch.bfloat16  # Optional: use half precision
    )
    vae.encoder = None 
    # Move to GPU if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    vae = vae.cuda()
    vae.compile()
    with torch.no_grad():
        out = vae.decode(mean_wan.bfloat16().cuda().permute(0,2,1,3,4))
        recon_vid = out.sample.permute(0,2,1,3,4)
        recon_vid = recon_vid[0]
else:
    cfg = ResNetConfig(
        sample_size=[360,640],
        channels=3,
        latent_size=4,
        latent_channels=128,
        noise_decoder_inputs=0.0,
        ch_0=256,
        ch_max=2048,
        encoder_blocks_per_stage = [4, 4, 4, 4, 4, 4, 4],
        decoder_blocks_per_stage = [4, 4, 4, 4, 4, 4, 4]
    )
    vae = DCAE(cfg)
    vae.load_state_dict(torch.load("/home/developer/workspace/models/cod_128x.pt"))
    vae.encoder = None
    vae.bfloat16().cuda()
    print(mean_dcae.shape)
    # Process in batches to avoid memory issues
    batch_size = 4
    recon_vid_list = []
    
    for i in range(0, mean_dcae.shape[0], batch_size):
        batch = mean_dcae[i:i+batch_size].bfloat16().cuda()
        with torch.no_grad():
            batch_recon = vae.decoder(batch)
        recon_vid_list.append(batch_recon.cpu())
    
    recon_vid = torch.cat(recon_vid_list, dim=0)

In [None]:
from lat2lat.utils import create_video_visualization, export_video_as_gif
from IPython.display import display, HTML

# Example usage with your reconstructed video
print("Creating video visualizations...")

# 1. Animated playback
print("\n1. Creating animated playback...")
anim, video_np = create_video_visualization(recon_vid, "Reconstructed Video Playback")
display(HTML(anim.to_jshtml()))

# 2. Export as GIF
print("\n2. Exporting as GIF...")
export_video_as_gif(video_np, "reconstructed_video.gif", fps=30)