In [1]:
import treescope
import torchinfo
from nutils.benchmark import measure_flops, benchmark_model
treescope.register_as_default()
treescope.basic_interactive_setup(autovisualize_arrays=True)

W1215 19:59:22.699000 34260 site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.


# LDM VAE

In [16]:
from diffusers.models import AutoencoderKL
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
num_params = sum(p.numel() for p in vae.parameters())
num_params_encoder = sum(p.numel() for p in vae.encoder.parameters())
num_params_decoder = sum(p.numel() for p in vae.decoder.parameters())
print(f"Total number of parameters: {num_params:,}")
print(f"Number of parameters in the encoder: {num_params_encoder:,}")
print(f"Number of parameters in the decoder: {num_params_decoder:,}")

Total number of parameters: 83,653,863
Number of parameters in the encoder: 34,163,592
Number of parameters in the decoder: 49,490,179


In [17]:
from models.vae import VAE

vae = VAE(
    in_channels=3,
    out_channels=3,
    channels=128,
    channels_mult=[1,2,4,4],
    num_res_blocks=2,
    attn_resolutions=[],
    dropout=0.0,
    resolution=256,
    z_channels=4,
    spatial_compression=8,
    prior="gaussian",
)
num_params = sum(p.numel() for p in vae.parameters())
num_params_encoder = sum(p.numel() for p in vae.encoder.parameters())
num_params_decoder = sum(p.numel() for p in vae.decoder.parameters())
print(f"Total number of parameters: {num_params:,}")
print(f"Number of parameters in the encoder: {num_params_encoder:,}")
print(f"Number of parameters in the decoder: {num_params_decoder:,}")

Wavelet Transform: None
z of shape: (1, 4, 32, 32), dimensions: 4096
Total number of parameters: 83,653,863
Number of parameters in the encoder: 34,163,592
Number of parameters in the decoder: 49,490,179


In [18]:
torchinfo.summary(
    vae,
    (1, 3, 256, 256),
    depth=1,
    col_names=(
        "input_size",
        "output_size",
        "num_params",
        "params_percent",
        "mult_adds",
    ),
)

In [20]:
runtime = benchmark_model(vae, (1, 3, 256, 256), device="cpu")
print(f"Runtime: {runtime.median:.2f} ms")

# flops = measure_flops(vae, (1, 3, 256, 256), device="meta")
# print(f"FLOPs: {flops['forward_total']/1e9:,} GFLOPs")

Runtime: 3.74 ms


In [6]:
import torch
import torch.nn as nn
from einops import rearrange


class GRN(nn.Module):
    """GRN (Global Response Normalization) layer"""

    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))

    def forward(self, x):
        Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
        return self.gamma * (x * Nx) + self.beta + x


class ConvNeXtBlock(nn.Module):
    def __init__(
        self, *, in_channels: int, out_channels: int = None, dropout: float, **kwargs
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels or in_channels

        self.convdw1 = nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size=7,
            padding=3,
            groups=in_channels,
        )
        self.norm1 = nn.LayerNorm(in_channels)
        self.pwconv1_1 = nn.Linear(in_channels, 4 * in_channels)
        self.act1 = nn.GELU()
        self.gn1 = GRN(4 * in_channels)
        self.pwconv1_2 = nn.Linear(4 * in_channels, in_channels)

        self.convdw2 = nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size=7,
            padding=3,
            groups=in_channels,
        )
        self.norm2 = nn.LayerNorm(in_channels)
        self.pwconv2_1 = nn.Linear(in_channels, 4 * in_channels)
        self.act2 = nn.GELU()
        self.gn2 = GRN(4 * in_channels)
        self.pwconv2_2 = nn.Linear(4 * in_channels, out_channels)

        self.nin_shortcut = (
            nn.Conv2d(in_channels, self.out_channels, kernel_size=1, stride=1)
            if in_channels != self.out_channels
            else nn.Identity()
        )

    def forward(self, x):
        h = x
        h = self.convdw1(h)
        h = rearrange(h, "b c h w -> b h w c")
        h = self.norm1(h)
        h = self.pwconv1_1(h)
        h = self.act1(h)
        h = self.gn1(h)
        h = self.pwconv1_2(h)
        h = rearrange(h, "b h w c -> b c h w")

        x = h + x

        h = x
        h = self.convdw2(h)
        h = rearrange(h, "b c h w -> b h w c")
        h = self.norm2(h)
        h = self.pwconv2_1(h)
        h = self.act2(h)
        h = self.gn2(h)
        h = self.pwconv2_2(h)
        h = rearrange(h, "b h w c -> b c h w")

        x = self.nin_shortcut(x)
        return h + x 

In [12]:
vae = VAE(
    in_channels=3,
    out_channels=3,
    channels=128,
    channels_mult=[1,2,4,4],
    num_res_blocks=2,
    attn_resolutions=[],
    dropout=0.0,
    resolution=256,
    z_channels=4,
    spatial_compression=8,
    prior="gaussian",
    block_fn=ConvNeXtBlock
)
num_params = sum(p.numel() for p in vae.parameters())
num_params_encoder = sum(p.numel() for p in vae.encoder.parameters())
num_params_decoder = sum(p.numel() for p in vae.decoder.parameters())
print(f"Total number of parameters: {num_params:,}")
print(f"Number of parameters in the encoder: {num_params_encoder:,}")
print(f"Number of parameters in the decoder: {num_params_decoder:,}")

Wavelet Transform: None
z of shape: (1, 4, 32, 32), dimensions: 4096
Total number of parameters: 77,031,143
Number of parameters in the encoder: 28,987,656
Number of parameters in the decoder: 48,043,395


In [13]:
torchinfo.summary(
    vae,
    (1, 3, 256, 256),
    depth=1,
    col_names=(
        "input_size",
        "output_size",
        "num_params",
        "params_percent",
        "mult_adds",
    ),
)

In [15]:
runtime = benchmark_model(vae.cpu(), (1, 3, 256, 256), device="cpu")
print(f"Runtime: {runtime.median:.2f} ms")

# flops = measure_flops(vae, (1, 3, 256, 256), device="meta")
# print(f"FLOPs: {flops['forward_total']/1e9:,} GFLOPs")

Runtime: 47.43 ms
