Skip to content

ConvTranspose2d 19.6x slower than cuDNN on MI355X for VAE decoder shapes #3952

@sunway513

Description

@sunway513

Summary

torch.nn.ConvTranspose2d on MI355X (gfx950, MIOpen) is 19.6x slower than cuDNN on B300 for VAE decoder upsample shapes used in diffusion models (FLUX, SD3.5, SDXL).

Measurements

Shape MI355X (MIOpen) B300 (cuDNN 9.19) Ratio
Cin=512, Cout=256, k=3, s=2, H=128, W=128 1.349ms / 28.7 TFLOPS 0.069ms / 562.7 TFLOPS 19.6x slower

Tested with PyTorch 2.11, BF16, cudnn.benchmark=True on both sides.

Context

This shape is the VAE decoder upsample layer present in ALL diffusion models (FLUX, SD3.5, SDXL, Sana, HunyuanVideo, Wan2.1, etc.). VAE decode is 60-70% Conv2D and is often the e2e latency bottleneck for image generation.

Standard Conv2d (non-transposed) on MI355X is competitive — only 0.78x vs cuDNN. The transposed convolution path appears severely underoptimized.

Reproduce

import torch, time
conv_t = torch.nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1, bias=False).cuda().bfloat16()
x = torch.randn(1, 512, 128, 128, dtype=torch.bfloat16, device='cuda')
[conv_t(x) for _ in range(100)]
torch.cuda.synchronize(); t=time.perf_counter()
[conv_t(x) for _ in range(200)]
torch.cuda.synchronize(); ms=(time.perf_counter()-t)/200*1000
flops = 2*1*512*256*256*256*3*3  # approximate
print(f'{flops/ms*1e-9:.1f} TFLOPS ({ms:.3f}ms)')

Impact

Blocks competitive VAE decode performance on MI355X for all diffusion model inference.

cc @sunway513

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions