In [1]:
from jax import random
from datasets.mnist import MNISTDataset
from src.autoencoders.simple_vae import model
from src.utils.autoencoder_trainer import AutoencoderTrainer
from src.utils.autoencoder_manager import get_latent_dataset, restore_model_state
import tensorflow_datasets as tfds

In [2]:
rng = random.PRNGKey(0)
binary_vae = model(latents=20)
input_shape = (64, 196)
learning_rate = 3e-4

# Load MNIST dataset for training and testing
train_dataset = MNISTDataset(split='train', batch_size=64, image_size=(14, 14)).load()
test_dataset = MNISTDataset(split='test', batch_size=64, image_size=(14, 14)).load()

# Create an instance of the trainer with your binary VAE model
trainer = AutoencoderTrainer(binary_vae, learning_rate, rng, input_shape)
state = trainer.state
params = trainer.params

In [3]:
ckpt_dir = "weights/binary_vae_500epoch_3e-4lr/checkpoint_461"
state = restore_model_state(ckpt_dir, state)

print("Weights loaded successfully.")

Checkpoint restored from directory '/Users/uribagi/Documents/GitHub/Latent-IQP/weights/binary_vae_500epoch_3e-4lr/checkpoint_461'.
Weights loaded successfully.




In [4]:
train_latents, train_labels = get_latent_dataset(binary_vae, state.params, tfds.as_numpy(train_dataset))

In [6]:
print(train_latents[0])
print(train_labels[0])

[0 0 1 1 1 1 0 1 0 1 1 0 0 0 0 1 1 0 1 1]
4


In [8]:
print(train_latents.shape)

(60000, 20)


In [9]:
test_latents, test_labels = get_latent_dataset(binary_vae, state.params, tfds.as_numpy(test_dataset))

In [10]:
print(test_latents[0])
print(test_labels[0])

[0 0 1 1 1 0 1 0 0 1 0 1 0 0 0 1 0 0 0 0]
2


In [11]:
print(test_latents.shape)

(10000, 20)


In [14]:
import numpy as np

np.save("datasets/mnist_latent_train.npy", train_latents)
np.save("datasets/mnist_labels_train.npy", train_labels)

np.save("datasets/mnist_latent_test.npy", test_latents)
np.save("datasets/mnist_labels_test.npy", test_labels)

In [15]:
from datasets.latent_mnist import LatentDataset

In [18]:
# Example of loading a saved dataset
latent_file = './datasets/mnist_latent_train.npy'
label_file = './datasets/mnist_labels_train.npy'
dataset = LatentDataset(latent_file, label_file, batch_size=8)

In [21]:
ds = dataset.load()

In [25]:
print(ds.take(1))

<_TakeDataset element_spec=(TensorSpec(shape=(None, 20), dtype=tf.int64, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))>


In [26]:
for d in ds:
    print(d)
    break

(<tf.Tensor: shape=(8, 20), dtype=int64, numpy=
array([[0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1],
       [0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
       [0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0],
       [1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1],
       [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0],
       [1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0],
       [0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0],
       [0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0]])>, <tf.Tensor: shape=(8,), dtype=int64, numpy=array([0, 5, 9, 7, 4, 3, 5, 3])>)
