In [None]:
!git clone https://github.com/baaivision/NOVA.git

In [None]:
!pip install -r /content/NOVA/requirements.txt

In [None]:
%cd /content/NOVA
!pip install -e /content/NOVA/diffnext

In [None]:
!cd /content/NOVA

In [None]:
import os
import gc
import torch
import numpy as np
from diffnext.pipelines import NOVAPipeline
from PIL import Image
import subprocess
import tempfile

class StreamlinedNOVA:
    def __init__(self, model_id, device="cuda"):
        self.pipe = NOVAPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16,
            trust_remote_code=True
        )
        os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
        self.pipe.enable_model_cpu_offload()
        self.device = device
        torch.cuda.empty_cache()
        gc.collect()

    def generate_frame(self, prompt):
        with torch.cuda.amp.autocast():
            result = self.pipe(
                prompt,
                max_latent_length=1,
                num_inference_steps=20,
                guidance_scale=4.0
            ).frames[0]

            if isinstance(result, Image.Image):
                result = np.array(result)
                result = torch.from_numpy(result)
            elif not isinstance(result, torch.Tensor):
                result = torch.tensor(result)

            # Remover dimensões de lote e canal unitário
            frame_cpu = result.squeeze().cpu().detach()
            torch.cuda.empty_cache()
            return frame_cpu

    def upscale_frame(self, frame, size=(512, 512)):
        if isinstance(frame, torch.Tensor):
            frame = frame.squeeze().cpu().numpy()

        # Remover dimensões extras (ex: batch)
        while frame.ndim > 3:
            frame = frame.squeeze(0)

        # Converter para (H, W, C)
        if frame.ndim == 3 and frame.shape[0] in [1, 3]:
            frame = np.transpose(frame, (1, 2, 0))

        if frame.ndim == 2:
            frame = np.stack([frame] * 3, axis=-1)
        elif frame.ndim == 3 and frame.shape[-1] == 1:
            frame = np.repeat(frame, 3, axis=-1)

        if frame.dtype != np.uint8:
            frame = (frame - frame.min()) / (frame.max() - frame.min()) * 255
            frame = frame.astype(np.uint8)

        img = Image.fromarray(frame)
        img = img.resize(size, Image.BILINEAR)
        return np.array(img)

    def generate_video_ffmpeg(self, prompt, num_frames=50, video_path="streamlined_video.mp4", fps=8, upscale_size=(512, 512)):
        temp_dir = tempfile.mkdtemp(prefix="temp_frames_")
        saved_frames = []

        for i in range(num_frames):
            print(f"Gerando frame {i+1}/{num_frames}")
            try:
                frame = self.generate_frame(prompt)
                frame_np = self.upscale_frame(frame, upscale_size)
                frame_filename = os.path.join(temp_dir, f"frame_{i:04d}.png")
                Image.fromarray(frame_np).save(frame_filename)
                saved_frames.append(frame_filename)
            except Exception as e:
                print(f"Erro no frame {i+1}: {str(e)}")
            finally:
                del frame, frame_np
                torch.cuda.empty_cache()
                gc.collect()

        if not saved_frames:
            raise RuntimeError("Nenhum frame foi gerado corretamente.")

        ffmpeg_cmd = [
            "ffmpeg", "-y", "-framerate", str(fps),
            "-start_number", "0",
            "-i", os.path.join(temp_dir, "frame_%04d.png"),
            "-c:v", "libx264", "-pix_fmt", "yuv420p",
            video_path
        ]

        try:
            subprocess.run(ffmpeg_cmd, check=True)
        except subprocess.CalledProcessError as e:
            print("Erro FFmpeg. Verifique se os frames foram gerados corretamente.")
            raise

        # Limpeza
        for f in saved_frames:
            os.remove(f)
        os.rmdir(temp_dir)

        return video_path

if __name__ == "__main__":
    model_id = "BAAI/nova-d48w1024-osp480"
    prompt = "mulher caminhando na cidade"

    streamlined = StreamlinedNOVA(model_id)
    video_file = streamlined.generate_video_ffmpeg(prompt, num_frames=50)
    print("Vídeo salvo em:", os.path.abspath(video_file))