In [None]:
import datetime
import inspect
import os
from omegaconf import OmegaConf

import torch
import torchvision.transforms as transforms

from diffusers import AutoencoderKL, DDIMScheduler

from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from animatediff.models.unet import UNet3DConditionModel
from animatediff.models.sparse_controlnet import SparseControlNetModel
from animatediff.pipelines.pipeline_animation import AnimationPipeline, AnimationPipelineOutput
from animatediff.utils.util import save_videos_grid
from animatediff.utils.util import load_weights
from diffusers.utils.import_utils import is_xformers_available

from einops import rearrange

from pathlib import Path
from PIL import Image
import numpy as np

os.environ['PYOPENGL_PLATFORM'] = 'egl'

In [None]:
# config for AnimateDiff
pretrained_model_path = "./models/StableDiffusion"
config_path = "./configs/prompts/v3/v3-1-T2V.yaml"
inference_config_path = './configs/inference/inference-v3.yaml'
time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
savedir = f"samples/{Path(config_path).stem}-{time_str}"

config = OmegaConf.load(config_path)
L, W, H = 16, 512, 512

# SD component loading
tokenizer    = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").cuda()
vae          = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").cuda()

In [None]:
example_idx = 0
example_config = config[example_idx]

inference_config = OmegaConf.load(example_config['inference_config'])
example_config.W = example_config.get("W", W)
example_config.L = example_config.get("L", L)
example_config.H = example_config.get("H", H)

inference_config = OmegaConf.load(example_config.get("inference_config", inference_config_path))

# AnimateDiff loading
unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)).cuda()
controlnet = None
# load controlnet (rgb) and preprocess condition image
if example_config.get("controlnet_path", "") != "":
    assert example_config.get("controlnet_images", "") != ""
    assert example_config.get("controlnet_config", "") != ""
    
    unet.config.num_attention_heads = 8
    unet.config.projection_class_embeddings_input_dim = None

    controlnet_config = OmegaConf.load(example_config.controlnet_config)
    controlnet = SparseControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {}))

    print(f"loading controlnet checkpoint from {example_config.controlnet_path} ...")
    controlnet_state_dict = torch.load(example_config.controlnet_path, map_location="cpu")
    controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
    controlnet_state_dict.pop("animatediff_config", "")
    controlnet.load_state_dict(controlnet_state_dict)
    controlnet.cuda()

    image_paths = example_config.controlnet_images
    if isinstance(image_paths, str): image_paths = [image_paths]

    print(f"controlnet image paths:")
    for path in image_paths: print(path)
    assert len(image_paths) <= example_config.L

    image_transforms = transforms.Compose([
        transforms.RandomResizedCrop(
            (example_config.H, example_config.W), (1.0, 1.0), 
            ratio=(example_config.W/example_config.H, example_config.W/example_config.H)
        ),
        transforms.ToTensor(),
    ])

    if example_config.get("normalize_condition_images", False):
        def image_norm(image):
            image = image.mean(dim=0, keepdim=True).repeat(3,1,1)
            image -= image.min()
            image /= image.max()
            return image
    else: image_norm = lambda x: x
        
    controlnet_images = [image_norm(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths]

    os.makedirs(os.path.join(savedir, "control_images"), exist_ok=True)
    for i, image in enumerate(controlnet_images):
        Image.fromarray((255. * (image.numpy().transpose(1,2,0))).astype(np.uint8)).save(f"{savedir}/control_images/{i}.png")

    controlnet_images = torch.stack(controlnet_images).unsqueeze(0).cuda()
    controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w")

    if controlnet.use_simplified_condition_embedding:
        num_controlnet_images = controlnet_images.shape[2]
        controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w")
        controlnet_images = vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * 0.18215
        controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images)

# set xformers
if is_xformers_available():
    unet.enable_xformers_memory_efficient_attention()
    if controlnet is not None: 
        controlnet.enable_xformers_memory_efficient_attention()

pipeline = AnimationPipeline(
    vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
    controlnet=controlnet,
    scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
).to("cuda")

pipeline = load_weights(
    pipeline,
    # motion module
    motion_module_path         = example_config.get("motion_module", ""),
    motion_module_lora_configs = example_config.get("motion_module_lora_configs", []),
    # domain adapter
    adapter_lora_path          = example_config.get("adapter_lora_path", ""),
    adapter_lora_scale         = example_config.get("adapter_lora_scale", 1.0),
    # image layers
    dreambooth_model_path      = example_config.get("dreambooth_path", ""),
    lora_model_path            = example_config.get("lora_model_path", ""),
    lora_alpha                 = example_config.get("lora_alpha", 0.8),
).to("cuda")

In [None]:
from typing import Callable, List, Optional, Union

