In [None]:
!git clone https://github.com/guoyww/AnimateDiff.git
%cd AnimateDiff
!pip install -r requirements.txt

In [None]:
! curl -O https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_sparsectrl_rgb.ckpt

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from animatediff.models.unet import UNet3DConditionModel
from animatediff.pipelines.pipeline_animation import AnimationPipeline
from diffusers import DDIMScheduler, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
from einops import rearrange
import numpy as np
import os

In [None]:
import os
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.io import read_video
from typing import Tuple, Dict

class AnimateDiffDataset(Dataset):
    def __init__(
        self,
        root_dir: str,
        image_size: Tuple[int, int] = (512, 512),
        num_frames: int = 16,
        frame_stride: int = 4,
        image_transform: transforms.Compose = None,
        video_transform: transforms.Compose = None
    ):
        """
        Args:
            root_dir: Path to root directory containing sample folders
            image_size: Target resolution for resizing
            num_frames: Number of frames to extract from each video
            frame_stride: Number of frames to skip between sampled frames
            transforms: Optional custom transforms
        """
        self.root_dir = root_dir
        self.image_size = image_size
        self.num_frames = num_frames
        self.frame_stride = frame_stride
        self.samples = self._discover_samples()

        # Default transforms if not provided
        self.image_transform = image_transform or transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

        self.video_transform = video_transform or transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

    def _discover_samples(self) -> list:
        """Find valid sample folders containing both image and video"""
        samples = []
        for folder_name in os.listdir(self.root_dir):
            folder_path = os.path.join(self.root_dir, folder_name)
            if os.path.isdir(folder_path):
                image_path = os.path.join(folder_path, "image.jpg")
                video_path = os.path.join(folder_path, "video.mp4")
                if os.path.isfile(image_path) and os.path.isfile(video_path):
                    samples.append((folder_name, image_path, video_path))
        return samples

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Dict:
        folder_name, image_path, video_path = self.samples[idx]

        # Load and process image
        image = Image.open(image_path).convert("RGB")
        image = self.image_transform(image)

        # Load and process video
        video, _, _ = read_video(video_path, pts_unit='sec')  # (T, H, W, C)

        # Temporal sampling
        total_frames = video.shape[0]
        frame_indices = torch.linspace(0, total_frames-1,
                                     steps=self.num_frames).long()
        video = video[frame_indices]

        # Convert and transform frames
        frames = []
        for frame in video:
            frame = Image.fromarray(frame.numpy())
            frames.append(self.video_transform(frame))
        video = torch.stack(frames)  # (T, C, H, W)

        return {
            "folder_name": folder_name,
            "conditioning_image": image,  # (C, H, W)
            "target_frames": video,       # (T, C, H, W)
            "image_path": image_path,
            "video_path": video_path
        }

    def get_sample_metadata(self, idx: int) -> Dict:
        """Get metadata without loading actual media"""
        folder_name, image_path, video_path = self.samples[idx]
        return {
            "folder_name": folder_name,
            "image_path": image_path,
            "video_path": video_path
        }

In [None]:
import os
import cv2
import torch
import numpy as np
from tqdm import tqdm
from pathlib import Path
from typing import Union, List

def create_animediff_dataset(
    video_dirs: Union[str, List[str]],
    output_root: str,
    target_fps: int = 16,
    duration_seconds: int = 4,
    image_size: Tuple[int, int] = (512, 512),
    max_samples: int = None
) -> AnimateDiffDataset:
    """
    Process video directories to create AnimateDiff-compatible dataset
    Returns initialized dataset ready for training

    Args:
        video_dirs: Directories containing input videos
        output_root: Where to save processed samples
        target_fps: Frames per second for output videos
        duration_seconds: Length of video clips to extract (4 seconds)
        image_size: Resolution for resizing
        max_samples: Maximum number of samples to process (None for all)
    """
    # Create output structure
    output_root = Path(output_root)
    output_root.mkdir(parents=True, exist_ok=True)

    # Collect video files
    video_paths = []
    if isinstance(video_dirs, str):
        video_dirs = [video_dirs]

    for dir_path in video_dirs:
        for root, _, files in os.walk(dir_path):
            for file in files:
                if file.lower().endswith(('.mp4', '.avi', '.mov')):
                    video_paths.append(Path(root) / file)

    # Process videos
    samples_created = 0
    for video_path in tqdm(video_paths, desc="Processing videos"):
        if max_samples and samples_created >= max_samples:
            break

        # Create sample directory
        sample_id = f"sample_{samples_created:06d}"
        sample_dir = output_root / sample_id
        sample_dir.mkdir(exist_ok=True)

        # Process video
        success = process_single_video(
            video_path=video_path,
            output_dir=sample_dir,
            target_fps=target_fps,
            duration_seconds=duration_seconds,
            image_size=image_size
        )

        if success:
            samples_created += 1

    print(f"Created {samples_created} valid samples")
    return AnimateDiffDataset(output_root, image_size=image_size)

def process_single_video(
    video_path: Path,
    output_dir: Path,
    target_fps: int,
    duration_seconds: int,
    image_size: Tuple[int, int]
) -> bool:
    """Process individual video into AnimateDiff format"""
    cap = cv2.VideoCapture(str(video_path))

    # Get video properties
    original_fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # Calculate required frames
    required_frames = target_fps * duration_seconds
    frame_step = max(1, int(original_fps / target_fps))

    # Skip videos that are too short
    if total_frames < required_frames * frame_step:
        return False

    # Read and process frames
    frames = []
    frame_count = 0
    success = True

    while success and len(frames) < required_frames:
        success, frame = cap.read()
        if not success:
            break

        if frame_count % frame_step == 0:
            # Convert BGR to RGB and resize
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = cv2.resize(frame, image_size)
            frames.append(frame)

        frame_count += 1

    cap.release()

    # Save first frame as image
    if len(frames) < required_frames:
        return False

    image = Image.fromarray(frames[0])
    image.save(output_dir / "image.jpg")

    # Save video clip
    video_array = np.array(frames[:required_frames])
    save_video_as_frames(video_array, output_dir / "video.mp4", target_fps)

    return True

