-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Add MAGI-1: Autoregressive Video Generation at Scale #11713
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…te attention mechanism accordingly. Updated initialization parameters and reshaping logic.
…tering and equal split ratio. Add utility functions for resizing and cropping images while preserving aspect ratio. Enhance 3D rotary positional embeddings Adds `center_grid_hw_indices` and `equal_split_ratio` parameters to the 3D rotary positional embedding function for more flexible configuration. The `center_grid_hw_indices` option centers the spatial grid indices around zero. The `equal_split_ratio` parameter provides an alternative way to divide the embedding dimension equally among the temporal and spatial axes. Updates the Magi1 VAE to utilize these new embedding features, introducing helper functions to prepare the embeddings dynamically based on input tensor dimensions.
Replaces the initial causal 3D convolution in the encoder with a standard `Conv3d` patch embedding layer. This simplifies the model and makes its input processing more consistent with Diffusion Transformer (DiT) architectures. Additionally, this change: - Removes the unused `Magi1CausalConv3d` class. - Updates the attention mechanism to use the standard `scaled_dot_product_attention`. - Sets the default for `sample_posterior` to `True` in the forward pass.
Removes the feature caching logic (`feat_cache`, `feat_idx`) from the encoder, decoder, and their sub-modules. This change significantly simplifies the forward pass implementation by removing stateful cache management. Additionally, this commit replaces the custom `Magi1RMS_norm` with a standard `nn.LayerNorm` and updates several custom causal convolution layers to use standard `nn.Linear` or `nn.Conv3d` layers.
Moves the positional embedding and dropout layers from the main autoencoder class into the decoder module. This improves encapsulation as the embedding is only used within the decoder. The decoder's forward pass is updated to apply the positional embedding and to remove the class token before the final output convolution. Additionally, `quant_conv` is renamed to `quant_linear` to accurately reflect the layer type.
Updates the `Magi1Decoder3d` from a convolutional design to a Transformer-like structure that operates on patches. This change replaces the initial convolutional and middle blocks with a linear projection layer, positional embeddings, and a class token. The logic for these components is moved from the parent `AutoencoderKLMagi1` model into the decoder for better encapsulation.
Removes several custom modules, including `Magi1ResidualBlock`, `Magi1Resample`, and `Magi1UpBlock`. Replaces the previous `Magi1MidBlock` with a more standard transformer-style `Magi1Block`. This change simplifies the overall VAE architecture by consolidating complex, specialized blocks into a more conventional design.
Replaces the custom `Magi1AttentionBlock` with the more generic `diffusers.Attention` module, combined with a new (?) `Magi1AttnProcessor2_0`. This change aligns the implementation with standard library patterns and leverages PyTorch 2.0's `scaled_dot_product_attention` for improved efficiency. The `Magi1Block` is also refactored into a more conventional transformer block structure using `Attention` and `FeedForward` modules.
Refactors the Magi1 VAE decoder to use a more standard transformer-based architecture. This change replaces the previous U-Net-like upsampling blocks with a series of standard transformer blocks, each containing self-attention and a feed-forward network. The custom rotary positional embedding logic and its helper functions have been removed, and the attention processor is simplified to work with the standard `Attention` module. This simplifies the overall model implementation.
Replaces the previous convolutional U-Net style encoder with a Vision Transformer (ViT) based implementation. This new architecture processes the input by dividing it into patches, adding positional embeddings, and then passing the sequence through a series of transformer blocks. The attention processor is also updated to support attention masks, and the model's configuration is adjusted to accommodate the new transformer-specific parameters.
Removes complex and unused parameters from the Magi1 VAE, encoder, and decoder modules. This change refactors the model to use a more standard Transformer architecture, eliminating the previous U-Net-like structure with dimension multipliers and residual blocks. The configuration is now more direct, improving clarity and maintainability.
Simplifies the initialization of the Magi1 VAE, encoder, and decoder. Reorders constructor parameters for clarity and removes unused arguments. The spatial and temporal compression ratios are now derived directly from the `patch_size` configuration, making the relationship more explicit. The pipeline is updated to use these new VAE attributes.
Simplifies the model architecture by removing the quantization and post-quantization convolution layers. This streamlines the `encode` and `decode` methods. The decoder is also updated to process the entire latent tensor at once, removing the previous frame-by-frame processing loop. Additionally, this change updates an import path for the `timm` library and renames an internal variable for consistency.
Updates the conversion script for the MAGI-1 VAE to correctly handle its Vision Transformer (ViT) based architecture. The state dictionary mapping is rewritten to align with the ViT structure. This includes adding logic to split the original checkpoint's combined QKV weights into separate query, key, and value tensors for the `diffusers` model. The model class and its configuration are also updated to reflect the appropriate ViT parameters, ensuring a correct conversion.
Renames the Magi autoencoder class to align with the "MAGI-1" model name. This refactoring improves consistency and clarity throughout the codebase, including documentation and tests.
Aligns the model naming with the source paper, "MAGI-1". This change refactors the model class, associated files, tests, and documentation to use the `Magi1` prefix for better clarity and consistency.
…ross multiple files
Improve compatibility by handling various PyTorch checkpoint formats. The loader now correctly extracts the state dictionary when it is nested under common keys like "model" or "state_dict". Ensure consistent loading of sharded safetensors by sorting the checkpoint files before merging them.
The test code is as follows. You can use any video for testing. from diffusers import AutoencoderKLMagi1
from diffusers.utils import export_to_video, load_video
import torch
from PIL import Image
import numpy as np
from torchvision import transforms
video_path = "curry_vs_thunder.mp4"
vae = AutoencoderKLMagi1.from_pretrained("sand-ai/MAGI-1", subfolder="ckpt/vae", torch_dtype=torch.bfloat16).to("cuda")
image_list = load_video(video_path)[:100]
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
frame_list = []
for image in image_list:
frame_list.append(transform(image).to("cuda"))
input_video = torch.stack(frame_list).to("cuda")
output_frame_list = []
with torch.no_grad() and torch.cuda.amp.autocast():
input_video = input_video.unsqueeze(0).transpose(1, 2)
output_video, posterior = vae(input_video)
output_video = output_video.squeeze(0)
for i in range(output_video.shape[1]):
output_frame = output_video[:, i, :, :].permute(1, 2, 0).cpu().detach().numpy()
# Denormalize from [-1, 1] to [0, 1]
output_frame = (output_frame + 1.0) * 0.5
# Convert to uint8
output_frame = (output_frame.clip(0, 1) * 255).astype(np.uint8)
output_frame = Image.fromarray(output_frame)
output_frame_list.append(output_frame)
export_to_video(output_frame_list, "output.mp4", fps=25) |
if self.use_rope:
assert feat_shape is not None
q, k, v = qkv.chunk(3, dim=2)
rope_emb = cache_rotary_emb(feat_shape=feat_shape, dim=C // self.num_heads, device=x.device, dtype=x.dtype)
sin_emb = rope_emb[0].unsqueeze(0).unsqueeze(2)
cos_emb = rope_emb[1].unsqueeze(0).unsqueeze(2)
print(q.shape, sin_emb.shape)
q[:, 1:, :] = apply_rot_embed(q[:, 1:, :], sin_emb, cos_emb).bfloat16()
k[:, 1:, :] = apply_rot_embed(k[:, 1:, :], sin_emb, cos_emb).bfloat16()
q, k, v = map(lambda t: t.squeeze(2).transpose(1, 2), (q, k, v))
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop_rate)
x = x.transpose(1, 2)
else:
q, k, v = map(lambda t: t.squeeze(2).transpose(1, 2), (q, k, v))
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop_rate)
x = x.transpose(1, 2) Solved. |
Thanks for the suggestions.
Right, I forgot to replace At first, I was trying to port positional embedding-related calculations, but then I realized that the shared models' vae doesn't use RoPE by default. Rather than Rather than Rather than In order to use I removed training-related parts such as I have also been trying to follow WDYT? |
I think it would be good to make the code follow the style of |
Dynamically resize positional embeddings in the encoder and decoder to match the input's latent shape using trilinear interpolation. This change allows the model to process inputs of varying height, width, and frame count in a single forward pass, removing the previous iterative encoding logic.
I was only focusing on the decoding part because inference is the priority. |
…toencoderKLMagi1.decode
Introduces a `ManualLayerNorm` class to provide an explicit, self-contained layer normalization. The VAE attention processor is updated to use this new manual normalization, removing the previous dependency.
Updates the `Magi1VAEAttnProcessor2_0` to remove the additive residual connection. This change simplifies the attention block's forward pass.
…i1VAEAttnProcessor2_0
…ndling in attention mechanism
Replaces the `trunc_normal_` function from the `timm` library with the equivalent `torch.nn.init.trunc_normal_`. This change removes an external dependency and also cleans up an unused import.
My modification to the original repo to compare fairlydiff --git a/inference/model/vae/vae_module.py b/inference/model/vae/vae_module.py
index eb4501b..777ee8b 100644
--- a/inference/model/vae/vae_module.py
+++ b/inference/model/vae/vae_module.py
@@ -19,8 +19,8 @@ from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
+import torch.nn.functional as F
from einops import rearrange
-from flash_attn import flash_attn_func, flash_attn_qkvpacked_func
from timm.models.layers import to_2tuple, trunc_normal_
###################################################
@@ -282,6 +282,10 @@ class Attention(nn.Module):
qkv = self.qkv_norm(qkv)
q, k, v = qkv.chunk(3, dim=2)
+ q = q.squeeze(2).transpose(1, 2) # B, num_heads, N, C // num_heads
+ k = k.squeeze(2).transpose(1, 2)
+ v = v.squeeze(2).transpose(1, 2)
+
if self.use_rope:
assert feat_shape is not None
q, k, v = qkv.chunk(3, dim=2)
@@ -291,11 +295,11 @@ class Attention(nn.Module):
print(q.shape, sin_emb.shape)
q[:, 1:, :] = apply_rot_embed(q[:, 1:, :], sin_emb, cos_emb).bfloat16()
k[:, 1:, :] = apply_rot_embed(k[:, 1:, :], sin_emb, cos_emb).bfloat16()
- x = flash_attn_func(q, k, v, dropout_p=self.attn_drop_rate)
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop_rate)
else:
- x = flash_attn_qkvpacked_func(qkv=qkv.bfloat16(), dropout_p=self.attn_drop_rate)
+ x = F.scaled_dot_product_attention(q.bfloat16(), k.bfloat16(), v.bfloat16(), dropout_p=self.attn_drop_rate)
# x = v
- x = x.reshape(B, N, C)
+ x = x.transpose(1, 2).reshape(B, N, C)
# import ipdb; ipdb.set_trace()
x = self.proj(x)
x = self.proj_drop(x) Comparison of the VAEs!pip uninstall diffusers -qy
!git clone https://github.com/tolgacangoz/diffusers.git
%cd diffusers
!git switch add-magi-1
!pip install -e . -q
import os, random, numpy as np, torch
def fix_seed(seed):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
fix_seed(0)
from diffusers import AutoencoderKLMagi1
from diffusers.utils import export_to_video, load_video
import torch
from PIL import Image
import numpy as np
from torchvision import transforms
video_path = "video1.mp4"
vae_diffusers = AutoencoderKLMagi1.from_pretrained("tolgacangoz/MAGI-1-T2V-4.5B-distill-Diffusers",
subfolder="vae",
torch_dtype=torch.bfloat16).to("cuda").eval()
image_list = load_video(video_path)[:16]
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),])
frame_list = []
for image in image_list:
frame_list.append(transform(image).to("cuda"))
input_video = torch.stack(frame_list).to("cuda")
output_frame_list = []
input_video = input_video.unsqueeze(0).to(torch.bfloat16).permute(0, 2, 1, 3, 4)
!git clone https://github.com/tolgacangoz/MAGI-1.git
%cd MAGI-1
!git switch diffusers
from inference.model.vae.vae_model import ViTVAE
import torch
vae_original = ViTVAE.from_pretrained('sand-ai/MAGI-1',
subfolder='ckpt/vae').eval().to('cuda', torch.bfloat16)
with torch.no_grad():
posterior = vae_diffusers.encode(input_video).latent_dist
z_diffusers = posterior.mode()
with torch.no_grad():
z_original = vae_original.encode(input_video, sample_posterior=False)
torch.equal(z_diffusers, z_original) # `True`
with torch.no_grad():
out_diffusers = vae_diffusers.decode(z_diffusers).sample
with torch.no_grad():
out_original = vae_original.decode(z_original)
torch.equal(out_diffusers, out_original) # `True` |
Refactors the MAGI-1 conversion script for clarity and correctness. This removes redundant conditional checks and ensures all necessary weights from the original checkpoint are properly mapped to the Diffusers format. The model definition is also updated to more accurately reflect the original architecture, including adjusting the default feed-forward network dimension and removing unnecessary biases from the patch embedding and final projection layers.
… state dict conversion
Introduces a `CaptionEmbedder` module to handle caption projections for both cross-attention and AdaLN conditioning. This replaces the previous time projection logic. The new module also adds support for classifier-free guidance by implementing caption dropout.
Applies a hyperbolic tangent function to the time embedding tensor before it is used for scale and shift adjustments.
…r time embedder dtype retrieval
This refactors the text conditioning logic in the Magi-1 transformer to align more closely with the PixArt-Alpha architecture. The separate `CaptionEmbedder` is removed, and the `PixArtAlphaTextProjection` is used directly on the text embeddings. The final layer normalization is updated from a standard `LayerNorm` to an `AdaLayerNorm`. This change simplifies the model by removing the explicit `scale_shift_table` and allows for direct conditioning. The conversion script is updated to correctly map the original checkpoint weights to the new architecture, including adding an identity mapping for a missing projection layer to ensure compatibility.
Thanks for the opportunity to fix #11519!
Original repo: https://github.com/SandAI-org/MAGI-1
AutoencoderKLMagi1
Magi1Transformer3DModel
MAGI-1-Diffusers
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.