In [1]:
import os, random, math
import subprocess
import numpy as np
from tqdm import tqdm
from pathlib import Path
import cv2
import torch
from torchvision import transforms
from PIL import Image

os.environ['all_proxy']=''
os.environ['all_proxy']=''

In [2]:
import ollama
import base64
# import decord
import io
from icecream import ic


In [3]:
from constants import DEFAULT_HEIGHT_BUCKETS, DEFAULT_WIDTH_BUCKETS, DEFAULT_FRAME_BUCKETS

def get_frames(inp: str, w: int, h: int, start_sec: float = 0, duration: float = None, f: int = None, fps = None) -> np.ndarray:
    args = []
    if duration is not None:
        args += ["-t", f"{duration:.2f}"]
    elif f is not None:
        args += ["-frames:v", str(f)]
    if fps is not None:
        args += ["-r", str(fps)]
    
    args = ["ffmpeg", "-nostdin", "-ss", f"{start_sec:.2f}", "-i", inp, *args, 
        "-f", "rawvideo", "-pix_fmt", "rgb24", "-s", f"{w}x{h}", "pipe:"]
    
    process = subprocess.Popen(args, stderr=-1, stdout=-1)
    out, err = process.communicate()
    retcode = process.poll()
    if retcode:
        raise Exception(f"{inp}: ffmpeg error: {err.decode('utf-8')}")

    process.terminate()
    return np.frombuffer(out, np.uint8).reshape(-1, h, w, 3) # b, h, w, c
    

In [4]:
class Captioner():
    def __init__(self, model="minicpm-v:8b-2.6-q5_0", prompt=None):
        self.client = ollama.Client()
        self.model = model
        # default_prompt = """describe this video in this order: camera angle, main subject, make the description short"""
        default_prompt = "describe this video in short"
        self.prompt = prompt or default_prompt
        
        start = ["The", "This"]
        kind = ["video", "image", "scene", "animated sequence"]
        act = ["displays", "shows", "features", "is", "depicts", "presents", "showcases", "captures" ]
        
        bad_phrese = []
        for ss in start:
            for kk in kind:
                for aa in act:
                    bad_phrese.append(f"{ss} {kk} {aa}")
                    
        self.should_remove_phrese=[
            "In the video",
        ] + bad_phrese
        
    @staticmethod
    def pil_to_base64(image):
      byte_stream = io.BytesIO()
      image.save(byte_stream, format='JPEG')
      byte_stream.seek(0)
      return base64.b64encode(byte_stream.read()).decode('utf-8')
    
    def remove_phrese(self, cap):
        # only keep the primary part of the caption
        if "\n\n" in cap:
            cap = cap.split("\n\n")[0]
        
        for ii in self.should_remove_phrese:
            cap = cap.replace(ii, "")
            
        return cap
        
    def get_caption(self, frames, size=(640, 320), frame_skip=4):
        # 24fps to 8fps
        frames = frames[::frame_skip]
        if isinstance(frames, np.ndarray):
            frames = [Image.fromarray(image).convert("RGB").resize(size) for image in frames]
        else:
            frames = [transforms.ToPILImage()(image.permute(2, 0, 1)).convert("RGB").resize(size) for image in frames]
        images = [ self.pil_to_base64(image) for image in frames]
        
        response = self.client.chat(
            model=self.model,
            messages=[{
              "role":"user",
              "content": self.prompt, # "describe this video in short",
              "images": images }
            ]
        )
        cap = response["message"]["content"]
        return self.remove_phrese(cap)
        

