# Stable Video Diffusion Video to Video


In [None]:
import os
import cv2
import gc
import torch
from diffusers import (
    StableDiffusionControlNetPipeline,
    StableVideoDiffusionPipeline,
    ControlNetModel,
)
from diffusers.utils import export_to_gif, make_image_grid
from PIL import Image as PILImage

from controlnet_aux import CannyDetector
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
from diffusers.blip import blip_decoder
from IPython.display import display, Image

input_path = "/data/noah/inference/simulation/svd_test/input"
output_path = "/data/noah/inference/simulation/svd_test/output/output.gif"
device = "cuda:0"
vid_device = "cuda:1"
num_frames = 12
height = 512
width = 768
# torch.cuda.set_device(device)
generator = torch.Generator().manual_seed(100)
import numpy as np


def extract_frame(video_path):
    images = []
    edges = []
    cap = cv2.VideoCapture(video_path)
    canny = CannyDetector()

    while True:
        ret, frame = cap.read()

        if not ret:
            break

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = PILImage.fromarray(frame)

        images.append(frame)

        edges.append(
            canny(
                frame,
                detect_resolution=frame.height,
                image_resolution=frame.height,
                low_threshold=100,
                high_threshold=200,
            )
        )
    cap.release()
    return images, edges


def load_demo_image(image, image_size, device):
    w, h = image.size
    # display(raw_image.resize((w//5,h//5)))

    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ]
    )
    image = transform(image).unsqueeze(0).to(device)
    return image


model = blip_decoder(
    pretrained="/data/noah/ckpt/pretrain_ckpt/BLIP/model_large_caption.pth", image_size=512, vit="large"
)
model.eval()
model = model.to(device)

canny = CannyDetector()

# Load the motion adapter
model_id = "/data/noah/ckpt/pretrain_ckpt/StableDiffusion/rv"
lora_id = "/data/noah/ckpt/pretrain_ckpt/StableDiffusion/lora_detail"
lora_name = "add_detail.safetensors"

controlnet = ControlNetModel.from_pretrained("/data/noah/ckpt/finetuning/Control_SD_AD", torch_dtype=torch.float16).to(
    device
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    model_id,
    controlnet=controlnet,
    torch_dtype=torch.float16,
)
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
pipe.load_lora_weights(lora_id, weight_name=lora_name)
pipe = pipe.to(device)

vid_pipe = StableVideoDiffusionPipeline.from_pretrained(
    "/data/noah/ckpt/pretrain_ckpt/StableDiffusion/svd_xt", torch_dtype=torch.float16
)
vid_pipe = vid_pipe.to(vid_device)

In [None]:
result_images = []
input_images = []
control_images = []

result_vid_images = []
input_vid_images = []
control_vid_images = []

for video_name in os.listdir(input_path):
    video_path = os.path.join(input_path, video_name)
    images, edges = extract_frame(video_path)[:num_frames]
    blip_image = load_demo_image(image=images[0], image_size=height, device=device)

    with torch.no_grad():
        caption = model.generate(blip_image, sample=True, top_p=0.9, max_length=20, min_length=5)[0]
        print(caption)

    input_images.append(images[0])
    control_images.append(edges[0])
    input_vid_images.append(images)
    control_vid_images.append(edges)

    images = [image.resize((width, height)) for image in images]

    prompt = "{}, outdoor, best quality, extremely detailed, clearness, naturalness, film grain, crystal clear, photo with color, actuality, <lora:add-detail:-1>".format(
        caption
    )
    negative_prompt = "cartoon, anime, painting, disfigured, immature, blur, picture, 3D, render, semi-realistic, drawing, poorly drawn, bad anatomy, wrong anatomy, gray scale, worst quality, low quality, sketch"

    result_image = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        image=edges[0],
        height=height,  # 512
        width=width,  # 768
        guidance_scale=7,  # 7
        num_inference_steps=40,  # 20
        controlnet_conditioning_scale=0.75,
        generator=generator,
    ).images[0]

    frames = vid_pipe(
        image=result_image,
        height=height,  # 512
        width=width,  # 768
        num_frames=len(images),
        num_inference_steps=25,
        decode_chunk_size=8,
        motion_bucket_id=230,
        noise_aug_strength=0,
        min_guidance_scale=1.0,
        max_guidance_scale=3.0,
        generator=generator,
    ).frames[0]

    new_w, new_h = edges[0].width, edges[0].height
    result_image = result_image.resize((new_w, new_h))
    result_images.append(result_image)

    for idx, f in enumerate(frames):
        frames[idx] = f.resize((new_w, new_h))
    result_vid_images.append(frames)

