In [1]:
import treescope
import torchinfo
from nutils.benchmark import measure_flops, benchmark_model
treescope.register_as_default()
treescope.basic_interactive_setup(autovisualize_arrays=True)

  from .autonotebook import tqdm as notebook_tqdm


# LDM VAE

In [2]:
from diffusers.models import AutoencoderKL
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
num_params = sum(p.numel() for p in vae.parameters())
num_params_encoder = sum(p.numel() for p in vae.encoder.parameters())
num_params_decoder = sum(p.numel() for p in vae.decoder.parameters())
print(f"Total number of parameters: {num_params:,}")
print(f"Number of parameters in the encoder: {num_params_encoder:,}")
print(f"Number of parameters in the decoder: {num_params_decoder:,}")

Total number of parameters: 83,653,863
Number of parameters in the encoder: 34,163,592
Number of parameters in the decoder: 49,490,179


In [3]:
from models.vae import VAE

vae = VAE(
    in_channels=3,
    out_channels=3,
    channels=128,
    channels_mult=[1,2,4,4],
    num_res_blocks=2,
    attn_resolutions=[],
    dropout=0.0,
    resolution=256,
    z_channels=4,
    spatial_compression=8,
    prior="gaussian",
)
num_params = sum(p.numel() for p in vae.parameters())
num_params_encoder = sum(p.numel() for p in vae.encoder.parameters())
num_params_decoder = sum(p.numel() for p in vae.decoder.parameters())
print(f"Total number of parameters: {num_params:,}")
print(f"Number of parameters in the encoder: {num_params_encoder:,}")
print(f"Number of parameters in the decoder: {num_params_decoder:,}")

Wavelet Transform: None
z of shape: (1, 4, 32, 32), dimensions: 4096
Total number of parameters: 83,653,863
Number of parameters in the encoder: 34,163,592
Number of parameters in the decoder: 49,490,179


In [4]:
torchinfo.summary(
    vae,
    (1, 3, 256, 256),
    depth=1,
    col_names=(
        "input_size",
        "output_size",
        "num_params",
        "params_percent",
        "mult_adds",
    ),
)

In [5]:
runtime = benchmark_model(vae, (2, 3, 256, 256), device="cuda")
print(f"Runtime: {runtime.median:.2f} ms")

flops = measure_flops(vae, (1, 3, 256, 256), device="cuda")
print(f"FLOPs: {flops['forward_total']/1e9:,} GFLOPs")

Runtime: 0.09 ms
FLOPs: 1,787.16164096 GFLOPs


In [6]:
import torch
import torch.nn as nn
from einops import rearrange


class GRN(nn.Module):
    """GRN (Global Response Normalization) layer"""

    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))

    def forward(self, x):
        Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
        return self.gamma * (x * Nx) + self.beta + x


class ConvNeXtBlock(nn.Module):
    def __init__(
        self, *, in_channels: int, out_channels: int = None, dropout: float, **kwargs
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels or in_channels

        self.convdw1 = nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size=7,
            padding=3,
            groups=in_channels,
        )
        self.norm1 = nn.LayerNorm(in_channels)
        self.pwconv1_1 = nn.Linear(in_channels, 4 * in_channels)
        self.act1 = nn.GELU()
        self.gn1 = GRN(4 * in_channels)
        self.pwconv1_2 = nn.Linear(4 * in_channels, in_channels)

        self.up_proj = (
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
            if in_channels != out_channels
            else nn.Identity()
        )
        self.nin_shortcut = (
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
            if in_channels != out_channels
            else nn.Identity()
        )

    def forward(self, x):
        h = x
        h = self.convdw1(h)
        h = rearrange(h, "b c h w -> b h w c")
        # h = h.permute()
        h = self.norm1(h)
        h = self.pwconv1_1(h)
        h = self.act1(h)
        h = self.gn1(h)
        h = self.pwconv1_2(h)
        h = rearrange(h, "b h w c -> b c h w")

        x = self.up_proj(h) + self.nin_shortcut(x)
        return x
    