In [20]:
class VideoFramesDataset(torch.utils.data.Dataset):
    def __init__(
        self, 
        video_dir: str,
        cach_dir: str,
        width: int = 1024,
        height: int = 576,
        num_frames: int = 49, 
        fps: int = 24,
        # to filter out short clips
        get_frames_max: int = 30 * 24, # prevent super long videos
        # cach_frames_min: int = 9,
        prompt_prefix = "freeze time, camera orbit left,",
    ):
        super().__init__()
        assert width in DEFAULT_WIDTH_BUCKETS, f"width only supported in: {DEFAULT_WIDTH_BUCKETS}"
        assert height in DEFAULT_HEIGHT_BUCKETS, f"height only supported in: {DEFAULT_HEIGHT_BUCKETS}"
        assert num_frames in DEFAULT_FRAME_BUCKETS, f"frames should in: {DEFAULT_FRAME_BUCKETS}"
        
        self.width = width
        self.height = height
        self.num_frames = num_frames
        self.fps = fps
        self.video_dir = video_dir
        
        self.cach_dir = Path(f"{cach_dir}_{num_frames}x{width}x{height}")
        self.cach_dir.mkdir(parents=True, exist_ok=True)
        
        self.get_frames_max = get_frames_max
        self.prompt_prefix = prompt_prefix
        
        # self.cach_frames_min = cach_frames_min
        self.videos = []
        
        self.data = []
        # load from cache
        # for root, dirs, files in os.walk(self.cach_dir):
        #     for file in files:
        #         if file.endswith('.pt') and file[0] != ".":
        #             self.data.append(os.path.join(root, file))
        # print(f"load cached videos: {len(self.data)}")
        
        if video_dir is not None:
            self.load_videos()
        
    def load_videos(self):
        videos = []
        for root, dirs, files in os.walk(self.video_dir):
            for file in files:
                if (file.endswith('.mp4') or file.endswith('.mov')) and file[0] != ".":
                    videos.append(os.path.join(root, file))
        assert len(videos) > 0, "目标文件夹内没有视频文件"
        
        self.videos = videos
        print(f"{type(self).__name__} found {len(self.videos)} videos ")
        return videos
    
    def to_tensor(self, data, device="cuda", dtype=torch.bfloat16):
        input = (data / 255) * 2.0 - 1.0
        # from (t, h, w, c) to b (t  c, h, w)
        return torch.from_numpy(input).permute(0, 3, 1, 2).unsqueeze(0).to(device, dtype=dtype)

    def cache_frames(self, to_latent, to_caption, to_embedding, device="cuda"):
        print(f"building caches, video count: {len(self.videos)}")

        for vid in tqdm(self.videos):
            dest = os.path.join(self.cach_dir, os.path.basename(vid).rsplit(".", 1)[0] + ".pt")
            ic(dest)
            if os.path.exists(dest):
                continue
            try:
                video_frames = get_frames(vid, self.width, self.height, 0,  f=self.get_frames_max, fps=self.fps)
            except:
                print("error file:", vid)
                continue
            if len(video_frames) < self.num_frames:
                continue
            # divid into parts
            iters = len(video_frames) // self.num_frames
            latents = []
            embedds = []
            masks = []
            captions = []
            for idx in range(iters):
                frames = video_frames[ idx*self.num_frames : (idx + 1) * self.num_frames ]
                ic(frames.shape) 
                caption = self.prompt_prefix + " " + to_caption(frames)
                ic(caption)
                emebedding, mask = to_embedding(caption.replace("  ", " "))
                ic(emebedding.shape, mask.shape)
                # should be 1, 49, 3, 512, 768 (b, f, c, h, w)
                latent = to_latent(self.to_tensor(frames, device=device))
                assert latent.ndim == 3, "patched latent should have 3 dims"
                ic(latent.shape)
                
                captions.append(caption)
                embedds.append(emebedding)
                masks.append(mask)
                latents.append(latent)

            latents = torch.cat(latents, dim=0)
            embedds = torch.cat(embedds, dim=0)
            masks = torch.cat(masks, dim=0)
            
            # print(latent.shape, latent_lr.shape)
            # np.savez(dest, hr=latent, lr=latent_lr)
            torch.save(dict(latents=latents, embedds=embedds, masks=masks, captions=captions), dest)
            self.data.append(dest)
            
        print(f">> cached {len(self.data)} videos")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.videos[idx]
        


In [21]:
import argparse
from yaml import load, dump, Loader, Dumper
from tqdm import tqdm

