# STDiT Backbone

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
# Import all
import torch
from mmengine.runner import set_random_seed

from opensora.registry import MODELS, SCHEDULERS, build_module
from opensora.utils.misc import to_torch_dtype
from opensora.datasets.aspect import get_image_size, get_num_frames
from opensora.utils.inference_utils import (
    prepare_multi_resolution_info,
    collect_references_batch,
    apply_mask_strategy,
    extract_json_from_prompts,
)
from opensora.schedulers.rf.rectified_flow import timestep_transform, RFlowScheduler

  _register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)


In [3]:
print("CUDA available:", torch.cuda.is_available())
print("Number of GPUs:", torch.cuda.device_count())

if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

CUDA available: True
Number of GPUs: 1
GPU: NVIDIA A100 80GB PCIe


## Load models

In [4]:
# Configs
resolution = "144p"
aspect_ratio = "9:16"
num_frames = 51
fps = 24
frame_interval = 1
save_fps = 24

save_dir = "/home/tran/workspace/Open-Sora/save/inference/test"
seed = 42
batch_size = 1
multi_resolution = "STDiT2"
dtype = "fp16"
condition_frame_length = 5
align = 5

model_cfg = dict(
    type="STDiT3-XL/2",
    from_pretrained="hpcai-tech/OpenSora-STDiT-v3",
    qk_norm=True,
    enable_flash_attn=True,
    enable_layernorm_kernel=True,
)
vae_cfg = dict(
    type="OpenSoraVAE_V1_2",
    from_pretrained="hpcai-tech/OpenSora-VAE-v1.2",
    micro_frame_size=17,
    micro_batch_size=4,
)
text_encoder_cfg = dict(
    type="t5", 
    from_pretrained="DeepFloyd/t5-v1_1-xxl", 
    model_max_length=300, 
    dtype=dtype
)

num_sampling_steps = 30
scheduler_cfg = dict(
    type="rflow",
    use_timestep_transform=True,
    num_sampling_steps=num_sampling_steps,
    cfg_scale=7.0,
)

aes = 6.5
flow = None

In [5]:
# Settings
device = torch.device("cuda")
dtype = to_torch_dtype(dtype)

image_size = get_image_size(resolution, aspect_ratio)
num_frames = get_num_frames(num_frames)

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

set_random_seed(seed)

42

In [6]:
# Init text and image encoder
text_encoder = build_module(text_encoder_cfg, MODELS, device=device)
vae = build_module(vae_cfg, MODELS).to(device, dtype).eval()



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
# Build diffusion model
input_size = (num_frames, *image_size)
latent_size = vae.get_latent_size(input_size)

model = (
        build_module(
            model_cfg,
            MODELS,
            input_size=latent_size,
            in_channels=vae.out_channels,
            caption_channels=text_encoder.output_dim,
            model_max_length=text_encoder.model_max_length,
            enable_sequence_parallelism=False,
        )
        .to(device, dtype)
        .eval()
    )
text_encoder.y_embedder = model.y_embedder  # HACK: for classifier-free guidance

model