class ConvNeXtStarBlock(nn.Module):
    def __init__(
        self, *, in_channels: int, out_channels: int = None, dropout: float, **kwargs
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels or in_channels

        self.convdw1 = nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size=7,
            padding=3,
            groups=in_channels,
        )
        self.norm1 = nn.LayerNorm(in_channels)
        self.pwconv1_1 = nn.Linear(in_channels, 4 * in_channels)
        self.act1 = nn.ReLU6()
        self.gn1 = GRN(2 * in_channels)
        self.pwconv1_2 = nn.Linear(2 * in_channels, in_channels)

        self.up_proj = (
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
            if in_channels != out_channels
            else nn.Identity()
        )
        self.nin_shortcut = (
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
            if in_channels != out_channels
            else nn.Identity()
        )

    def forward(self, x):
        h = x
        h = self.convdw1(h)
        h = rearrange(h, "b c h w -> b h w c")
        # h = h.permute()
        h = self.norm1(h)
        h = self.pwconv1_1(h)
        h1, h2 = torch.chunk(h, 2, dim=-1)
        h = self.act1(h1) * h2
        h = self.gn1(h)
        h = self.pwconv1_2(h)
        h = rearrange(h, "b h w c -> b c h w")

        x = self.up_proj(h) + self.nin_shortcut(x)
        return x

In [7]:
vae = VAE(
    in_channels=3,
    out_channels=3,
    channels=128,
    channels_mult=[1,2,4,4],
    num_res_blocks=2,
    attn_resolutions=[],
    dropout=0.0,
    resolution=256,
    z_channels=4,
    spatial_compression=8,
    prior="gaussian",
    block_fn="convnext",
)
num_params = sum(p.numel() for p in vae.parameters())
num_params_encoder = sum(p.numel() for p in vae.encoder.parameters())
num_params_decoder = sum(p.numel() for p in vae.decoder.parameters())
print(f"Total number of parameters: {num_params:,}")
print(f"Number of parameters in the encoder: {num_params_encoder:,}")
print(f"Number of parameters in the decoder: {num_params_decoder:,}")

Wavelet Transform: None
z of shape: (1, 4, 32, 32), dimensions: 4096
Total number of parameters: 44,460,007
Number of parameters in the encoder: 16,672,008
Number of parameters in the decoder: 27,787,907


In [8]:
torchinfo.summary(
    vae,
    (1, 3, 256, 256),
    depth=1,
    col_names=(
        "input_size",
        "output_size",
        "num_params",
        "params_percent",
        "mult_adds",
    ),
)

In [9]:
runtime = benchmark_model(vae, (2, 3, 256, 256), device="cuda")
print(f"Runtime: {runtime.median:.2f} ms")

flops = measure_flops(vae, (1, 3, 256, 256), device="cuda")
print(f"FLOPs: {flops['forward_total']/1e9:,} GFLOPs")

Runtime: 2.91 ms
FLOPs: 1,200.879730688 GFLOPs


In [10]:
vae = VAE(
    in_channels=3,
    out_channels=3,
    channels=128,
    channels_mult=[1,2,4,4],
    num_res_blocks=2,
    attn_resolutions=[],
    dropout=0.0,
    resolution=256,
    z_channels=4,
    spatial_compression=8,
    prior="gaussian",
    block_fn="convnext-star"
)
num_params = sum(p.numel() for p in vae.parameters())
num_params_encoder = sum(p.numel() for p in vae.encoder.parameters())
num_params_decoder = sum(p.numel() for p in vae.decoder.parameters())
print(f"Total number of parameters: {num_params:,}")
print(f"Number of parameters in the encoder: {num_params_encoder:,}")
print(f"Number of parameters in the decoder: {num_params_decoder:,}")

Wavelet Transform: None
z of shape: (1, 4, 32, 32), dimensions: 4096
Total number of parameters: 36,264,423
Number of parameters in the encoder: 13,676,296
Number of parameters in the decoder: 22,588,035


In [11]:
torchinfo.summary(
    vae,
    (1, 3, 256, 256),
    depth=1,
    col_names=(
        "input_size",
        "output_size",
        "num_params",
        "params_percent",
        "mult_adds",
    ),
)

In [12]:
runtime = benchmark_model(vae, (2, 3, 256, 256), device="cuda")
print(f"Runtime: {runtime.median:.2f} ms")

flops = measure_flops(vae, (1, 3, 256, 256), device="cuda")
print(f"FLOPs: {flops['forward_total']/1e9:,} GFLOPs")

Runtime: 1.18 ms
FLOPs: 1,014.048653312 GFLOPs


In [13]:
vae.decoder.conv_out