In [None]:
!pip install diffusers transformers accelerate

Collecting diffusers
  Downloading diffusers-0.29.2-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
Collecting accelerate
  Downloading accelerate-0.31.0-py3-none-any.whl (309 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m309.4/309.4 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.10.0->accelerate)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.10.0->accelerate)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.10.0->accelerate)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.10.0->accelerate)
  Using cached nvidia_cudnn_cu

In [None]:
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.utils import export_to_video

# Initialize the pipeline
pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()

# Generate video frames
prompt = "dog walking"
video_frames = pipe(prompt, num_inference_steps=120).frames

# Debug: Check the shape of the first frame
print("Shape of the first frame:", video_frames[0].shape)

# Flatten the batch of frames
correct_frames = []
for batch in video_frames:
    if len(batch.shape) == 4 and batch.shape[0] > 1:  # Batch of frames
        for frame in batch:
            correct_frames.append(frame)
    elif len(batch.shape) == 3:  # Single frame
        correct_frames.append(batch)
    else:
        raise ValueError(f"Unexpected frame shape: {batch.shape}")

# Debug: Check the number of frames
print("Number of frames after flattening:", len(correct_frames))

# Export frames to video
video_path = export_to_video(correct_frames)

# Print video path
print('Video Path:', video_path)
