From c2dbf92f01dec9f1c8d8e98daf11b66f0a602b9d Mon Sep 17 00:00:00 2001 From: Tuna Tuncer Date: Sat, 25 Oct 2025 23:47:07 +0200 Subject: [PATCH 01/18] start implementing the __init__ and __call__ funcs of MagiPipeline --- .../pipelines/magi1/pipeline_magi1_t2v.py | 176 ++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py new file mode 100644 index 000000000000..84efcc2dfaf7 --- /dev/null +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py @@ -0,0 +1,176 @@ +# Copyright 2025 The SandAI Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import Magi1LoraLoaderMixin +from ...models import AutoencoderKLMagi1, Magi1Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline + + +# TODO: Write example_doc_string +EXAMPLE_DOC_STRING = """ + Examples: + ```python + ``` +""" + +class Magi1Pipeline(DiffusionPipeline, Magi1LoraLoaderMixin): + r""" + Pipeline for text-to-video generation using Magi1. + + Reference: https://github.com/SandAI-org/MAGI-1 + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`Magi1Transformer3DModel`]): + Conditional Transformer to denoise the input latents. + vae ([`AutoencoderKLMagi1`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A flow matching scheduler with Euler discretization, using SD3-style time resolution transform. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + #TODO: Add _callback_tensor_inputs and _optional_components? + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: Magi1Transformer3DModel, + vae: AutoencoderKLMagi1, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + vae=vae, + scheduler=scheduler, + ) + + # TODO: Add attributes + + # TODO: Fix default values of the parameters + # TODO: Double-check if all parameters are needed/included + # TODO: Double-check output type default (both in param and in docstring) + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = 720, + width: Optional[int] = 1280, + num_frames: int = 96, + num_inference_steps: int = 32, + guidance_scale: float = 7.5, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + max_sequence_length: int = 800, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `720`): + The height in pixels of the generated video. + width (`int`, defaults to `1280`): + The width in pixels of the generated video. + num_frames (`int`, defaults to `96`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `32`): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + guidance_scale (`float`, defaults to `7.5`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to + the text `prompt`, usually at the expense of lower video quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, negative_prompt_embeds will be generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between `"latent"`, `"pt"`, or `"np"`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`Magi1PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `800`): + The maximum sequence length for the text encoder. Sequences longer than this will be truncated. MAGI-1 + uses a max length of 800 tokens. + + Examples: + + Returns: + [`~Magi1PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`Magi1PipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated videos. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + From 3bd3b29e75b4f8c5b07ee34a8b39152ba86f1686 Mon Sep 17 00:00:00 2001 From: Tuna Tuncer Date: Sun, 26 Oct 2025 14:52:05 +0100 Subject: [PATCH 02/18] check_inputs --- .../pipelines/magi1/pipeline_magi1_t2v.py | 77 ++++++++++++++++++- 1 file changed, 75 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py index 84efcc2dfaf7..6795473cc2a9 100644 --- a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py @@ -76,6 +76,57 @@ def __init__( # TODO: Add attributes + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + """Checks the validity of the inputs.""" + + # Check prompt and prompt_embeds + if prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + # Check negative_prompt and negative_prompt_embeds + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + if negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + # Check height and width + # TODO: Why 16? + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + # Check callback_on_step_end_tensor_inputs + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + # TODO: Fix default values of the parameters # TODO: Double-check if all parameters are needed/included # TODO: Double-check output type default (both in param and in docstring) @@ -85,8 +136,8 @@ def __call__( self, prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, - height: Optional[int] = 720, - width: Optional[int] = 1280, + height: int = 720, + width: int = 1280, num_frames: int = 96, num_inference_steps: int = 32, guidance_scale: float = 7.5, @@ -101,6 +152,7 @@ def __call__( callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 800, ): r""" @@ -174,3 +226,24 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + # TODO: Come back here later + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] # TODO: Check if linter complains here + \ No newline at end of file From 9cdcda7cac870402c56c954cb7977559415a1df6 Mon Sep 17 00:00:00 2001 From: Tuna Tuncer Date: Sun, 26 Oct 2025 14:58:00 +0100 Subject: [PATCH 03/18] improve typing of check_inputs --- .../pipelines/magi1/pipeline_magi1_t2v.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py index 6795473cc2a9..a20e41245a40 100644 --- a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py @@ -79,13 +79,13 @@ def __init__( def check_inputs( self, - prompt, - negative_prompt, - height, - width, - prompt_embeds=None, - negative_prompt_embeds=None, - callback_on_step_end_tensor_inputs=None, + prompt: Optional[Union[str, List[str]]], + negative_prompt: Optional[Union[str, List[str]]], + height: int, + width: int, + prompt_embeds: Optional[torch.Tensor], + negative_prompt_embeds: Optional[torch.Tensor], + callback_on_step_end_tensor_inputs: List[str], ): """Checks the validity of the inputs.""" From bc512daa623caf587b3d7827c8f49c4c2ea7b1b0 Mon Sep 17 00:00:00 2001 From: Tuna Tuncer Date: Sun, 26 Oct 2025 15:25:27 +0100 Subject: [PATCH 04/18] encode_prompt --- .../pipelines/magi1/pipeline_magi1_t2v.py | 106 +++++++++++++++++- 1 file changed, 102 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py index a20e41245a40..f5f244e9ae3f 100644 --- a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py @@ -77,6 +77,91 @@ def __init__( # TODO: Add attributes + def encode_prompt( + self, + prompt: Optional[Union[str, List[str]]], + negative_prompt: Optional[Union[str, List[str]]], + do_classifier_free_guidance: bool, + num_videos_per_prompt: int, + prompt_embeds: Optional[torch.Tensor], + negative_prompt_embeds: Optional[torch.Tensor], + max_sequence_length: int, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + ): + r"""Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + # TODO: Can we provide different prompts for different chunks? + # If so, how are we gonna support that? + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + else: + prompt_mask = None + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + else: + negative_mask = None + return prompt_embeds, negative_prompt_embeds, prompt_mask, negative_mask + def check_inputs( self, prompt: Optional[Union[str, List[str]]], @@ -87,7 +172,7 @@ def check_inputs( negative_prompt_embeds: Optional[torch.Tensor], callback_on_step_end_tensor_inputs: List[str], ): - """Checks the validity of the inputs.""" + r"""Checks the validity of the inputs.""" # Check prompt and prompt_embeds if prompt is None and prompt_embeds is None: @@ -141,7 +226,7 @@ def __call__( num_frames: int = 96, num_inference_steps: int = 32, guidance_scale: float = 7.5, - num_videos_per_prompt: Optional[int] = 1, + num_videos_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, @@ -179,7 +264,7 @@ def __call__( of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, usually at the expense of lower video quality. - num_videos_per_prompt (`int`, *optional*, defaults to 1): + num_videos_per_prompt (`int`, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make @@ -246,4 +331,17 @@ def __call__( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # TODO: Check if linter complains here - \ No newline at end of file + + device = self._execution_device + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds, prompt_mask, negative_mask = self.encode_prompt( + prompt, + negative_prompt, + self.do_classifier_free_guidance, + num_videos_per_prompt, + prompt_embeds, + negative_prompt_embeds, + max_sequence_length, + device, + self.text_encoder.dtype + ) \ No newline at end of file From 67e9bdc47e20a02377fc7c6ab68db23e3be7cb86 Mon Sep 17 00:00:00 2001 From: Tuna Tuncer Date: Sun, 26 Oct 2025 15:41:04 +0100 Subject: [PATCH 05/18] _get_t5_prompt_embeds --- .../pipelines/magi1/pipeline_magi1_t2v.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py index f5f244e9ae3f..e3853238cc35 100644 --- a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import html +import re from typing import Any, Callable, Dict, List, Optional, Union +import ftfy import torch from transformers import AutoTokenizer, UMT5EncoderModel @@ -25,6 +28,23 @@ from ..pipeline_utils import DiffusionPipeline + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + # TODO: Write example_doc_string EXAMPLE_DOC_STRING = """ Examples: @@ -77,6 +97,53 @@ def __init__( # TODO: Add attributes + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int, + max_sequence_length: int, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + ): + # TODO: Double check if MAGI-1 does some special handling during prompt encoding + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + # TODO: Debug if repeating mask is necessary because it's not used in any other pipeline + # Repeat mask the same way as embeddings and keep size [B*num, L] + mask = mask.repeat(1, num_videos_per_prompt) + mask = mask.view(batch_size * num_videos_per_prompt, -1).to(device) + + return prompt_embeds, mask + def encode_prompt( self, prompt: Optional[Union[str, List[str]]], From 5041c67f2bf7ae500eab429927a1aa351293dcdf Mon Sep 17 00:00:00 2001 From: Tuna Tuncer Date: Tue, 28 Oct 2025 00:32:14 +0100 Subject: [PATCH 06/18] use T5EncoderModel --- src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py index e3853238cc35..b9bb5af98e5a 100644 --- a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py @@ -18,7 +18,7 @@ import ftfy import torch -from transformers import AutoTokenizer, UMT5EncoderModel +from transformers import AutoTokenizer, T5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import Magi1LoraLoaderMixin @@ -61,10 +61,10 @@ class Magi1Pipeline(DiffusionPipeline, Magi1LoraLoaderMixin): Args: tokenizer ([`T5Tokenizer`]): Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), - specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + specifically the [DeepFloyd/t5-v1_1-xxl](https://huggingface.co/DeepFloyd/t5-v1_1-xxl) variant. text_encoder ([`T5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically - the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + the [DeepFloyd/t5-v1_1-xxl](https://huggingface.co/DeepFloyd/t5-v1_1-xxl) variant. transformer ([`Magi1Transformer3DModel`]): Conditional Transformer to denoise the input latents. vae ([`AutoencoderKLMagi1`]): @@ -79,7 +79,7 @@ class Magi1Pipeline(DiffusionPipeline, Magi1LoraLoaderMixin): def __init__( self, tokenizer: AutoTokenizer, - text_encoder: UMT5EncoderModel, + text_encoder: T5EncoderModel, transformer: Magi1Transformer3DModel, vae: AutoencoderKLMagi1, scheduler: FlowMatchEulerDiscreteScheduler, From e9ea7a6015e2d8bdca69c129c61cb8e17c3b2b8d Mon Sep 17 00:00:00 2001 From: Tuna Tuncer Date: Tue, 28 Oct 2025 02:36:32 +0100 Subject: [PATCH 07/18] handle special tokens and negative prompt embeddings --- .../pipelines/magi1/pipeline_magi1_t2v.py | 260 ++++++++++++++---- 1 file changed, 213 insertions(+), 47 deletions(-) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py index b9bb5af98e5a..c7366e4d4c2c 100644 --- a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py @@ -13,10 +13,13 @@ # limitations under the License. import html +import math +import os import re -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import ftfy +import numpy as np import torch from transformers import AutoTokenizer, T5EncoderModel @@ -28,6 +31,101 @@ from ..pipeline_utils import DiffusionPipeline +SPECIAL_TOKEN_PATH = os.getenv("SPECIAL_TOKEN_PATH", "example/assets/special_tokens.npz") +SPECIAL_TOKEN = np.load(SPECIAL_TOKEN_PATH) +CAPTION_TOKEN = torch.tensor(SPECIAL_TOKEN["caption_token"].astype(np.float16)) +LOGO_TOKEN = torch.tensor(SPECIAL_TOKEN["logo_token"].astype(np.float16)) +TRANS_TOKEN = torch.tensor(SPECIAL_TOKEN["other_tokens"][:1].astype(np.float16)) +HQ_TOKEN = torch.tensor(SPECIAL_TOKEN["other_tokens"][1:2].astype(np.float16)) +STATIC_FIRST_FRAMES_TOKEN = torch.tensor(SPECIAL_TOKEN["other_tokens"][2:3].astype(np.float16)) # static first frames +DYNAMIC_FIRST_FRAMES_TOKEN = torch.tensor(SPECIAL_TOKEN["other_tokens"][3:4].astype(np.float16)) # dynamic first frames +BORDERNESS_TOKEN = torch.tensor(SPECIAL_TOKEN["other_tokens"][4:5].astype(np.float16)) +DURATION_TOKEN_LIST = [torch.tensor(SPECIAL_TOKEN["other_tokens"][i : i + 1].astype(np.float16)) for i in range(0 + 7, 8 + 7)] +THREE_D_MODEL_TOKEN = torch.tensor(SPECIAL_TOKEN["other_tokens"][15:16].astype(np.float16)) +TWO_D_ANIME_TOKEN = torch.tensor(SPECIAL_TOKEN["other_tokens"][16:17].astype(np.float16)) + + + +SPECIAL_TOKEN_DICT = { + "CAPTION_TOKEN": CAPTION_TOKEN, + "LOGO_TOKEN": LOGO_TOKEN, + "TRANS_TOKEN": TRANS_TOKEN, + "HQ_TOKEN": HQ_TOKEN, + "STATIC_FIRST_FRAMES_TOKEN": STATIC_FIRST_FRAMES_TOKEN, + "DYNAMIC_FIRST_FRAMES_TOKEN": DYNAMIC_FIRST_FRAMES_TOKEN, + "BORDERNESS_TOKEN": BORDERNESS_TOKEN, + "THREE_D_MODEL_TOKEN": THREE_D_MODEL_TOKEN, + "TWO_D_ANIME_TOKEN": TWO_D_ANIME_TOKEN, +} + +def _pad_special_token(special_token: torch.Tensor, txt_feat: torch.Tensor, attn_mask: torch.Tensor = None): + _device = txt_feat.device + _dtype = txt_feat.dtype + N, C, _, D = txt_feat.size() + txt_feat = torch.cat( + [special_token.unsqueeze(0).unsqueeze(0).to(_device).to(_dtype).expand(N, C, -1, D), txt_feat], dim=2 + )[:, :, :800, :] + if attn_mask is not None: + attn_mask = torch.cat([torch.ones(N, C, 1, dtype=_dtype, device=_device), attn_mask], dim=-1)[:, :, :800] + return txt_feat, attn_mask + + + +def pad_special_token(special_token_keys: List[str], caption_embs: torch.Tensor, emb_masks: torch.Tensor): + device = caption_embs.device + for special_token_key in special_token_keys: + if special_token_key == "DURATION_TOKEN": + new_caption_embs, new_emb_masks = [], [] + num_chunks = caption_embs.size(1) + for i in range(num_chunks): + chunk_caption_embs, chunk_emb_masks = _pad_special_token( + DURATION_TOKEN_LIST[min(num_chunks - i - 1, 7)].to(device), + caption_embs[:, i : i + 1], + emb_masks[:, i : i + 1], + ) + new_caption_embs.append(chunk_caption_embs) + new_emb_masks.append(chunk_emb_masks) + caption_embs = torch.cat(new_caption_embs, dim=1) + emb_masks = torch.cat(new_emb_masks, dim=1) + else: + special_token = SPECIAL_TOKEN_DICT.get(special_token_key) + if special_token is not None: + caption_embs, emb_masks = _pad_special_token(special_token.to(device), caption_embs, emb_masks) + return caption_embs, emb_masks + +def get_special_token_keys( + use_static_first_frames_token: bool, + use_dynamic_first_frames_token: bool, + use_borderness_token: bool, + use_hq_token: bool, + use_three_d_model_token: bool, + use_two_d_anime_token: bool, + use_duration_token: bool, +): + special_token_keys = [] + if use_static_first_frames_token: + special_token_keys.append("STATIC_FIRST_FRAMES_TOKEN") + if use_dynamic_first_frames_token: + special_token_keys.append("DYNAMIC_FIRST_FRAMES_TOKEN") + if use_borderness_token: + special_token_keys.append("BORDERNESS_TOKEN") + if use_hq_token: + special_token_keys.append("HQ_TOKEN") + if use_three_d_model_token: + special_token_keys.append("THREE_D_MODEL_TOKEN") + if use_two_d_anime_token: + special_token_keys.append("TWO_D_ANIME_TOKEN") + if use_duration_token: + special_token_keys.append("DURATION_TOKEN") + return special_token_keys + +def get_negative_special_token_keys( + use_negative_special_tokens: bool, +): + if use_negative_special_tokens: + return ["CAPTION_TOKEN", "LOGO_TOKEN", "TRANS_TOKEN", "BORDERNESS_TOKEN"] + return [] + def basic_clean(text): text = ftfy.fix_text(text) @@ -94,6 +192,8 @@ def __init__( scheduler=scheduler, ) + self.temporal_downscale_factor = 4 # TODO: Double check this value + self.chunk_width = 6 # TODO: Double check this value # TODO: Add attributes @@ -110,6 +210,7 @@ def _get_t5_prompt_embeds( dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt + # Just keep the clean function consistent with other pipelines prompt = [prompt_clean(u) for u in prompt] batch_size = len(prompt) @@ -127,21 +228,22 @@ def _get_t5_prompt_embeds( prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 - ) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - - # TODO: Debug if repeating mask is necessary because it's not used in any other pipeline - # Repeat mask the same way as embeddings and keep size [B*num, L] - mask = mask.repeat(1, num_videos_per_prompt) - mask = mask.view(batch_size * num_videos_per_prompt, -1).to(device) - + # TODO: IDK why we need the code below, seems redundant to me, double check later. + # prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + # prompt_embeds = torch.stack( + # [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + # ) + + # # duplicate text embeddings for each generation per prompt, using mps friendly method + # _, seq_len, _ = prompt_embeds.shape + # prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + # prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + # # TODO: Debug if repeating mask is necessary because it's not used in any other pipeline + # # Repeat mask the same way as embeddings and keep size [B*num, L] + # mask = mask.repeat(1, num_videos_per_prompt) + # mask = mask.view(batch_size * num_videos_per_prompt, -1).to(device) + prompt_embeds = prompt_embeds.float() return prompt_embeds, mask def encode_prompt( @@ -151,11 +253,13 @@ def encode_prompt( do_classifier_free_guidance: bool, num_videos_per_prompt: int, prompt_embeds: Optional[torch.Tensor], + prompt_mask: Optional[torch.Tensor], negative_prompt_embeds: Optional[torch.Tensor], + negative_prompt_mask: Optional[torch.Tensor], max_sequence_length: int, device: Optional[torch.device], dtype: Optional[torch.dtype], - ): + ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Encodes the prompt into text encoder hidden states. Args: @@ -172,10 +276,14 @@ def encode_prompt( prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. + prompt_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. device: (`torch.device`, *optional*): torch device dtype: (`torch.dtype`, *optional*): @@ -201,33 +309,7 @@ def encode_prompt( ) else: prompt_mask = None - - if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - negative_prompt_embeds, negative_mask = self._get_t5_prompt_embeds( - prompt=negative_prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - else: - negative_mask = None - return prompt_embeds, negative_prompt_embeds, prompt_mask, negative_mask + return prompt_embeds, prompt_mask def check_inputs( self, @@ -236,7 +318,9 @@ def check_inputs( height: int, width: int, prompt_embeds: Optional[torch.Tensor], + prompt_mask: Optional[torch.Tensor], negative_prompt_embeds: Optional[torch.Tensor], + negative_prompt_mask: Optional[torch.Tensor], callback_on_step_end_tensor_inputs: List[str], ): r"""Checks the validity of the inputs.""" @@ -254,6 +338,12 @@ def check_inputs( if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + # Check prompt_embeds and prompt_mask + if prompt_embeds is not None and prompt_mask is None: + raise ValueError("Must provide `prompt_mask` when specifying `prompt_embeds`.") + if prompt_embeds is None and prompt_mask is not None: + raise ValueError("Must provide `prompt_embeds` when specifying `prompt_mask`.") + # Check negative_prompt and negative_prompt_embeds if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -265,6 +355,28 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + # Check negative_prompt_embeds and negative_prompt_mask + if negative_prompt_embeds is not None and negative_prompt_mask is None: + raise ValueError("Must provide `negative_prompt_mask` when specifying `negative_prompt_embeds`.") + if negative_prompt_embeds is None and negative_prompt_mask is not None: + raise ValueError("Must provide `negative_prompt_embeds` when specifying `negative_prompt_mask`.") + + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_mask is not None and negative_prompt_mask is not None: + if prompt_mask.shape != negative_prompt_mask.shape: + raise ValueError( + "`prompt_mask` and `negative_prompt_mask` must have the same shape when passed directly, but" + f" got: `prompt_mask` {prompt_mask.shape} != `negative_prompt_mask`" + f" {negative_prompt_mask.shape}." + ) + # Check height and width # TODO: Why 16? if height % 16 != 0 or width % 16 != 0: @@ -278,6 +390,9 @@ def check_inputs( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 # TODO: Fix default values of the parameters # TODO: Double-check if all parameters are needed/included @@ -297,7 +412,9 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, + prompt_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -306,6 +423,14 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 800, + use_static_first_frames_token: bool = False, + use_dynamic_first_frames_token: bool = False, + use_borderness_token: bool = False, + use_hq_token: bool = False, + use_three_d_model_token: bool = False, + use_two_d_anime_token: bool = False, + use_duration_token: bool = False, + use_negative_special_tokens: bool = False, ): r""" The call function to the pipeline for generation. @@ -343,9 +468,13 @@ def __call__( prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. + prompt_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, negative_prompt_embeds will be generated from the `negative_prompt` input argument. + negative_prompt_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. output_type (`str`, *optional*, defaults to `"np"`): The output format of the generated video. Choose between `"latent"`, `"pt"`, or `"np"`. return_dict (`bool`, *optional*, defaults to `True`): @@ -389,6 +518,7 @@ def __call__( callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale # TODO: Come back here later # 2. Define call parameters @@ -401,14 +531,50 @@ def __call__( device = self._execution_device # 3. Encode input prompt - prompt_embeds, negative_prompt_embeds, prompt_mask, negative_mask = self.encode_prompt( + prompt_embeds, prompt_mask = self.encode_prompt( prompt, negative_prompt, self.do_classifier_free_guidance, num_videos_per_prompt, prompt_embeds, + prompt_mask, negative_prompt_embeds, + negative_prompt_mask, max_sequence_length, device, - self.text_encoder.dtype - ) \ No newline at end of file + self.text_encoder.dtype # TODO: double check what is passed here + ) + + num_infer_chunks = math.ceil((num_frames // self.temporal_downscale_factor) / self.chunk_width) + prompt_embeds = prompt_embeds.unsqueeze(1).repeat(1, num_infer_chunks, 1, 1) + prompt_mask = prompt_mask.unsqueeze(1).repeat(1, num_infer_chunks, 1) + special_token_keys = get_special_token_keys( + use_static_first_frames_token=use_static_first_frames_token, + use_dynamic_first_frames_token=use_dynamic_first_frames_token, + use_borderness_token=use_borderness_token, + use_hq_token=use_hq_token, + use_three_d_model_token=use_three_d_model_token, + use_two_d_anime_token=use_two_d_anime_token, + use_duration_token=use_duration_token) + prompt_embeds, prompt_mask = pad_special_token(special_token_keys, prompt_embeds, prompt_mask) + if self.do_classifier_free_guidance: + if negative_prompt_embeds is None: + # TODO: Load negative prompt embeds, they are learned + # null_caption_embedding = model.y_embedder.null_caption_embedding.unsqueeze(0) + # Creating zeros for negative prompt embeds for now + negative_prompt_embeds = torch.zeros(prompt_embeds.size(0), prompt_embeds.size(2), prompt_embeds.size(3)).to(prompt_embeds.device) + negative_mask = torch.zeros_like(prompt_mask) + negative_prompt_embeds = negative_prompt_embeds.unsqueeze(1).repeat(1, num_infer_chunks, 1, 1) + special_negative_token_keys = get_negative_special_token_keys( + use_negative_special_tokens=use_negative_special_tokens, + ) + negative_prompt_embeds, _ = pad_special_token(special_negative_token_keys, negative_prompt_embeds, None) + negative_token_length = 50 + negative_mask[:, :, :negative_token_length] = 1 + negative_mask[:, :, negative_token_length:] = 0 + if prompt_mask.sum() == 0: + prompt_embeds = torch.cat([negative_prompt_embeds, negative_prompt_embeds]) + prompt_mask = torch.cat([negative_mask, negative_mask], dim=0) + else: + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds]) + prompt_mask = torch.cat([prompt_mask, negative_mask], dim=0) \ No newline at end of file From f7ed43d7a16bc4e1c1c4e8b457e6bea2ad052c16 Mon Sep 17 00:00:00 2001 From: Tuna Tuncer Date: Tue, 28 Oct 2025 02:40:28 +0100 Subject: [PATCH 08/18] fix input params to check_inputs --- src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py index c7366e4d4c2c..4b67811a315d 100644 --- a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py @@ -514,7 +514,9 @@ def __call__( height, width, prompt_embeds, + prompt_mask, negative_prompt_embeds, + negative_prompt_mask, callback_on_step_end_tensor_inputs, ) From f1dc11c10c26d8db15b2643eb9e1016b6d5d5142 Mon Sep 17 00:00:00 2001 From: Tuna Tuncer Date: Tue, 28 Oct 2025 12:38:38 +0100 Subject: [PATCH 09/18] fix default value for special tokens --- src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py index 4b67811a315d..2be5050eb6ec 100644 --- a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py @@ -426,10 +426,10 @@ def __call__( use_static_first_frames_token: bool = False, use_dynamic_first_frames_token: bool = False, use_borderness_token: bool = False, - use_hq_token: bool = False, + use_hq_token: bool = True, use_three_d_model_token: bool = False, use_two_d_anime_token: bool = False, - use_duration_token: bool = False, + use_duration_token: bool = True, use_negative_special_tokens: bool = False, ): r""" From 35d5f8945bd24fb3e41e318a40b5c18f60215750 Mon Sep 17 00:00:00 2001 From: Tuna Tuncer Date: Tue, 28 Oct 2025 12:41:12 +0100 Subject: [PATCH 10/18] update special token naming --- .../pipelines/magi1/pipeline_magi1_t2v.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py index 2be5050eb6ec..47c38f7b280a 100644 --- a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py @@ -98,8 +98,8 @@ def get_special_token_keys( use_dynamic_first_frames_token: bool, use_borderness_token: bool, use_hq_token: bool, - use_three_d_model_token: bool, - use_two_d_anime_token: bool, + use_3d_model_token: bool, + use_2d_anime_token: bool, use_duration_token: bool, ): special_token_keys = [] @@ -111,9 +111,9 @@ def get_special_token_keys( special_token_keys.append("BORDERNESS_TOKEN") if use_hq_token: special_token_keys.append("HQ_TOKEN") - if use_three_d_model_token: + if use_3d_model_token: special_token_keys.append("THREE_D_MODEL_TOKEN") - if use_two_d_anime_token: + if use_2d_anime_token: special_token_keys.append("TWO_D_ANIME_TOKEN") if use_duration_token: special_token_keys.append("DURATION_TOKEN") @@ -427,8 +427,8 @@ def __call__( use_dynamic_first_frames_token: bool = False, use_borderness_token: bool = False, use_hq_token: bool = True, - use_three_d_model_token: bool = False, - use_two_d_anime_token: bool = False, + use_3d_model_token: bool = False, + use_2d_anime_token: bool = False, use_duration_token: bool = True, use_negative_special_tokens: bool = False, ): @@ -555,8 +555,8 @@ def __call__( use_dynamic_first_frames_token=use_dynamic_first_frames_token, use_borderness_token=use_borderness_token, use_hq_token=use_hq_token, - use_three_d_model_token=use_three_d_model_token, - use_two_d_anime_token=use_two_d_anime_token, + use_3d_model_token=use_3d_model_token, + use_2d_anime_token=use_2d_anime_token, use_duration_token=use_duration_token) prompt_embeds, prompt_mask = pad_special_token(special_token_keys, prompt_embeds, prompt_mask) if self.do_classifier_free_guidance: From d34ed67301db08e8603d67f3a9a98fb98bdc9bdf Mon Sep 17 00:00:00 2001 From: Tuna Tuncer Date: Wed, 29 Oct 2025 05:08:07 +0100 Subject: [PATCH 11/18] fix bugs --- .../pipelines/magi1/pipeline_magi1_t2v.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py index 47c38f7b280a..4445049a6f13 100644 --- a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py @@ -59,14 +59,12 @@ } def _pad_special_token(special_token: torch.Tensor, txt_feat: torch.Tensor, attn_mask: torch.Tensor = None): - _device = txt_feat.device - _dtype = txt_feat.dtype N, C, _, D = txt_feat.size() txt_feat = torch.cat( - [special_token.unsqueeze(0).unsqueeze(0).to(_device).to(_dtype).expand(N, C, -1, D), txt_feat], dim=2 + [special_token.unsqueeze(0).unsqueeze(0).to(txt_feat.device).to(txt_feat.dtype).expand(N, C, -1, D), txt_feat], dim=2 )[:, :, :800, :] if attn_mask is not None: - attn_mask = torch.cat([torch.ones(N, C, 1, dtype=_dtype, device=_device), attn_mask], dim=-1)[:, :, :800] + attn_mask = torch.cat([torch.ones(N, C, 1, dtype=attn_mask.dtype, device=attn_mask.device), attn_mask], dim=-1)[:, :, :800] return txt_feat, attn_mask @@ -192,8 +190,9 @@ def __init__( scheduler=scheduler, ) - self.temporal_downscale_factor = 4 # TODO: Double check this value + self.temporal_downscale_factor = 4 # TODO: Read it from model (vae) config self.chunk_width = 6 # TODO: Double check this value + self._callback_tensor_inputs = ["latents"] # extend as needed # TODO: Add attributes @@ -228,21 +227,21 @@ def _get_t5_prompt_embeds( prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - # TODO: IDK why we need the code below, seems redundant to me, double check later. - # prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] - # prompt_embeds = torch.stack( - # [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 - # ) - - # # duplicate text embeddings for each generation per prompt, using mps friendly method - # _, seq_len, _ = prompt_embeds.shape - # prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - # prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - - # # TODO: Debug if repeating mask is necessary because it's not used in any other pipeline - # # Repeat mask the same way as embeddings and keep size [B*num, L] - # mask = mask.repeat(1, num_videos_per_prompt) - # mask = mask.view(batch_size * num_videos_per_prompt, -1).to(device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + # TODO: Debug if repeating mask is necessary because it's not used in any other pipeline + # Repeat mask the same way as embeddings and keep size [B*num, L] + mask = mask.repeat(1, num_videos_per_prompt) + mask = mask.view(batch_size * num_videos_per_prompt, -1).to(device) + # TODO: I think prompt_embeds are already float32, but double check prompt_embeds = prompt_embeds.float() return prompt_embeds, mask @@ -309,6 +308,7 @@ def encode_prompt( ) else: prompt_mask = None + # TODO: Also handle if negative prompt is provided (though the default is learned embeddings in MAGI-1) return prompt_embeds, prompt_mask def check_inputs( From 594858ed0cefd2c146ea6d996cf63c6e9c92654a Mon Sep 17 00:00:00 2001 From: Tuna Tuncer Date: Wed, 29 Oct 2025 05:22:13 +0100 Subject: [PATCH 12/18] support negative prompts --- .../pipelines/magi1/pipeline_magi1_t2v.py | 40 ++++++++++++++----- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py index 4445049a6f13..ecf626eeaf76 100644 --- a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py @@ -249,7 +249,6 @@ def encode_prompt( self, prompt: Optional[Union[str, List[str]]], negative_prompt: Optional[Union[str, List[str]]], - do_classifier_free_guidance: bool, num_videos_per_prompt: int, prompt_embeds: Optional[torch.Tensor], prompt_mask: Optional[torch.Tensor], @@ -258,7 +257,7 @@ def encode_prompt( max_sequence_length: int, device: Optional[torch.device], dtype: Optional[torch.dtype], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ): r"""Encodes the prompt into text encoder hidden states. Args: @@ -268,8 +267,6 @@ def encode_prompt( The prompt or prompts not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - do_classifier_free_guidance (`bool`, *optional*): - Whether to use classifier free guidance or not. num_videos_per_prompt (`int`): Number of videos that should be generated per prompt. torch device to place the resulting embeddings on prompt_embeds (`torch.Tensor`, *optional*): @@ -298,7 +295,7 @@ def encode_prompt( else: batch_size = prompt_embeds.shape[0] - if prompt_embeds is None: + if prompt is not None: prompt_embeds, prompt_mask = self._get_t5_prompt_embeds( prompt=prompt, num_videos_per_prompt=num_videos_per_prompt, @@ -306,10 +303,32 @@ def encode_prompt( device=device, dtype=dtype, ) - else: - prompt_mask = None - # TODO: Also handle if negative prompt is provided (though the default is learned embeddings in MAGI-1) - return prompt_embeds, prompt_mask + + # Negative prompt embeddings are learned for MAGI-1 + # However, we still provide the option to pass them in + if self.do_classifier_free_guidance: + if negative_prompt is not None: + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + negative_prompt_embeds, negative_prompt_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + return prompt_embeds, prompt_mask, negative_prompt_embeds, negative_prompt_mask def check_inputs( self, @@ -533,10 +552,9 @@ def __call__( device = self._execution_device # 3. Encode input prompt - prompt_embeds, prompt_mask = self.encode_prompt( + prompt_embeds, prompt_mask, negative_prompt_embeds, negative_prompt_mask = self.encode_prompt( prompt, negative_prompt, - self.do_classifier_free_guidance, num_videos_per_prompt, prompt_embeds, prompt_mask, From 7a6e76412273dc28f2f26665d4a6444c65b018f5 Mon Sep 17 00:00:00 2001 From: Tuna Tuncer Date: Wed, 29 Oct 2025 05:42:30 +0100 Subject: [PATCH 13/18] extract a function from prepending special tokens --- .../pipelines/magi1/pipeline_magi1_t2v.py | 105 ++++++++++++------ 1 file changed, 71 insertions(+), 34 deletions(-) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py index ecf626eeaf76..bc45435589a0 100644 --- a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py @@ -195,6 +195,60 @@ def __init__( self._callback_tensor_inputs = ["latents"] # extend as needed # TODO: Add attributes + + def _build_text_pack( + self, + prompt_embeds: torch.Tensor, + prompt_mask: torch.Tensor, + negative_prompt_embeds: Optional[torch.Tensor], + negative_prompt_mask: Optional[torch.Tensor], + num_infer_chunks: int, + use_static_first_frames_token: bool, + use_dynamic_first_frames_token: bool, + use_borderness_token: bool, + use_hq_token: bool, + use_3d_model_token: bool, + use_2d_anime_token: bool, + use_duration_token: bool, + use_negative_special_tokens: bool, + ): + """ + Expand to chunk dim and prepend special tokens in MAGI order. + """ + prompt_embeds = prompt_embeds.unsqueeze(1).repeat(1, num_infer_chunks, 1, 1) + prompt_mask = prompt_mask.unsqueeze(1).repeat(1, num_infer_chunks, 1) + special_token_keys = get_special_token_keys( + use_static_first_frames_token=use_static_first_frames_token, + use_dynamic_first_frames_token=use_dynamic_first_frames_token, + use_borderness_token=use_borderness_token, + use_hq_token=use_hq_token, + use_3d_model_token=use_3d_model_token, + use_2d_anime_token=use_2d_anime_token, + use_duration_token=use_duration_token, + ) + prompt_embeds, prompt_mask = pad_special_token(special_token_keys, prompt_embeds, prompt_mask) + if self.do_classifier_free_guidance: + if negative_prompt_embeds is None: + # TODO: Load negative prompt embeds, they are learned + # null_caption_embedding = model.y_embedder.null_caption_embedding.unsqueeze(0) + # Creating zeros for negative prompt embeds for now + negative_prompt_embeds = torch.zeros(prompt_embeds.size(0), prompt_embeds.size(2), prompt_embeds.size(3)).to(prompt_embeds.device) + negative_mask = torch.zeros_like(prompt_mask) + negative_prompt_embeds = negative_prompt_embeds.unsqueeze(1).repeat(1, num_infer_chunks, 1, 1) + special_negative_token_keys = get_negative_special_token_keys( + use_negative_special_tokens=use_negative_special_tokens, + ) + negative_prompt_embeds, _ = pad_special_token(special_negative_token_keys, negative_prompt_embeds, None) + negative_token_length = 50 + negative_mask[:, :, :negative_token_length] = 1 + negative_mask[:, :, negative_token_length:] = 0 + if prompt_mask.sum() == 0: + prompt_embeds = torch.cat([negative_prompt_embeds, negative_prompt_embeds]) + prompt_mask = torch.cat([negative_mask, negative_mask], dim=0) + else: + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds]) + prompt_mask = torch.cat([prompt_mask, negative_mask], dim=0) + return prompt_embeds, prompt_mask def _get_t5_prompt_embeds( self, @@ -203,7 +257,7 @@ def _get_t5_prompt_embeds( max_sequence_length: int, device: Optional[torch.device], dtype: Optional[torch.dtype], - ): + ) -> Tuple[torch.Tensor, torch.Tensor]: # TODO: Double check if MAGI-1 does some special handling during prompt encoding device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -257,7 +311,7 @@ def encode_prompt( max_sequence_length: int, device: Optional[torch.device], dtype: Optional[torch.dtype], - ): + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: r"""Encodes the prompt into text encoder hidden states. Args: @@ -566,35 +620,18 @@ def __call__( ) num_infer_chunks = math.ceil((num_frames // self.temporal_downscale_factor) / self.chunk_width) - prompt_embeds = prompt_embeds.unsqueeze(1).repeat(1, num_infer_chunks, 1, 1) - prompt_mask = prompt_mask.unsqueeze(1).repeat(1, num_infer_chunks, 1) - special_token_keys = get_special_token_keys( - use_static_first_frames_token=use_static_first_frames_token, - use_dynamic_first_frames_token=use_dynamic_first_frames_token, - use_borderness_token=use_borderness_token, - use_hq_token=use_hq_token, - use_3d_model_token=use_3d_model_token, - use_2d_anime_token=use_2d_anime_token, - use_duration_token=use_duration_token) - prompt_embeds, prompt_mask = pad_special_token(special_token_keys, prompt_embeds, prompt_mask) - if self.do_classifier_free_guidance: - if negative_prompt_embeds is None: - # TODO: Load negative prompt embeds, they are learned - # null_caption_embedding = model.y_embedder.null_caption_embedding.unsqueeze(0) - # Creating zeros for negative prompt embeds for now - negative_prompt_embeds = torch.zeros(prompt_embeds.size(0), prompt_embeds.size(2), prompt_embeds.size(3)).to(prompt_embeds.device) - negative_mask = torch.zeros_like(prompt_mask) - negative_prompt_embeds = negative_prompt_embeds.unsqueeze(1).repeat(1, num_infer_chunks, 1, 1) - special_negative_token_keys = get_negative_special_token_keys( - use_negative_special_tokens=use_negative_special_tokens, - ) - negative_prompt_embeds, _ = pad_special_token(special_negative_token_keys, negative_prompt_embeds, None) - negative_token_length = 50 - negative_mask[:, :, :negative_token_length] = 1 - negative_mask[:, :, negative_token_length:] = 0 - if prompt_mask.sum() == 0: - prompt_embeds = torch.cat([negative_prompt_embeds, negative_prompt_embeds]) - prompt_mask = torch.cat([negative_mask, negative_mask], dim=0) - else: - prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds]) - prompt_mask = torch.cat([prompt_mask, negative_mask], dim=0) \ No newline at end of file + prompt_embeds, prompt_mask = self._build_text_pack( + prompt_embeds, + prompt_mask, + negative_prompt_embeds, + negative_prompt_mask, + num_infer_chunks, + use_static_first_frames_token, + use_dynamic_first_frames_token, + use_borderness_token, + use_hq_token, + use_3d_model_token, + use_2d_anime_token, + use_duration_token, + use_negative_special_tokens, + ) \ No newline at end of file From b88bc12572b4a3424fdd651d8a46fb5a64f0e824 Mon Sep 17 00:00:00 2001 From: Tuna Tuncer Date: Wed, 29 Oct 2025 05:48:11 +0100 Subject: [PATCH 14/18] fix typo --- src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py index bc45435589a0..100babac2cc3 100644 --- a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py @@ -233,21 +233,21 @@ def _build_text_pack( # null_caption_embedding = model.y_embedder.null_caption_embedding.unsqueeze(0) # Creating zeros for negative prompt embeds for now negative_prompt_embeds = torch.zeros(prompt_embeds.size(0), prompt_embeds.size(2), prompt_embeds.size(3)).to(prompt_embeds.device) - negative_mask = torch.zeros_like(prompt_mask) + negative_prompt_mask = torch.zeros_like(prompt_mask) negative_prompt_embeds = negative_prompt_embeds.unsqueeze(1).repeat(1, num_infer_chunks, 1, 1) special_negative_token_keys = get_negative_special_token_keys( use_negative_special_tokens=use_negative_special_tokens, ) negative_prompt_embeds, _ = pad_special_token(special_negative_token_keys, negative_prompt_embeds, None) negative_token_length = 50 - negative_mask[:, :, :negative_token_length] = 1 - negative_mask[:, :, negative_token_length:] = 0 + negative_prompt_mask[:, :, :negative_token_length] = 1 + negative_prompt_mask[:, :, negative_token_length:] = 0 if prompt_mask.sum() == 0: prompt_embeds = torch.cat([negative_prompt_embeds, negative_prompt_embeds]) - prompt_mask = torch.cat([negative_mask, negative_mask], dim=0) + prompt_mask = torch.cat([negative_prompt_mask, negative_prompt_mask], dim=0) else: prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds]) - prompt_mask = torch.cat([prompt_mask, negative_mask], dim=0) + prompt_mask = torch.cat([prompt_mask, negative_prompt_mask], dim=0) return prompt_embeds, prompt_mask def _get_t5_prompt_embeds( From 84cd56413ab266493e52878dc01cfaff6db15290 Mon Sep 17 00:00:00 2001 From: Tuna Tuncer Date: Wed, 29 Oct 2025 06:09:29 +0100 Subject: [PATCH 15/18] prepare latents --- .../pipelines/magi1/pipeline_magi1_t2v.py | 47 ++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py index 100babac2cc3..3c59bb037bf7 100644 --- a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py @@ -28,6 +28,7 @@ from ...models import AutoencoderKLMagi1, Magi1Transformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -190,7 +191,10 @@ def __init__( scheduler=scheduler, ) - self.temporal_downscale_factor = 4 # TODO: Read it from model (vae) config + # TODO: Double check if they are really read from config + self.temporal_downscale_factor = getattr(self.vae.config, "scale_factor_temporal", 4) + self.spatial_downscale_factor = getattr(self.vae.config, "scale_factor_spatial", 8) + self.num_channels_latents = self.transformer.config.in_channels self.chunk_width = 6 # TODO: Double check this value self._callback_tensor_inputs = ["latents"] # extend as needed # TODO: Add attributes @@ -462,6 +466,35 @@ def check_inputs( raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + num_chunks: int, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + shape = ( + batch_size, + num_channels_latents, + num_chunks * self.chunk_width, + height // self.spatial_downscale_factor, + width // self.spatial_downscale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents @property def do_classifier_free_guidance(self): @@ -634,4 +667,16 @@ def __call__( use_2d_anime_token, use_duration_token, use_negative_special_tokens, + ) + + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + self.num_channels_latents, + height, + width, + num_infer_chunks, + prompt_embeds.dtype, + device, + generator, + latents, ) \ No newline at end of file From a4b8824ba4b49b467b13188a0f62b69c91de24f9 Mon Sep 17 00:00:00 2001 From: Tuna Tuncer Date: Wed, 29 Oct 2025 08:18:01 +0100 Subject: [PATCH 16/18] implement the streaming structure --- .../pipelines/magi1/pipeline_magi1_t2v.py | 286 +++++++++++++++++- 1 file changed, 272 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py index 3c59bb037bf7..fba4146c1046 100644 --- a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py @@ -16,7 +16,7 @@ import math import os import re -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union import ftfy import numpy as np @@ -142,6 +142,59 @@ def prompt_clean(text): text = whitespace_clean(basic_clean(text)) return text + +def _compute_sequences(chunk_num: int, window_size: int, chunk_offset: int = 0): + """ + Replicates MAGI's generate_sequences(...). Returns four int lists: + clip_start, clip_end, t_start, t_end (all length = chunk_num + window_size - 1 - offset). + """ + start_index = chunk_offset + end_index = chunk_num + window_size - 1 + clip_start = [max(chunk_offset, i - window_size + 1) for i in range(start_index, end_index)] + clip_end = [min(chunk_num, i + 1) for i in range(start_index, end_index)] + t_start = [max(0, i - chunk_num + 1) for i in range(start_index, end_index)] + t_end = [min(window_size, i - chunk_offset + 1) if i - chunk_offset < window_size else window_size + for i in range(start_index, end_index)] + return clip_start, clip_end, t_start, t_end + +class ARWindow: + """ + Sliding window over temporal chunks, identical to MAGI's step/stage mapping. + + - chunk_num: number of chunks + - window_size: number of chunks per stage + - chunk_offset: prefix offset (0 for T2V) + """ + def __init__(self, chunk_num: int, window_size: int, chunk_offset: int = 0): + self.chunk_num = chunk_num + self.window_size = window_size + self.offset = chunk_offset + (self.clip_start, + self.clip_end, + self.t_start, + self.t_end) = _compute_sequences(chunk_num, window_size, chunk_offset) + + def total_forward_steps(self, num_steps: int) -> int: + per_stage = num_steps // self.window_size + return per_stage * (self.chunk_num + self.window_size - 1 - self.offset) + + def status(self, step: int, num_steps: int) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]: + """ + Returns: + (per_stage, stage, idx), + (chunk_start, chunk_end, t_start, t_end) + """ + per_stage = num_steps // self.window_size + stage, idx = divmod(step, per_stage) + return (per_stage, stage, idx), ( + self.clip_start[stage], + self.clip_end[stage], + self.t_start[stage], + self.t_end[stage], + ) + + + # TODO: Write example_doc_string EXAMPLE_DOC_STRING = """ Examples: @@ -196,6 +249,7 @@ def __init__( self.spatial_downscale_factor = getattr(self.vae.config, "scale_factor_spatial", 8) self.num_channels_latents = self.transformer.config.in_channels self.chunk_width = 6 # TODO: Double check this value + self.window_size = 4 # TODO: Double check this value self._callback_tensor_inputs = ["latents"] # extend as needed # TODO: Add attributes @@ -206,7 +260,7 @@ def _build_text_pack( prompt_mask: torch.Tensor, negative_prompt_embeds: Optional[torch.Tensor], negative_prompt_mask: Optional[torch.Tensor], - num_infer_chunks: int, + chunk_num: int, use_static_first_frames_token: bool, use_dynamic_first_frames_token: bool, use_borderness_token: bool, @@ -219,8 +273,8 @@ def _build_text_pack( """ Expand to chunk dim and prepend special tokens in MAGI order. """ - prompt_embeds = prompt_embeds.unsqueeze(1).repeat(1, num_infer_chunks, 1, 1) - prompt_mask = prompt_mask.unsqueeze(1).repeat(1, num_infer_chunks, 1) + prompt_embeds = prompt_embeds.unsqueeze(1).repeat(1, chunk_num, 1, 1) + prompt_mask = prompt_mask.unsqueeze(1).repeat(1, chunk_num, 1) special_token_keys = get_special_token_keys( use_static_first_frames_token=use_static_first_frames_token, use_dynamic_first_frames_token=use_dynamic_first_frames_token, @@ -238,7 +292,7 @@ def _build_text_pack( # Creating zeros for negative prompt embeds for now negative_prompt_embeds = torch.zeros(prompt_embeds.size(0), prompt_embeds.size(2), prompt_embeds.size(3)).to(prompt_embeds.device) negative_prompt_mask = torch.zeros_like(prompt_mask) - negative_prompt_embeds = negative_prompt_embeds.unsqueeze(1).repeat(1, num_infer_chunks, 1, 1) + negative_prompt_embeds = negative_prompt_embeds.unsqueeze(1).repeat(1, chunk_num, 1, 1) special_negative_token_keys = get_negative_special_token_keys( use_negative_special_tokens=use_negative_special_tokens, ) @@ -473,7 +527,7 @@ def prepare_latents( num_channels_latents: int, height: int, width: int, - num_chunks: int, + chunk_num: int, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -484,7 +538,7 @@ def prepare_latents( shape = ( batch_size, num_channels_latents, - num_chunks * self.chunk_width, + chunk_num * self.chunk_width, height // self.spatial_downscale_factor, width // self.spatial_downscale_factor, ) @@ -496,6 +550,114 @@ def prepare_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents + def _decode_chunk_to_frames(self, latents_chunk: torch.Tensor) -> np.ndarray: + """ + Decode a single latent chunk to video frames. + + Args: + latents_chunk: Latent tensor of shape [1, C, Tc, H_lat, W_lat] + + Returns: + Frames as uint8 numpy array of shape [Tc, H, W, 3] + """ + scale = getattr(self.vae.config, "scaling_factor", 1.0) + x = latents_chunk / scale + + # Decode through VAE + pixels = self.vae.decode(x).sample # [1, 3, Tc, H, W], expected in [-1, 1] + + # Convert to uint8 [0, 255] + pixels = (pixels.clamp(-1, 1) + 1) / 2 + pixels = (pixels * 255).round().to(torch.uint8) + + # Rearrange to [Tc, H, W, 3] + frames = pixels.squeeze(0).permute(1, 2, 3, 0).contiguous().cpu().numpy() + return frames + + def _denoise_loop_generator( + self, + latents: torch.Tensor, + prompt_embeds: torch.Tensor, + prompt_mask: torch.Tensor, + num_inference_steps: int, + chunk_num: int, + device: torch.device, + callback_on_step_end: Optional[Callable] = None, + callback_on_step_end_tensor_inputs: List[str] = None, + ) -> Generator[Dict[str, Any], None, None]: + """ + Generator that performs the denoising loop and yields clean chunks as they complete. + + Yields: + Dictionary with: + - "chunk_idx": int, index of the completed chunk + - "latents": torch.Tensor, clean latent chunk [1, C, chunk_width, H_lat, W_lat] + """ + window = ARWindow(chunk_num=chunk_num, window_size=self.window_size, chunk_offset=0) + total_steps = window.total_forward_steps(num_inference_steps) + denoise_counts = torch.zeros(chunk_num, dtype=torch.int32, device="cpu") + + # TODO: Initialize scheduler timesteps + # self.scheduler.set_timesteps(num_inference_steps, device=device) + # timesteps = self.scheduler.timesteps + + with self.progress_bar(total=total_steps) as pbar: + for step in range(total_steps): + (per_stage, stage, idx), (c_start, c_end, t_start, t_end) = window.status(step, num_inference_steps) + + # TODO: Implement the actual denoising step + # 1. Extract the window of latents for current stage + # latent_window = latents[:, :, c_start * self.chunk_width : c_end * self.chunk_width] + + # 2. Extract corresponding prompt embeddings + # y_window = prompt_embeds[:, c_start:c_end] + # mask_window = prompt_mask[:, c_start:c_end] + + # 3. Get timesteps for current stage + # t = timesteps[???] # Need to map stage/idx to timestep + + # 4. Forward through transformer + # noise_pred = self.transformer( + # latent_window, + # timestep=t, + # encoder_hidden_states=y_window, + # encoder_attention_mask=mask_window, + # **self._attention_kwargs or {}, + # ) + + # 5. Scheduler step to update latents + # latent_window = self.scheduler.step(noise_pred, t, latent_window).prev_sample + + # 6. Write back to main latents + # latents[:, :, c_start * self.chunk_width : c_end * self.chunk_width] = latent_window + + # 7. Update denoise counts and check for clean chunks + for c in range(c_start, c_end): + denoise_counts[c] += 1 + if denoise_counts[c] == num_inference_steps: + # Extract clean chunk (only conditional part if using CFG) + chunk_start = c * self.chunk_width + chunk_end = (c + 1) * self.chunk_width + + if self.do_classifier_free_guidance: + # Take only the conditional part (first half of batch) + clean_chunk = latents[0:1, :, chunk_start:chunk_end].detach() + else: + clean_chunk = latents[:, :, chunk_start:chunk_end].detach() + + yield {"chunk_idx": int(c), "latents": clean_chunk} + + # Callback support + if callback_on_step_end is not None: + callback_kwargs = {} + if callback_on_step_end_tensor_inputs: + for k in callback_on_step_end_tensor_inputs: + if k == "latents": + callback_kwargs[k] = latents + callback_on_step_end(self, step, None, callback_kwargs) + + pbar.update() + @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @@ -537,6 +699,7 @@ def __call__( use_2d_anime_token: bool = False, use_duration_token: bool = True, use_negative_special_tokens: bool = False, + stream_chunks: bool = False, ): r""" The call function to the pipeline for generation. @@ -601,13 +764,25 @@ def __call__( max_sequence_length (`int`, defaults to `800`): The maximum sequence length for the text encoder. Sequences longer than this will be truncated. MAGI-1 uses a max length of 800 tokens. + stream_chunks (`bool`, defaults to `False`): + Whether to stream chunks as they are generated. If `True`, this method returns a generator that yields + dictionaries containing `{"chunk_idx": int, "latents": torch.Tensor}` or + `{"chunk_idx": int, "frames": np.ndarray}` depending on `output_type`. If `False`, the method returns + the complete video after all chunks are generated. Examples: Returns: - [`~Magi1PipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`Magi1PipelineOutput`] is returned, otherwise a `tuple` is returned where - the first element is a list with the generated videos. + If `stream_chunks=False`: + [`~Magi1PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`Magi1PipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated videos. + + If `stream_chunks=True`: + Generator yielding dictionaries with: + - `chunk_idx` (int): Index of the chunk + - `latents` (torch.Tensor): If output_type is "latent" or "pt" + - `frames` (np.ndarray): If output_type is "np" """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): @@ -652,13 +827,13 @@ def __call__( self.text_encoder.dtype # TODO: double check what is passed here ) - num_infer_chunks = math.ceil((num_frames // self.temporal_downscale_factor) / self.chunk_width) + chunk_num = math.ceil((num_frames // self.temporal_downscale_factor) / self.chunk_width) prompt_embeds, prompt_mask = self._build_text_pack( prompt_embeds, prompt_mask, negative_prompt_embeds, negative_prompt_mask, - num_infer_chunks, + chunk_num, use_static_first_frames_token, use_dynamic_first_frames_token, use_borderness_token, @@ -674,9 +849,92 @@ def __call__( self.num_channels_latents, height, width, - num_infer_chunks, + chunk_num, prompt_embeds.dtype, device, generator, latents, - ) \ No newline at end of file + ) + + # Store attention kwargs for use in denoising loop + self._attention_kwargs = attention_kwargs + + # Execute denoising and handle streaming vs non-streaming + if stream_chunks: + # Streaming mode: Return generator that yields decoded chunks + def stream_generator(): + for event in self._denoise_loop_generator( + latents=latents, + prompt_embeds=prompt_embeds, + prompt_mask=prompt_mask, + num_inference_steps=num_inference_steps, + chunk_num=chunk_num, + device=device, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ): + chunk_idx = event["chunk_idx"] + chunk_latents = event["latents"] + + # Decode or keep as latents based on output_type + if output_type == "latent": + yield {"chunk_idx": chunk_idx, "latents": chunk_latents} + elif output_type == "pt": + # Decode but keep as torch tensor + frames = self._decode_chunk_to_frames(chunk_latents) + frames_pt = torch.from_numpy(frames).to(device) + yield {"chunk_idx": chunk_idx, "frames": frames_pt} + else: # output_type == "np" + # Decode to numpy frames + frames = self._decode_chunk_to_frames(chunk_latents) + yield {"chunk_idx": chunk_idx, "frames": frames} + + return stream_generator() + + else: + # Non-streaming mode: Collect all chunks, then decode at the end + collected_chunks = [] + + for event in self._denoise_loop_generator( + latents=latents, + prompt_embeds=prompt_embeds, + prompt_mask=prompt_mask, + num_inference_steps=num_inference_steps, + chunk_num=chunk_num, + device=device, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ): + collected_chunks.append(event["latents"]) + + if len(collected_chunks) == 0: + raise RuntimeError("No chunks were produced during generation.") + + # Concatenate all chunks + full_latents = torch.cat(collected_chunks, dim=2) # [1, C, T_lat, H_lat, W_lat] + + # Handle output format + if output_type == "latent": + output = full_latents + else: + # Decode all chunks + all_frames = [] + for chunk_latents in collected_chunks: + frames = self._decode_chunk_to_frames(chunk_latents) + all_frames.append(frames) + + # Concatenate frames along time dimension + video_frames = np.concatenate(all_frames, axis=0) # [T, H, W, 3] + + if output_type == "pt": + output = torch.from_numpy(video_frames).to(device) + else: # output_type == "np" + output = video_frames + + if not return_dict: + return (output,) + + # TODO: Return proper output type (Magi1PipelineOutput) + return {"frames": output} if output_type != "latent" else {"latents": output} + + From 0f3dbbaea5baca74eff397836e0d528199d0e677 Mon Sep 17 00:00:00 2001 From: Tuna Tuncer Date: Sat, 1 Nov 2025 01:12:28 +0100 Subject: [PATCH 17/18] better variable naming and cleaner docstrings --- .../pipelines/magi1/pipeline_magi1_t2v.py | 78 ++++++++++--------- 1 file changed, 42 insertions(+), 36 deletions(-) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py index fba4146c1046..f92fefddff2e 100644 --- a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py @@ -143,26 +143,26 @@ def prompt_clean(text): return text -def _compute_sequences(chunk_num: int, window_size: int, chunk_offset: int = 0): - """ - Replicates MAGI's generate_sequences(...). Returns four int lists: - clip_start, clip_end, t_start, t_end (all length = chunk_num + window_size - 1 - offset). - """ +def _generate_sequences(chunk_num: int, window_size: int, chunk_offset: int = 0): + """Compute global and local (window) start/end indices for each stage of a sliding window over chunked frames + replicating MAGI's `generate_sequences()` to drive the auto-regressive denoising schedule.""" start_index = chunk_offset end_index = chunk_num + window_size - 1 clip_start = [max(chunk_offset, i - window_size + 1) for i in range(start_index, end_index)] - clip_end = [min(chunk_num, i + 1) for i in range(start_index, end_index)] - t_start = [max(0, i - chunk_num + 1) for i in range(start_index, end_index)] - t_end = [min(window_size, i - chunk_offset + 1) if i - chunk_offset < window_size else window_size - for i in range(start_index, end_index)] + clip_end = [min(chunk_num, i + 1) for i in range(start_index, end_index)] + t_start = [max(0, i - chunk_num + 1) for i in range(start_index, end_index)] + t_end = [ + min(window_size, i - chunk_offset + 1) if i - chunk_offset < window_size else window_size + for i in range(start_index, end_index) + ] return clip_start, clip_end, t_start, t_end class ARWindow: """ Sliding window over temporal chunks, identical to MAGI's step/stage mapping. - - chunk_num: number of chunks - - window_size: number of chunks per stage + - chunk_num: number of total chunks + - window_size: max number of chunks processed in one stage - chunk_offset: prefix offset (0 for T2V) """ def __init__(self, chunk_num: int, window_size: int, chunk_offset: int = 0): @@ -172,25 +172,31 @@ def __init__(self, chunk_num: int, window_size: int, chunk_offset: int = 0): (self.clip_start, self.clip_end, self.t_start, - self.t_end) = _compute_sequences(chunk_num, window_size, chunk_offset) + self.t_end) = _generate_sequences(chunk_num, window_size, chunk_offset) - def total_forward_steps(self, num_steps: int) -> int: - per_stage = num_steps // self.window_size - return per_stage * (self.chunk_num + self.window_size - 1 - self.offset) + def calc_total_inference_steps(self, num_inference_steps: int) -> int: + steps_per_stage = num_inference_steps // self.window_size + return steps_per_stage * (self.chunk_num + self.window_size - 1 - self.offset) - def status(self, step: int, num_steps: int) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]: + def status(self, global_step_idx: int, num_inference_steps: int) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]: """ + Get current stage and step indices, along with chunk indices. Returns: - (per_stage, stage, idx), - (chunk_start, chunk_end, t_start, t_end) + - steps_per_stage: How many steps are in each stage + - stage_idx: Current stage index + - local_step_idx: Local (to the stage) step index within that stage + - clip_start: Global starting chunk index included in the current stage + - clip_end: Global ending chunk index (exclusive) included in the current stage + - t_start: Local (to the window) starting chunk index within the current stage + - t_end: Local (to the window) ending chunk index (exclusive) within the current stage """ - per_stage = num_steps // self.window_size - stage, idx = divmod(step, per_stage) - return (per_stage, stage, idx), ( - self.clip_start[stage], - self.clip_end[stage], - self.t_start[stage], - self.t_end[stage], + steps_per_stage = num_inference_steps // self.window_size + stage_idx, local_step_idx = divmod(global_step_idx, steps_per_stage) + return (steps_per_stage, stage_idx, local_step_idx), ( + self.clip_start[stage_idx], + self.clip_end[stage_idx], + self.t_start[stage_idx], + self.t_end[stage_idx], ) @@ -594,39 +600,39 @@ def _denoise_loop_generator( - "latents": torch.Tensor, clean latent chunk [1, C, chunk_width, H_lat, W_lat] """ window = ARWindow(chunk_num=chunk_num, window_size=self.window_size, chunk_offset=0) - total_steps = window.total_forward_steps(num_inference_steps) + total_inference_steps = window.calc_total_inference_steps(num_inference_steps) denoise_counts = torch.zeros(chunk_num, dtype=torch.int32, device="cpu") # TODO: Initialize scheduler timesteps # self.scheduler.set_timesteps(num_inference_steps, device=device) # timesteps = self.scheduler.timesteps - with self.progress_bar(total=total_steps) as pbar: - for step in range(total_steps): - (per_stage, stage, idx), (c_start, c_end, t_start, t_end) = window.status(step, num_inference_steps) + with self.progress_bar(total=total_inference_steps) as pbar: + for global_step_idx in range(total_inference_steps): + (steps_per_stage, stage_idx, local_step_idx), (c_start, c_end, t_start, t_end) = window.status(global_step_idx, num_inference_steps) # TODO: Implement the actual denoising step # 1. Extract the window of latents for current stage - # latent_window = latents[:, :, c_start * self.chunk_width : c_end * self.chunk_width] + latent_window = latents[:, :, c_start * self.chunk_width : c_end * self.chunk_width] # 2. Extract corresponding prompt embeddings - # y_window = prompt_embeds[:, c_start:c_end] - # mask_window = prompt_mask[:, c_start:c_end] + prompt_embeds_window = prompt_embeds[:, c_start:c_end] + prompt_mask_window = prompt_mask[:, c_start:c_end] # 3. Get timesteps for current stage # t = timesteps[???] # Need to map stage/idx to timestep # 4. Forward through transformer - # noise_pred = self.transformer( + # velocity = self.transformer( # latent_window, # timestep=t, - # encoder_hidden_states=y_window, - # encoder_attention_mask=mask_window, + # encoder_hidden_states=prompt_embeds_window, + # encoder_attention_mask=prompt_mask_window, # **self._attention_kwargs or {}, # ) # 5. Scheduler step to update latents - # latent_window = self.scheduler.step(noise_pred, t, latent_window).prev_sample + # latent_window = self.scheduler.step(velocity, t, latent_window).prev_sample # 6. Write back to main latents # latents[:, :, c_start * self.chunk_width : c_end * self.chunk_width] = latent_window From 31ae0ff288b5b3cf93596e447bf42b3d6106a329 Mon Sep 17 00:00:00 2001 From: Tuna Tuncer Date: Sun, 2 Nov 2025 18:29:50 +0100 Subject: [PATCH 18/18] implement the scheduler logic --- .../pipelines/magi1/pipeline_magi1_t2v.py | 75 ++++++++++++++----- 1 file changed, 56 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py index f92fefddff2e..2e7e0d0497ea 100644 --- a/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py +++ b/src/diffusers/pipelines/magi1/pipeline_magi1_t2v.py @@ -31,6 +31,8 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +# TODO: Load special tokens in a function instead of at the module level +# TODO: Add docstrings to functions taken from original MAGI-1 repo SPECIAL_TOKEN_PATH = os.getenv("SPECIAL_TOKEN_PATH", "example/assets/special_tokens.npz") SPECIAL_TOKEN = np.load(SPECIAL_TOKEN_PATH) @@ -322,12 +324,12 @@ def _get_t5_prompt_embeds( device: Optional[torch.device], dtype: Optional[torch.dtype], ) -> Tuple[torch.Tensor, torch.Tensor]: - # TODO: Double check if MAGI-1 does some special handling during prompt encoding device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt - # Just keep the clean function consistent with other pipelines + # The original prompt clean functionality is more complex however, + # we just keep the clean function consistent with other pipelines prompt = [prompt_clean(u) for u in prompt] batch_size = len(prompt) @@ -403,8 +405,6 @@ def encode_prompt( dtype: (`torch.dtype`, *optional*): torch dtype """ - # TODO: Can we provide different prompts for different chunks? - # If so, how are we gonna support that? device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt @@ -556,6 +556,37 @@ def prepare_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents + def prepare_magi_timesteps(self, num_steps: int, shift: float = 3.0, device: torch.device = None) -> torch.Tensor: + """ + Prepare timesteps following Magi's schedule: + 1. Linear spacing + 2. Square + 3. Apply shift transform with shift_inv + + Args: + num_steps: Number of denoising steps + shift: Shift parameter (default 3.0 from Magi) + device: Device to place tensor on + + Returns: + Timesteps tensor of shape [num_steps + 1] + """ + t = torch.linspace(0, 1, num_steps + 1, device=device) + t = t ** 2 + shift_inv = 1.0 / shift + t = shift_inv * t / (1 + (shift_inv - 1) * t) + return t + + def get_timesteps(self, timesteps, steps_per_stage, t_start, t_end, local_step_idx): + t_indices = [] + for i in range(t_start, t_end): + t_indices.append(i * steps_per_stage + local_step_idx) + t_indices.reverse() # AR windowing + timestep_window = timesteps[t_indices] + # TODO: Implement has_clean_t + return timestep_window + + def _decode_chunk_to_frames(self, latents_chunk: torch.Tensor) -> np.ndarray: """ Decode a single latent chunk to video frames. @@ -603,9 +634,7 @@ def _denoise_loop_generator( total_inference_steps = window.calc_total_inference_steps(num_inference_steps) denoise_counts = torch.zeros(chunk_num, dtype=torch.int32, device="cpu") - # TODO: Initialize scheduler timesteps - # self.scheduler.set_timesteps(num_inference_steps, device=device) - # timesteps = self.scheduler.timesteps + timesteps = self.prepare_magi_timesteps(num_inference_steps, shift=3.0, device=device) with self.progress_bar(total=total_inference_steps) as pbar: for global_step_idx in range(total_inference_steps): @@ -614,28 +643,36 @@ def _denoise_loop_generator( # TODO: Implement the actual denoising step # 1. Extract the window of latents for current stage latent_window = latents[:, :, c_start * self.chunk_width : c_end * self.chunk_width] + B, C, T_window, H, W = latent_window.shape # 2. Extract corresponding prompt embeddings prompt_embeds_window = prompt_embeds[:, c_start:c_end] prompt_mask_window = prompt_mask[:, c_start:c_end] # 3. Get timesteps for current stage - # t = timesteps[???] # Need to map stage/idx to timestep + t_window = self.get_timesteps(timesteps, steps_per_stage, t_start, t_end, local_step_idx) # [num chunks in window] # 4. Forward through transformer - # velocity = self.transformer( - # latent_window, - # timestep=t, - # encoder_hidden_states=prompt_embeds_window, - # encoder_attention_mask=prompt_mask_window, - # **self._attention_kwargs or {}, - # ) - + velocity = self.transformer( + latent_window, + timestep=t_window.unsqueeze(0).repeat(B, 1), + encoder_hidden_states=prompt_embeds_window, + encoder_attention_mask=prompt_mask_window, + **self._attention_kwargs or {}, + ) + # 5. Scheduler step to update latents - # latent_window = self.scheduler.step(velocity, t, latent_window).prev_sample + t_window_next = self.get_timesteps(timesteps, steps_per_stage, t_start, t_end, local_step_idx + 1) + delta_t_window = t_window_next - t_window + latent_window = latent_window.reshape(B, C, -1, self.chunk_width, H, W) + velocity = velocity.reshape(B, C, -1, self.chunk_width, H, W) + assert latent_window.size(2) == delta_t_window.size(0) + latent_window = latent_window + velocity * delta_t_window.reshape(1, 1, -1, 1, 1, 1) + latent_window = latent_window.reshape(B, C, T_window, H, W) + # 6. Write back to main latents - # latents[:, :, c_start * self.chunk_width : c_end * self.chunk_width] = latent_window + latents[:, :, c_start * self.chunk_width : c_end * self.chunk_width] = latent_window # 7. Update denoise counts and check for clean chunks for c in range(c_start, c_end): @@ -660,7 +697,7 @@ def _denoise_loop_generator( for k in callback_on_step_end_tensor_inputs: if k == "latents": callback_kwargs[k] = latents - callback_on_step_end(self, step, None, callback_kwargs) + callback_on_step_end(self, global_step_idx, None, callback_kwargs) pbar.update()