from ltx_video_lora import *

# ------------------- 
config_file = "./configs/ltx.yaml"
device = "cuda"
dtype = torch.bfloat16
# ------------------- 

config_dict = load(open(config_file, "r"), Loader=Loader)
args = argparse.Namespace(**config_dict)


# ----------- prepare models -------------
dataset = VideoFramesDataset(
    video_dir="/home/eisneim/Videos",
    cach_dir="/home/eisneim/www/ml/video_gen/ltx_training/data/ltxv_disney",
    width=768,
    height=512,
    num_frames= 49,
    prompt_prefix=args.id_token,
)

captioner = Captioner()



VideoFramesDataset found 24 videos 


In [17]:
cond_models = load_condition_models()
tokenizer, text_encoder = cond_models["tokenizer"], cond_models["text_encoder"]
text_encoder = text_encoder.to(device, dtype=dtype)
vae = load_latent_models()["vae"].to(device, dtype=dtype)
# vae.enable_tiling()
# vae.enable_slicing()

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [18]:
def to_latent(frames_tensor):
    ic(frames_tensor.shape)
    assert frames_tensor.size(2) == 3, f"frames should be in shape: (b, f, c, h, w) provided: {frames_tensor.shape}"
    with torch.no_grad():
        return prepare_latents(
                vae=vae,
                image_or_video=frames_tensor,
                device=device,
                dtype=dtype,
            )["latents"].cpu()
        
def to_embedding(caption):
    with torch.no_grad():
        text_conditions = prepare_conditions(
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            prompt=caption,
        )
        prompt_embeds = text_conditions["prompt_embeds"].to("cpu", dtype=torch.bfloat16)
        prompt_attention_mask = text_conditions["prompt_attention_mask"].to("cpu", dtype=torch.bfloat16)
    return prompt_embeds, prompt_attention_mask



In [None]:
dataset.cache_frames(to_latent, captioner.get_caption, to_embedding)

In [23]:
def _unpack_latents(
        latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
    ) -> torch.Tensor:
    # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
    # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
    # what happens in the `_pack_latents` method.
    batch_size = latents.size(0)
    latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
    latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
    return latents

  data = torch.load(file)


torch.Size([2688, 128])

In [32]:

file = "/home/eisneim/www/ml/video_gen/ltx_training/data/ltxv_disney_49x768x512/v0200fg10000c2hqshct1rmi4vupplog.pt"
data = torch.load(file)
ll = data["latents"][0]
ll.shape

lt = _unpack_latents(ll.unsqueeze(dim=0).to(device, dtype=dtype), 7, 512//32,  768//32)
# lt = _unpack_latents(ll.unsqueeze(dim=0).to(device, dtype=dtype), 7, 768//32, 512//32)
print(ll.shape)
lt.shape

torch.Size([2688, 128])


  data = torch.load(file)


torch.Size([1, 128, 7, 16, 24])

In [8]:
vae = load_latent_models()["vae"].to(device, dtype=dtype)

In [43]:
from diffusers.utils import export_to_video
from diffusers.video_processor import VideoProcessor
with torch.no_grad():
    video =  vae.decode(lt, return_dict=False)[0]
pcc = VideoProcessor(vae_scale_factor=32)
vv = pcc.postprocess_video(video)[0]
export_to_video(vv, "data/test_rec3.mp4")

'data/test_rec3.mp4'

In [42]:
lt = _unpack_latents(ll.unsqueeze(dim=0).to(device, dtype=dtype), 7, 512//32,  768//32)
# lt = lt.permute(0, 2, 1, 3, 4)
lt.shape

torch.Size([1, 128, 7, 16, 24])

'test.mp4'

In [13]:
tt = torch.randn(1, 49, 3, 512, 768)
out = to_latent(tt)
out.shape

ic| frames_tensor.shape: torch.Size([1, 49, 3, 512, 768])


torch.Size([1, 256, 7, 16, 24])

torch.Size([1, 2688, 128])