STDiT3(
  (pos_embed): PositionEmbedding2D()
  (rope): RotaryEmbedding()
  (x_embedder): PatchEmbed3D(
    (proj): Conv3d(4, 1152, kernel_size=(1, 2, 2), stride=(1, 2, 2))
  )
  (t_embedder): TimestepEmbedder(
    (mlp): Sequential(
      (0): Linear(in_features=256, out_features=1152, bias=True)
      (1): SiLU()
      (2): Linear(in_features=1152, out_features=1152, bias=True)
    )
  )
  (fps_embedder): SizeEmbedder(
    (mlp): Sequential(
      (0): Linear(in_features=256, out_features=1152, bias=True)
      (1): SiLU()
      (2): Linear(in_features=1152, out_features=1152, bias=True)
    )
  )
  (t_block): Sequential(
    (0): SiLU()
    (1): Linear(in_features=1152, out_features=6912, bias=True)
  )
  (y_embedder): CaptionEmbedder(
    (y_proj): Mlp(
      (fc1): Linear(in_features=4096, out_features=1152, bias=True)
      (act): GELU(approximate='tanh')
      (drop1): Dropout(p=0, inplace=False)
      (norm): Identity()
      (fc2): Linear(in_features=1152, out_features=1152, bi

## Preprocess Input

In [8]:
# Text conditioning
prompts = ["A bear climbing a tree"]
model_args = text_encoder.encode(prompts)
model_args["y"].shape, model_args["mask"].shape

(torch.Size([1, 1, 300, 4096]), torch.Size([1, 300]))

In [9]:
# Classifier-free guidance
n = len(prompts)
y_null = text_encoder.null(n)
model_args["y"] = torch.cat([model_args["y"], y_null], 0)

In [10]:
# Prepare additional arguments
additional_args = prepare_multi_resolution_info(
    multi_resolution, n, image_size, num_frames, fps, device, dtype
)
model_args.update(additional_args)
model_args

{'y': tensor([[[[-0.0812, -0.1660,  0.0464,  ...,  0.0174, -0.0797,  0.0227],
           [-0.0566, -0.1606,  0.0829,  ...,  0.1131,  0.0443,  0.0917],
           [-0.0417, -0.0914,  0.0007,  ...,  0.0492, -0.1201,  0.0847],
           ...,
           [ 0.0494, -0.1984,  0.0881,  ...,  0.0006,  0.0547, -0.0087],
           [ 0.0494, -0.1984,  0.0881,  ...,  0.0006,  0.0547, -0.0087],
           [ 0.0494, -0.1984,  0.0881,  ...,  0.0006,  0.0547, -0.0087]]],
 
 
         [[[-0.0060, -0.0104, -0.0017,  ..., -0.0063,  0.0064,  0.0035],
           [ 0.0576, -0.0359,  0.0204,  ...,  0.0500, -0.0530, -0.0042],
           [ 0.0576, -0.0359,  0.0204,  ...,  0.0500, -0.0530, -0.0042],
           ...,
           [ 0.0576, -0.0359,  0.0204,  ...,  0.0500, -0.0530, -0.0042],
           [ 0.0576, -0.0359,  0.0204,  ...,  0.0500, -0.0530, -0.0042],
           [ 0.0576, -0.0359,  0.0204,  ...,  0.0500, -0.0530, -0.0042]]]],
        device='cuda:0', dtype=torch.float16),
 'mask': tensor([[1, 1, 1, 1, 1

In [11]:
# Timesteps
num_timesteps = 1000

timesteps = [(1.0 - i / num_sampling_steps) * num_timesteps for i in range(num_sampling_steps)]
timesteps = [timestep_transform(t, additional_args, num_timesteps=num_timesteps) for t in timesteps]
timesteps

[tensor([1000.], device='cuda:0'),
 tensor([976.8082], device='cuda:0'),
 tensor([953.1246], device='cuda:0'),
 tensor([928.9335], device='cuda:0'),
 tensor([904.2182], device='cuda:0'),
 tensor([878.9617], device='cuda:0'),
 tensor([853.1458], device='cuda:0'),
 tensor([826.7519], device='cuda:0'),
 tensor([799.7603], device='cuda:0'),
 tensor([772.1504], device='cuda:0'),
 tensor([743.9008], device='cuda:0'),
 tensor([714.9890], device='cuda:0'),
 tensor([685.3915], device='cuda:0'),
 tensor([655.0834], device='cuda:0'),
 tensor([624.0389], device='cuda:0'),
 tensor([592.2310], device='cuda:0'),
 tensor([559.6310], device='cuda:0'),
 tensor([526.2089], device='cuda:0'),
 tensor([491.9334], device='cuda:0'),
 tensor([456.7712], device='cuda:0'),
 tensor([420.6876], device='cuda:0'),
 tensor([383.6458], device='cuda:0'),
 tensor([345.6071], device='cuda:0'),
 tensor([306.5309], device='cuda:0'),
 tensor([266.3739], device='cuda:0'),
 tensor([225.0908], device='cuda:0'),
 tensor([182.63

In [12]:
# Noise input
z = torch.randn(n, vae.out_channels, *latent_size, device=device, dtype=dtype)
z.shape

torch.Size([1, 4, 15, 18, 32])

In [13]:
# Image conditioning
refs = ["/remote/vast0/tran/workspace/Open-Sora/save/references/sample.jpg"]
mask_strategy = [""]

prompts, refs, ms = extract_json_from_prompts(prompts, refs, mask_strategy)
refs = collect_references_batch(refs, vae, image_size)
mask = apply_mask_strategy(z, refs, ms, 0, align=align)
mask.shape

torch.Size([1, 15])

In [14]:
# Init noise added
noise_added = torch.zeros_like(mask, dtype=torch.bool)
noise_added = noise_added | (mask == 1)

## Real Forward

In [15]:
# Init scheduler
scheduler = RFlowScheduler(
    num_timesteps=num_timesteps,
    num_sampling_steps=num_sampling_steps,
    use_discrete_timesteps=False,
    use_timestep_transform=True,
)
scheduler

<opensora.schedulers.rf.rectified_flow.RFlowScheduler at 0x7f4f35681e50>

In [16]:

# Prepare mask
t = timesteps[0]

mask_t = mask * num_timesteps
x0 = z.clone()
x_noise = scheduler.add_noise(x0, torch.randn_like(x0), t)

mask_t_upper = mask_t >= t.unsqueeze(1)
model_args["x_mask"] = mask_t_upper.repeat(2, 1)
mask_add_noise = mask_t_upper & ~noise_added

z = torch.where(mask_add_noise[:, None, :, None, None], x_noise, x0)
noise_added = mask_t_upper

In [17]:
# Prepare data
z_in = torch.cat([z, z], 0)
t_in = torch.cat([t, t], 0)

z_in.shape, t_in.shape

(torch.Size([2, 4, 15, 18, 32]), torch.Size([2]))

In [18]:
# Unpack model args
y = model_args["y"]
mask = model_args["mask"]
x_mask = model_args["x_mask"]
fps = model_args["fps"]
height = model_args["height"]
width = model_args["width"]

y.shape, mask.shape, x_mask.shape, fps.shape, height.shape, width.shape

(torch.Size([2, 1, 300, 4096]),
 torch.Size([1, 300]),
 torch.Size([2, 15]),
 torch.Size([1]),
 torch.Size([1]),
 torch.Size([1]))

In [19]:
# First real iteration
true_pred = model(
    z_in, t_in, y, 
    mask=mask, x_mask=x_mask, 
    fps=fps, height=height, 
    width=width
)
true_pred.shape

torch.Size([2, 8, 15, 18, 32])

## ONNX Conversion

In [20]:
ONNX_FILEPATH = "/home/tran/workspace/Open-Sora/tensorrt/resources/stdit3.onnx"

dynamic_axes = {
    "z_in": {
        0: "2batchsize",
        2: "frames",
        3: "height",
        4: "width",
    },
    "t_in": {
        0: "2batchsize",
    },
    "y": {
        0: "2batchsize",
    },
    "mask": {
        0: "batchsize",
    },
    "x_mask": {
        0: "2batchsize",
        1: "frames",
    }
}

input_names = [
    "z_in", "t_in", "y", 
    "mask", "x_mask", 
    "fps", "height", "width"
]

inputs = (
    z_in, t_in, y, 
    mask, x_mask, 
    fps, height, width
)

output_names = ["output"]

torch.onnx.export(
    model,
    inputs,
    ONNX_FILEPATH,
    export_params=True,
    input_names=input_names,
    output_names=output_names,
    dynamic_axes=dynamic_axes
)

  if T % self.patch_size[0] != 0:
  if H % self.patch_size[1] != 0:
  if W % self.patch_size[2] != 0:
  S = torch.tensor(H * W)
  S = torch.tensor(H * W)
  resolution_sq = (height[0].item() * width[0].item()) ** 0.5
  if s.shape[0] != bs:
  assert s.shape[0] == bs
  if mask.shape[0] != y.shape[0]:
  y_lens = mask.sum(dim=1).tolist()
  if W % self.patch_size[2] != 0:
  if H % self.patch_size[1] != 0:
  if D % self.patch_size[0] != 0:


  if enable_flash_attn:
  if not enable_flash_attn:
  max_seqlen = max(max_seqlen, seqlen)
  min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen
  seqstart = torch.tensor(seqstart_py, dtype=torch.int32)
  if kv_seqlen is None or q_seqlen == kv_seqlen:
  if any(x.shape[0] != 1 for x in qkv):
  self.query.shape == (B, Mq, H, K)
  and self.key.shape == (B, Mkv, H, key_embed_dim)
  and self.value.shape == (B, Mkv, H, Kv)
  if not cls.SUPPORTS_DIFFERENT_VALUE_EMBED and K != Kv:
  if max(K, Kv) > cls.SUPPORTED_MAX_K:
  if x.shape[-1] % alignment != 0:
  if inp.query.numel() > 0 and inp.key.numel() > 0:


RuntimeError: _Map_base::at