In [None]:
# from huggingface_hub import hf_hub_download

# ltxv_model_path = hf_hub_download(
#     repo_id="Lightricks/LTX-Video",
#     filename="./ltxv-2b-0.9.8-distilled.safetensors",
#     repo_type="model",
# )

In [None]:
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder

device = 'cuda'

ltxv_model_path = './ltxv-2b-0.9.8-distilled.safetensors'
vae = CausalVideoAutoencoder.from_pretrained(ltxv_model_path)
vae.to(device)
print("-")

In [None]:
from torchvision.io import read_video
from einops import rearrange
import torch

vid_path = '/workspace/sixteen128.mp4'

video_frames, audio, info = read_video(vid_path)
print("info : ", info)

print(video_frames.shape, video_frames.dtype)  # (T, H, W, C)  # T: frame count, C=3
video_frames = rearrange(video_frames.unsqueeze(dim=0).tile(2, 1, 1, 1, 1), 'b t h w c -> b c t h w').to(device).to(torch.float32)
video_frames = video_frames[:, :, :65, :, :] # 8의 배수 + 1이어야 한다.
print(video_frames.shape, video_frames[0][0][0][0])

video_frames = (video_frames/255) * 2 - 1

In [None]:
with torch.no_grad():
    aeoutput = vae.encode(video_frames)
    latent = aeoutput.latent_dist.mode()
    print(latent.shape)

del aeoutput
torch.cuda.empty_cache()

In [None]:
import numpy as np
import os
import torch

fss = os.listdir("/workspace/AVE_Dataset/AVE_latents/")
print(fss[0])

npz = np.load("/workspace/AVE_Dataset/AVE_latents/" + fss[0])
print(npz.shape)
tt = torch.tensor(npz).unsqueeze(dim=0).to(device)
print(tt.shape)

timestep = torch.ones(1, device=device) * 0.1
ts = torch.randn(((1, 3, 240, 128, 128)))

reconstructed_videos = vae.decode(
    tt, target_shape=tt.shape, timestep=timestep
).sample

print(reconstructed_videos.shape)

In [None]:
timestep = torch.ones(video_frames.shape[0], device=device) * 0.1

reconstructed_videos = vae.decode(
    latent[:, :, :4, :, :], target_shape=video_frames[:, :, :49, :, :].shape, timestep=timestep
).sample

print(reconstructed_videos.shape)

In [None]:
import torchvision
from torchvision.io import read_video, write_video
from einops import rearrange

output_path = './reencoded_video12816test_2.mp4'
recon_videos = (rearrange(reconstructed_videos, "b c t h w -> b t h w c")[0].cpu()/2 + 0.5) * 255

write_video(
    filename=output_path,
    video_array=recon_videos,      # shape: (T, H, W, C)
    fps=int(24),
    video_codec='libx264',         # optional
    options={"crf": "18"}          # optional: quality setting
)

In [None]:
recon_videos[0][0][0]

In [None]:
video_frames[0][0][0][0]

In [None]:
import torch.nn as nn

# spatial 4×4를 1×1로 줄이기 위한 3D‑Conv
#   kernel_size=(1,4,4), stride=(1,4,4)
conv3d = nn.Conv3d(
    in_channels=128,
    out_channels=128,
    kernel_size=(1, 4, 4),
    stride=(1, 4, 4),
    padding=0,
).to(device)

y = conv3d(latent).squeeze()
print(y.shape)

In [None]:
import torch.nn as nn

class GRN(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, dim))

    def forward(self, x):
        Gx = torch.norm(x, p=2, dim=1, keepdim=True)
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
        return self.gamma * (x * Nx) + self.beta + x

class ConvNeXtV2Block(nn.Module):
    def __init__(
        self,
        dim: int,
        intermediate_dim: int,
        dilation: int = 1,
    ):
        super().__init__()
        padding = (dilation * (7 - 1)) // 2
        self.dwconv = nn.Conv1d(
            dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
        )  # depthwise conv
        self.norm = nn.LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, intermediate_dim)  # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.grn = GRN(intermediate_dim)
        self.pwconv2 = nn.Linear(intermediate_dim, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        input x is sequence, tensor of (bs, seq_len, dim)
        """
        residual = x
        x = x.transpose(1, 2)  # b n d -> b d n
        x = self.dwconv(x)
        x = x.transpose(1, 2)  # b d n -> b n d
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.grn(x)
        x = self.pwconv2(x)
        return residual + x

cnv = ConvNeXtV2Block(128, 64).to(device)

In [None]:
# x = torch.randn((2, 13, 32))
out = cnv(rearrange(y, 'b d n -> b n d'))
print(out.shape)

In [None]:
import torch.nn as nn

class UpsampleHalveChannel(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose1d(
            in_channels=in_ch,
            out_channels=out_ch,
            kernel_size=4,
            stride=2,
            padding=1,
        )
        self.norm = nn.LayerNorm(out_ch)
        self.act = nn.GELU()

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.up(x)                   # → [B, C/2, 2L]
        x = x.transpose(1, 2)            # → [B, 2L, C/2]
        x = self.norm(x)
        return self.act(x)

upconv = UpsampleHalveChannel(128, 64).to(device)

In [None]:
upconv(rearrange(y, 'b d n -> b n d')).shape

In [None]:
layers = nn.Linear(128, 64)

In [None]:
total_params = sum(p.numel() for p in conv3d.parameters())
print(f"conv3d parameters: {total_params:,}")

total_params = sum(p.numel() for p in cnv.parameters())
print(f"cnv parameters: {total_params*4:,}")

total_params = sum(p.numel() for p in layers.parameters())
print(f"layers parameters: {total_params:,}")