In [1]:
import jax
import jax.random as random
import jax.numpy as jnp

import equinox as eqx
import equinox.nn as nn

import optax 

from functools import partial

import numpy as np
import matplotlib.pyplot as plt

import tqdm

from vae.data.datasets import build_dataset
from vae.model.vqvae import VQVAE
from vae.model.metrics import compute_code_usage

## 1. Usage Metric

In [2]:
key = jax.random.key(0)
shape = (16, 4, 4)
minval = 0
maxval = 1024

samples = 16

track = jnp.zeros(maxval)

for _ in range(samples):
    idx = jax.random.randint(key, shape, minval=minval, maxval=maxval)
    track = track.at[idx.flatten()].set(1.0)

print(f"{jnp.sum(track) / maxval:.3f}")

0.220


## 2. Test on VQVAE

In [3]:
dataloader, num_classes, _, _ = build_dataset(
    "CIFAR10", 
    "/Users/anton/source/vae/vae/train",
    is_train=False
)

vqvae = VQVAE(
    key=jax.random.PRNGKey(0),
    in_channels=3,
    num_embeddings=512,
    embedding_dim=64,
    ch=128,
    ch_mult=(1, 1, 2, 2, 4),
    num_res_blocks=2,
    beta_commit=0.25,
    ema_decay=0.99,
    epsilon=1e-5
)

In [4]:
prop, ent = compute_code_usage(vqvae, dataloader)
print(f"proportion {prop:.3f}, entropy {ent:.3f}")

measuring code on test: 100%|██████████| 313/313 [01:41<00:00,  3.09it/s]


proportion 0.781, entropy 5.123


## utils

In [5]:
def entropy(dist): return sum(jax.scipy.special.entr(dist))

track = jnp.zeros(6)
zero_vec = jnp.zeros(6)
idx1 = jnp.arange(3)
idx2 = jnp.arange(1,3)
idx3 = jnp.arange(2,4)

for idx in [idx1, idx2, idx3]:
    track += zero_vec.at[idx.flatten()].set(1.0)
track /= sum(track)
track

Array([0.14285715, 0.2857143 , 0.42857143, 0.14285715, 0.        ,
       0.        ], dtype=float32)

In [6]:
entropy(track)

Array(1.2770343, dtype=float32)

In [7]:
jnp.sum(track != 0)

Array(4, dtype=int32)