image_grid = make_image_grid(input_images, rows=1, cols=len(input_images))
control_grid = make_image_grid(control_images, rows=1, cols=len(input_images))
result_grid = make_image_grid(result_images, rows=1, cols=len(input_images))
result_grid = make_image_grid([image_grid, control_grid, result_grid], rows=3, cols=1)
display(result_grid)

grids = []
result_vid_images = np.array(result_vid_images).astype("uint8")
input_vid_images = np.array(input_vid_images).astype("uint8")
control_vid_images = np.array(control_vid_images).astype("uint8")
num_video = len(os.listdir(input_path))

for idx in range(num_frames):
    sub_result_images = [PILImage.fromarray(r) for r in result_vid_images[:, idx, ...]]
    sub_input_images = [PILImage.fromarray(i) for i in input_vid_images[:, idx, ...]]
    sub_control_images = [PILImage.fromarray(c) for c in control_vid_images[:, idx, ...]]

    image_grid = make_image_grid(sub_input_images, rows=1, cols=num_video)
    result_grid = make_image_grid(sub_result_images, rows=1, cols=num_video)
    grid = make_image_grid([image_grid, result_grid], rows=2, cols=1)
    grids.append(grid.resize((grid.width // 2, grid.height // 2)))

export_to_gif(grids, output_path)
display(Image(output_path))

# Stable Video Diffusion ControlNet


In [None]:
import os
import cv2
import gc
import torch
from diffusers import (
    StableDiffusionControlNetPipeline,
    StableVideoDiffusionControlNetPipeline,
    ControlNetSVDModel,
    ControlNetModel,
)
from diffusers.utils import export_to_gif, make_image_grid
from PIL import Image as PILImage

from controlnet_aux import CannyDetector
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
from diffusers.blip import blip_decoder
from IPython.display import display, Image

input_path = "/data/noah/inference/simulation/svd_test/input"
output_path = "/data/noah/inference/simulation/svd_test/output/output.gif"
device = "cuda:2"
num_frames = 8
height = 512
width = 768
# torch.cuda.set_device(device)
generator = torch.Generator().manual_seed(100)
import numpy as np


def extract_frame(video_path):
    images = []
    edges = []
    cap = cv2.VideoCapture(video_path)
    canny = CannyDetector()

    while True:
        ret, frame = cap.read()

        if not ret:
            break

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = PILImage.fromarray(frame)

        images.append(frame)

        edges.append(
            canny(
                frame,
                detect_resolution=frame.height,
                image_resolution=frame.height,
                low_threshold=50,
                high_threshold=100,
            )
        )
    cap.release()
    return images, edges


def load_demo_image(image, image_size, device):
    w, h = image.size
    # display(raw_image.resize((w//5,h//5)))

    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ]
    )
    image = transform(image).unsqueeze(0).to(device)
    return image


model = blip_decoder(
    pretrained="/data/noah/ckpt/pretrain_ckpt/BLIP/model_large_caption.pth", image_size=512, vit="large"
)
model.eval()
model = model.to(device)

canny = CannyDetector()

model_id = "/data/noah/ckpt/pretrain_ckpt/StableDiffusion/rv"
lora_id = "/data/noah/ckpt/pretrain_ckpt/StableDiffusion/lora_detail"
lora_name = "add_detail.safetensors"

controlnet = ControlNetModel.from_pretrained("/data/noah/ckpt/finetuning/Control_SD_AD", torch_dtype=torch.float16).to(
    device
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    model_id,
    controlnet=controlnet,
    torch_dtype=torch.float16,
)
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
pipe.load_lora_weights(lora_id, weight_name=lora_name)
pipe = pipe.to(device)

result_images = []
input_images = []
control_images = []

result_vid_images = []
input_vid_images = []
control_vid_images = []

for video_name in os.listdir(input_path):
    video_path = os.path.join(input_path, video_name)
    images, edges = extract_frame(video_path)[:num_frames]
    blip_image = load_demo_image(image=images[0], image_size=height, device=device)

    with torch.no_grad():
        caption = model.generate(blip_image, sample=True, top_p=0.9, max_length=20, min_length=5)[0]
        print(caption)

    input_images.append(images[0])
    control_images.append(edges[0])
    input_vid_images.append(images)
    control_vid_images.append(edges)

    images = [image.resize((width, height)) for image in images]

    prompt = "{}, outdoor, best quality, extremely detailed, clearness, naturalness, film grain, crystal clear, photo with color, actuality, <lora:add-detail:-1>".format(
        caption
    )
    negative_prompt = "cartoon, anime, painting, disfigured, immature, blur, picture, 3D, render, semi-realistic, drawing, poorly drawn, bad anatomy, wrong anatomy, gray scale, worst quality, low quality, sketch"

    result_image = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        image=edges[0],
        height=height,  # 512
        width=width,  # 768
        guidance_scale=6.5,  # 7
        num_inference_steps=40,  # 20
        controlnet_conditioning_scale=0.75,
        generator=generator,
    ).images[0]

    new_w, new_h = edges[0].width, edges[0].height
    result_image = result_image.resize((new_w, new_h))
    result_images.append(result_image)

