In [None]:
import cv2
import torch
from diffusers import AutoencoderKL, ControlNetModel
from diffusers.pipelines import RAVEPipeline
from PIL import Image as PILImage
from diffusers.utils import export_to_gif
from IPython.display import display, Image

device = "cuda:0"
torch.cuda.set_device(device)


def extract_frame(video_path):
    frames = []
    cap = cv2.VideoCapture(video_path)

    while True:
        ret, frame = cap.read()

        if not ret:
            break

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = PILImage.fromarray(frame)
        frames.append(frame)
    cap.release()
    return frames


controlnet = ControlNetModel.from_pretrained(
    "/home/sckim/Dataset/control_ad/checkpoint-20000/controlnet", torch_dtype=torch.float16
).to(device)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to(device)
model_id = "/home/sckim/Dataset/sd"
pipe = RAVEPipeline.from_pretrained(model_id, controlnet=controlnet, vae=vae)

pipe.enable_vae_slicing()
pipe.enable_vae_tiling()
# pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
pipe = pipe.to(device=device, dtype=torch.float16)

video_path = "/home/sckim/Dataset/incabin_sample/reference_vid/sample_2.mp4"
video_out_path = "/home/sckim/Dataset/incabin_sample/out/sample_rave_2.gif"
frames = extract_frame(video_path)[:16]
display(frames[0])


prompt = "car is driving on the road, outdoor, best quality, extremely detailed, clearness, naturalness, film grain, crystal clear, photo with color, actuality"
negative_prompt = "cartoon, anime, painting, disfigured, immature, blur, picture, 3D, render, semi-realistic, drawing, poorly drawn, bad anatomy, wrong anatomy, gray scale, worst quality, low quality, sketch"
result_frames = pipe(
    video=frames,
    controlnet_processor_id="canny",
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=512,
    width=768,
    guidance_scale=7.5,
    strength=0.75,
    grid_size=2,
    use_shuffling=False,
    vae_batch_size=4,
    num_inference_steps=30,
).frames

new_w, new_h = frames[0].width, frames[0].height
for idx, f in enumerate(result_frames):
    result_frames[idx] = f.resize((new_w, new_h))

export_to_gif(result_frames, video_out_path)

display(Image(video_out_path))