In [None]:
import torch
import treescope
from models.gan import DinoPatchDiscriminator
from torchmetrics.image import (
    StructuralSimilarityIndexMeasure,
    PeakSignalNoiseRatio,
    FrechetInceptionDistance,
    LearnedPerceptualImagePatchSimilarity,
)

treescope.register_as_default()

In [None]:
model = DinoPatchDiscriminator('small')
model(torch.randn(1, 3, 256, 256).cuda())

In [None]:
from layers.layers2d import Encoder, Decoder
from models.vae import VAE

In [None]:
vae = VAE(
    in_channels=3,
    out_channels=3,
    channels=128,
    channels_mult=[1, 2, 4, 4],
    num_res_blocks=2,
    attn_resolutions=[32],
    dropout=0.0,
    resolution=256,
    z_channels=4,
    spatial_compression=4,
    wavelet="db4",
    maxlevel=2
)
vae

In [None]:
vae.forward(torch.randn(1, 3, 256, 256))

In [None]:
encoder = Encoder(
    in_channels=3,
    channels=128,
    channels_mult=[1, 2, 4, 4],
    num_res_blocks=2,
    attn_resolutions=[32],
    resolution=256,
    z_channels=4,
    spatial_compression=8,
    dropout=0.1,
)

decoder = Decoder(
    z_channels=4,
    channels=128,
    channels_mult=[1, 2, 4, 4],
    num_res_blocks=2,
    attn_resolutions=[32],
    resolution=256,
    out_channels=3,
    spatial_compression=8,
    dropout=0.1,
)

In [None]:
encoder

In [None]:
decoder

In [None]:
decoder.forward(torch.randn(1, 4, 8, 8)).shape

In [1]:
import torch
import einops
import treescope

from models.vqvae import VQVAE
from layers.quantizers import FSQuantizer
treescope.register_as_default()
treescope.basic_interactive_setup(autovisualize_arrays=True)


In [2]:
a = torch.randn((1, 1, 8))
b = torch.randn((1, 2, 8))

print(einops.rearrange(a, "... c d -> ... (c d)").shape)
print(einops.rearrange(b, "... c d -> ... (c d)").shape)

torch.Size([1, 8])
torch.Size([1, 16])


In [None]:
quantizer = FSQuantizer(levels=[7, 5, 5, 5], input_dim=256, num_codebooks=1)

In [None]:
codes, indices = quantizer(torch.randn(1, 256, 8, 8))

In [None]:
proj_codes, codes = quantizer.indices_to_codes(indices)

In [2]:
vq = VQVAE(
    in_channels=3,
    out_channels=3,
    channels=128,
    channels_mult=[1, 2, 4, 4],
    num_res_blocks=2,
    attn_resolutions=[32],
    dropout=0.0,
    resolution=256,
    z_channels=256,
    spatial_compression=4,
    wavelet="db4",
    maxlevel=2,
    quantization="fsq",
    levels=[8, 5, 5, 5],
    embedding_dim=4,
    num_codebooks=1,
)

Wavelet Transform: db4
z of shape: (1, 256, 16, 16), dimensions: 65536


In [4]:
vq

In [3]:
vq(torch.randn(1, 3, 256, 256))