@torch.no_grad()
def anmd_sample(
    pipeline: AnimationPipeline,
    prompt: Union[str, List[str]],
    video_length: Optional[int],
    height: Optional[int] = None,
    width: Optional[int] = None,
    num_inference_steps: int = 50,
    guidance_scale: float = 7.5,
    negative_prompt: Optional[Union[str, List[str]]] = None,
    num_videos_per_prompt: Optional[int] = 1,
    eta: float = 0.0,
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
    latents: Optional[torch.FloatTensor] = None,
    output_type: Optional[str] = "tensor",
    return_dict: bool = True,
    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
    callback_steps: Optional[int] = 1,

    # support controlnet
    controlnet_images: torch.FloatTensor = None,
    controlnet_image_index: list = [0],
    controlnet_conditioning_scale: Union[float, List[float]] = 1.0,

    **kwargs,
):
    # Default height and width to unet
    height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
    width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor

    # Check inputs. Raise error if not correct
    pipeline.check_inputs(prompt, height, width, callback_steps)

    # Define call parameters
    # batch_size = 1 if isinstance(prompt, str) else len(prompt)
    batch_size = 1
    if latents is not None:
        batch_size = latents.shape[0]
    if isinstance(prompt, list):
        batch_size = len(prompt)

    device = pipeline._execution_device
    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    # corresponds to doing no classifier free guidance.
    do_classifier_free_guidance = guidance_scale > 1.0

    # Encode input prompt
    prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
    if negative_prompt is not None:
        negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size 
    text_embeddings = pipeline._encode_prompt(
        prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
    )

    # Prepare timesteps
    pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
    timesteps = pipeline.scheduler.timesteps

    # Prepare latent variables
    num_channels_latents = pipeline.unet.in_channels
    latents = pipeline.prepare_latents(
        batch_size * num_videos_per_prompt,
        num_channels_latents,
        video_length,
        height,
        width,
        text_embeddings.dtype,
        device,
        generator,
        latents,
    )
    latents_dtype = latents.dtype

    # Prepare extra step kwargs.
    extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)

    # Denoising loop
    num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order
    with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
            latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)

            down_block_additional_residuals = mid_block_additional_residual = None
            if (getattr(pipeline, "controlnet", None) != None) and (controlnet_images != None):
                assert controlnet_images.dim() == 5

                controlnet_noisy_latents = latent_model_input
                controlnet_prompt_embeds = text_embeddings

                controlnet_images = controlnet_images.to(latents.device)

                controlnet_cond_shape    = list(controlnet_images.shape)
                controlnet_cond_shape[2] = video_length
                controlnet_cond = torch.zeros(controlnet_cond_shape).to(latents.device)

                controlnet_conditioning_mask_shape    = list(controlnet_cond.shape)
                controlnet_conditioning_mask_shape[1] = 1
                controlnet_conditioning_mask          = torch.zeros(controlnet_conditioning_mask_shape).to(latents.device)

                assert controlnet_images.shape[2] >= len(controlnet_image_index)
                controlnet_cond[:,:,controlnet_image_index] = controlnet_images[:,:,:len(controlnet_image_index)]
                controlnet_conditioning_mask[:,:,controlnet_image_index] = 1

                down_block_additional_residuals, mid_block_additional_residual = pipeline.controlnet(
                    controlnet_noisy_latents, t,
                    encoder_hidden_states=controlnet_prompt_embeds,
                    controlnet_cond=controlnet_cond,
                    conditioning_mask=controlnet_conditioning_mask,
                    conditioning_scale=controlnet_conditioning_scale,
                    guess_mode=False, return_dict=False,
                )

            # predict the noise residual
            noise_pred = pipeline.unet(
                latent_model_input, t, 
                encoder_hidden_states=text_embeddings,
                down_block_additional_residuals = down_block_additional_residuals,
                mid_block_additional_residual   = mid_block_additional_residual,
            ).sample.to(dtype=latents_dtype)

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

            # call the callback, if provided
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
                progress_bar.update()
                if callback is not None and i % callback_steps == 0:
                    callback(i, t, latents)

    # Post-processing
    video = pipeline.decode_latents(latents)

    # Convert to tensor
    if output_type == "tensor":
        video = torch.from_numpy(video)

    if not return_dict:
        return video

    return AnimationPipelineOutput(videos=video)

In [None]:
# Sample preparation
prompts      = example_config.prompt
n_prompts    = list(example_config.n_prompt) * len(prompts) if len(example_config.n_prompt) == 1 else example_config.n_prompt

random_seeds = example_config.get("seed", [-1])
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds

example_config.random_seed = []

# Sampling
samples = []
sample_idx = 0
for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
    print(prompt)
    # manually set random seed for reproduction
    if random_seed != -1: 
        torch.manual_seed(random_seed)
    else: 
        torch.seed()
    example_config.random_seed.append(torch.initial_seed())

    print(f"current seed: {torch.initial_seed()}")
    print(f"sampling {prompt} ...")
    sample = anmd_sample(
    pipeline,
    prompt,
    negative_prompt     = n_prompt,
    num_inference_steps = example_config.steps,
    guidance_scale      = example_config.guidance_scale,
    width               = example_config.W,
    height              = example_config.H,
    video_length        = example_config.L,

    controlnet_images = controlnet_images,
    controlnet_image_index = example_config.get("controlnet_image_indexs", [0]),
    ).videos

    samples.append(sample)

    prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
    save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif")
    print(f"save to {savedir}/sample/{prompt}.gif")

    sample_idx += 1