In [1]:
import time
import mlx.core as mx 
import mlx.nn as nn 
import numpy as np
from mnist import mnist
from tqdm import tqdm
from model_cnn import VAE
from loss import vae_loss
from umap import UMAP
import mlx.optimizers as optim

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_x, train_y, test_x, test_y = map(mx.array, mnist())
train_x = train_x.reshape(-1, 28, 28, 1)
test_x = test_x.reshape(-1, 28, 28, 1)
train_c = mx.zeros((train_y.size, 10))
train_c[mx.arange(len(train_y)), train_y] = 1
test_c = mx.zeros((test_y.size, 10))
test_c[mx.arange(len(test_y)), test_y] = 1

In [3]:
train_y.shape

[60000]

In [4]:
model = VAE(train_x.shape[1:], latent_dim=10)
mx.eval(model.parameters())
model

VAE(
  (encoder): Encoder(
    (conv1): Conv2d(1, 32, kernel_size=[4], stride=(2, 2), padding=(1, 1), bias=True)
    (conv2): Conv2d(32, 32, kernel_size=[4], stride=(2, 2), padding=(1, 1), bias=True)
    (conv3): Conv2d(32, 32, kernel_size=[4], stride=(2, 2), padding=(2, 2), bias=True)
    (dense1): Linear(input_dims=512, output_dims=256, bias=True)
    (dense2): Linear(input_dims=256, output_dims=256, bias=True)
    (lin_mu): Linear(input_dims=256, output_dims=10, bias=True)
    (lin_logvar): Linear(input_dims=256, output_dims=10, bias=True)
  )
  (decoder): Decoder(
    (dense1): Linear(input_dims=10, output_dims=256, bias=True)
    (dense2): Linear(input_dims=256, output_dims=256, bias=True)
    (dense3): Linear(input_dims=256, output_dims=784, bias=True)
  )
)

In [5]:
loss_and_grad_fn = nn.value_and_grad(model, vae_loss)
optimizer = optim.Adam(learning_rate=2*1e-4)

In [6]:
def batch_iterate(batch_size, X, y):
    perm = mx.array(np.random.permutation(y.shape[0]))
    for s in range(0, y.shape[0], batch_size):
        ids = perm[s : s + batch_size]
        yield X[ids], y[ids]
BATCH_SIZE = 1024
EPOCHS = 200
np.random.seed(42)
for e in range(EPOCHS):
        tic = time.perf_counter()
        for X, _ in tqdm(batch_iterate(BATCH_SIZE, train_x, train_c),
                         total=train_x.shape[0] / BATCH_SIZE,
                         desc=f"Training process for epoch {e + 1}",
                         leave=False):
            loss, grads = loss_and_grad_fn(model, X,)
            optimizer.update(model, grads)
            mx.eval(model.parameters(), optimizer.state)
        # accuracy = eval_fn(model, test_images, test_labels)
        toc = time.perf_counter()
        if (e + 1) % 20 == 0:
            print(
                f"Epoch {e + 1}: Train loss {loss.item():.3f},"
                f" Time {toc - tic:.3f} (s)"
            )

                                                                                    

Epoch 20: Train loss 3298.037, Time 2.375 (s)


                                                                                    

Epoch 40: Train loss 2281.728, Time 2.507 (s)


                                                                                    

Epoch 60: Train loss 1956.166, Time 2.397 (s)


                                                                                    

Epoch 80: Train loss 1688.627, Time 2.455 (s)


                                                                                     

Epoch 100: Train loss 1713.417, Time 2.428 (s)


                                                                                     

Epoch 120: Train loss 1611.070, Time 2.452 (s)


                                                                                     

Epoch 140: Train loss 1493.925, Time 2.427 (s)


  full_bar = Bar(frac,
                                                                                     

Epoch 160: Train loss 1479.731, Time 2.430 (s)


                                                                                     

Epoch 180: Train loss 1460.610, Time 2.499 (s)


                                                                                     

Epoch 200: Train loss 1376.512, Time 2.449 (s)




In [7]:
model

VAE(
  (encoder): Encoder(
    (conv1): Conv2d(1, 32, kernel_size=[4], stride=(2, 2), padding=(1, 1), bias=True)
    (conv2): Conv2d(32, 32, kernel_size=[4], stride=(2, 2), padding=(1, 1), bias=True)
    (conv3): Conv2d(32, 32, kernel_size=[4], stride=(2, 2), padding=(2, 2), bias=True)
    (dense1): Linear(input_dims=512, output_dims=256, bias=True)
    (dense2): Linear(input_dims=256, output_dims=256, bias=True)
    (lin_mu): Linear(input_dims=256, output_dims=10, bias=True)
    (lin_logvar): Linear(input_dims=256, output_dims=10, bias=True)
  )
  (decoder): Decoder(
    (dense1): Linear(input_dims=10, output_dims=256, bias=True)
    (dense2): Linear(input_dims=256, output_dims=256, bias=True)
    (dense3): Linear(input_dims=256, output_dims=784, bias=True)
  )
)

In [8]:
Z = model.sample_latent(train_x)
Z.shape

[60000, 10]

In [9]:
umap = UMAP()
manifold = umap.fit_transform(np.asarray(Z))

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [10]:
import plotly.express as px 
px.scatter(x=manifold[:, 0], y=manifold[:, 1], color=train_y)

In [12]:
z = mx.random.normal([40, 10])
rec = model.decoder(z)
rec_image = rec.reshape(-1, 28 ,28)
import plotly.express as px 
px.imshow(rec_image, facet_col=0, color_continuous_scale='gray', facet_col_wrap=10)