In [1]:
import time
import mlx.core as mx 
import mlx.nn as nn 
import numpy as np
from mnist import mnist
from model_cnn import VAE
from loss import vae_loss
import mlx.optimizers as optim

In [2]:
train_x, train_y, test_x, test_y = map(mx.array, mnist())
train_x = train_x.reshape(-1, 28, 28, 1)
test_x = test_x.reshape(-1, 28, 28, 1)
train_c = mx.zeros((train_y.size, 10))
train_c[mx.arange(len(train_y)), train_y] = 1
test_c = mx.zeros((test_y.size, 10))
test_c[mx.arange(len(test_y)), test_y] = 1

In [3]:
model = VAE(train_x.shape[1:], latent_dim=10, n_cond=10)
mx.eval(model.parameters())
model

VAE(
  (encoder): Encoder(
    (conv1): Conv2d(1, 32, kernel_size=[4], stride=(2, 2), padding=(1, 1), bias=True)
    (conv2): Conv2d(32, 32, kernel_size=[4], stride=(2, 2), padding=(1, 1), bias=True)
    (conv3): Conv2d(32, 32, kernel_size=[4], stride=(2, 2), padding=(2, 2), bias=True)
    (dense1): Linear(input_dims=512, output_dims=256, bias=True)
    (dense2): Linear(input_dims=256, output_dims=256, bias=True)
    (lin_mu): Linear(input_dims=266, output_dims=10, bias=True)
    (lin_logvar): Linear(input_dims=266, output_dims=10, bias=True)
  )
  (decoder): Decoder(
    (dense1): Linear(input_dims=20, output_dims=256, bias=True)
    (dense2): Linear(input_dims=256, output_dims=256, bias=True)
    (dense3): Linear(input_dims=256, output_dims=784, bias=True)
  )
)

In [4]:
loss_and_grad_fn = nn.value_and_grad(model, vae_loss)
optimizer = optim.Adam(learning_rate=2*1e-4)

In [5]:
def batch_iterate(batch_size, X, y):
    perm = mx.array(np.random.permutation(y.shape[0]))
    for s in range(0, y.shape[0], batch_size):
        ids = perm[s : s + batch_size]
        yield X[ids], y[ids]
BATCH_SIZE = 1024
EPOCHS = 200
np.random.seed(42)
for e in range(EPOCHS):
        tic = time.perf_counter()
        for X, c in batch_iterate(BATCH_SIZE, train_x, train_c):
            loss, grads = loss_and_grad_fn(model, X, c)
            optimizer.update(model, grads)
            mx.eval(model.parameters(), optimizer.state)
        # accuracy = eval_fn(model, test_images, test_labels)
        toc = time.perf_counter()
        if (e + 1) % 20 == 0:
            print(
                f"Epoch {e + 1}: Train loss {loss.item():.3f},"
                f" Time {toc - tic:.3f} (s)"
            )

Epoch 20: Train loss 3117.908, Time 2.418 (s)
Epoch 40: Train loss 2308.216, Time 2.454 (s)
Epoch 60: Train loss 1998.906, Time 2.516 (s)
Epoch 80: Train loss 1672.092, Time 2.341 (s)
Epoch 100: Train loss 1660.109, Time 2.537 (s)
Epoch 120: Train loss 1578.101, Time 2.386 (s)
Epoch 140: Train loss 1441.561, Time 2.361 (s)
Epoch 160: Train loss 1417.290, Time 2.326 (s)
Epoch 180: Train loss 1397.883, Time 2.419 (s)
Epoch 200: Train loss 1299.629, Time 2.515 (s)


In [6]:
Z = model.sample_latent(train_x, train_c)
Z.shape

[60000, 10]

In [7]:
from umap import UMAP

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
umap = UMAP()
manifold = umap.fit_transform(np.asarray(Z))

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [9]:
import plotly.express as px 
px.scatter(x=manifold[:, 0], y=manifold[:, 1], color=train_y)

In [12]:
z = mx.random.normal([40, 10])
y = mx.arange(10)
c = mx.zeros((y.size, 10))
c[mx.arange(len(y)), y] = 1
c = mx.concatenate([c]*4, axis=0)
rec = model.decoder(z, c)
rec_image = rec.reshape(-1, 28 ,28)
import plotly.express as px 
px.imshow(rec_image, facet_col=0, color_continuous_scale='gray', facet_col_wrap=10)