def save_video_as_frames(frames: np.ndarray, output_path: Path, fps: int):
    """Save numpy array as video file"""
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    h, w = frames[0].shape[:2]
    writer = cv2.VideoWriter(
        str(output_path),
        fourcc,
        fps,
        (w, h)

    for frame in frames:
        writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
    writer.release()

In [None]:
# Process raw videos and create dataset
dataset = create_animediff_dataset(
    video_dirs=["path/to/video_folder1", "path/to/video_folder2"],
    output_root="processed_dataset",
    target_fps=16,
    duration_seconds=4,
    image_size=(512, 512),
    max_samples=1000  # Optional: limit number of samples
)

# Create DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

# Training loop example
for batch in dataloader:
    conditioning_images = batch["conditioning_image"].to(device)  # (B, C, H, W)
    target_frames = batch["target_frames"].to(device)            # (B, T, C, H, W)

    # Forward pass through model
    outputs = model(conditioning_images, target_frames)

    # Compute loss and backpropagate
    loss = criterion(outputs, target_frames)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
# Verify dataset structure
def verify_dataset(dataset: AnimateDiffDataset):
    for i in range(len(dataset)):
        sample = dataset[i]
        metadata = dataset.get_sample_metadata(i)

        assert sample["conditioning_image"].shape == (3, 512, 512)
        assert sample["target_frames"].shape == (16, 3, 512, 512)
        assert Path(metadata["image_path"]).exists()
        assert Path(metadata["video_path"]).exists()

    print("Dataset verification passed!")

verify_dataset(dataset)

In [None]:
pretrained_model_path = "path/to/pretrained/animtediff/weights"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_frames = 16 * 4
height, width = 512, 512

tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
unet = UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet")
scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")

pipeline = AnimationPipeline(
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=unet,
    scheduler=scheduler,
).to(device)

In [None]:
class VideoGenerationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.video_files = [f for f in os.listdir(root_dir) if f.endswith('.mp4')]

    def __len__(self):
        return len(self.video_files)

    def __getitem__(self, idx):
        video_path = os.path.join(self.root_dir, self.video_files[idx])
        frames = load_and_process_video(video_path)  # Shape of [T, H, W, C]

        # Extract first frame as conditioning image
        conditioning_image = frames[0]
        target_frames = frames

        if self.transform:
            conditioning_image = self.transform(conditioning_image)
            target_frames = torch.stack([self.transform(f) for f in target_frames])

        return {
            "conditioning_image": conditioning_image,
            "target_frames": target_frames,
            "prompt": ""  # Add text prompts if available
        }

def load_and_process_video(path):
    # Implement video loading logic (e.g., using decord or torchvision)
    # Return tensor of shape [num_frames, height, width, channels]
    pass

In [None]:
def train(model, dataset, epochs=10, batch_size=2):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    optimizer = torch.optim.AdamW(model.unet.parameters(), lr=1e-5)

    for epoch in range(epochs):
        for batch in dataloader:
            conditioning_images = batch["conditioning_image"].to(device)
            target_frames = batch["target_frames"].to(device)

            latents = vae.encode(target_frames).latent_dist.sample()
            latents = latents * 0.18215  # Scaling factor

            # Sample noise
            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, scheduler.num_train_timesteps, (batch_size,), device=device).long()

            # Add noise to latents
            noisy_latents = scheduler.add_noise(latents, noise, timesteps)

            # Forward pass
            model_output = model.unet(
                noisy_latents,
                timesteps,
                encoder_hidden_states=text_encoder(batch["prompt"])[0],
                conditioning_images=conditioning_images
            ).sample

            # Calculate loss
            loss = torch.nn.functional.mse_loss(model_output, noise)

            # Backward pass
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

In [None]:
def generate_video_from_image(pipeline, image, prompt="", num_frames=64):
    # Preprocess input image
    image = transforms.ToTensor()(image).unsqueeze(0).to(device)
    image = transforms.Resize((height, width))(image)

    # Run inference
    with torch.no_grad():
        frames = pipeline(
            prompt=prompt,
            conditioning_image=image,
            num_frames=num_frames,
            guidance_scale=7.5,
            num_inference_steps=50
        ).video

    # Post-process output
    frames = rearrange(frames[0], "c t h w -> t h w c").cpu().numpy()
    return (frames * 255).astype(np.uint8)

In [None]:
# Save entire pipeline
pipeline.save_pretrained("path/to/save/model")

# Load saved model
pipeline = AnimationPipeline.from_pretrained("path/to/saved/model").to(device)

In [None]:
pipeline = AnimationPipeline(...).to(device)

# Load or fine-tune
# train(pipeline, dataset)  # Uncomment for fine-tuning

# Generate video from image
from PIL import Image

input_image = Image.open("input.jpg")
generated_video = generate_video_from_image(
    pipeline=pipeline,
    image=input_image,
    prompt="high quality, cinematic, 4K resolution",
    num_frames=64
)

# Save video
import imageio
imageio.mimwrite("output.mp4", generated_video, fps=16)