In [None]:
controlnet_svd = ControlNetSVDModel.from_pretrained(
    "/data/noah/ckpt/finetuning/SVD_CON_AD_SEQ/checkpoint-40000/controlnet", torch_dtype=torch.float16
).to(device)

vid_pipe = StableVideoDiffusionControlNetPipeline.from_pretrained(
    "/data/noah/ckpt/finetuning/SVD_CON_AD_SEQ", controlnet=controlnet_svd, torch_dtype=torch.float16
)
vid_pipe = vid_pipe.to(device)

for idx, control_frames in enumerate(control_vid_images):
    frames = vid_pipe(
        image=result_images[idx].resize((width, height)),
        controlnet_condition=control_frames,
        height=height,  # 512
        width=width,  # 768
        num_frames=len(control_frames),
        num_inference_steps=25,
        min_guidance_scale=1.0,
        max_guidance_scale=3.0,
        motion_bucket_id=250,
        noise_aug_strength=0,
        controlnet_cond_scale=1.0,
        # decode_chunk_size=8,
        decode_chunk_size=1,
        generator=generator,
    ).frames[0]

    new_w, new_h = result_images[idx].width, result_images[idx].height
    for ind, f in enumerate(frames):
        frames[ind] = f.resize((new_w, new_h))

    result_vid_images.append(frames)

image_grid = make_image_grid(input_images, rows=1, cols=len(input_images))
control_grid = make_image_grid(control_images, rows=1, cols=len(input_images))
result_grid = make_image_grid(result_images, rows=1, cols=len(input_images))
result_grid = make_image_grid([image_grid, control_grid, result_grid], rows=3, cols=1)
display(result_grid)

grids = []
result_vid_images = np.array(result_vid_images).astype("uint8")
input_vid_images = np.array(input_vid_images).astype("uint8")
control_vid_images = np.array(control_vid_images).astype("uint8")
num_video = len(os.listdir(input_path))

for idx in range(num_frames):
    sub_result_images = [PILImage.fromarray(r) for r in result_vid_images[:, idx, ...]]
    sub_input_images = [PILImage.fromarray(i) for i in input_vid_images[:, idx, ...]]
    sub_control_images = [PILImage.fromarray(c) for c in control_vid_images[:, idx, ...]]

    image_grid = make_image_grid(sub_input_images, rows=1, cols=num_video)
    condition_grid = make_image_grid(sub_control_images, rows=1, cols=num_video)
    result_grid = make_image_grid(sub_result_images, rows=1, cols=num_video)
    grid = make_image_grid([image_grid, condition_grid, result_grid], rows=3, cols=1)
    grids.append(grid.resize((grid.width // 2, grid.height // 2)))

export_to_gif(grids, output_path)
display(Image(output_path))