In [1]:
import pandas as pd
import jax.numpy as jnp
import numpy as np
from jax import jit
from matplotlib import pyplot as plt

In [3]:
# read lc
lc = pd.read_parquet('data/lc1.parquet')
lc.head(2)

Unnamed: 0,object_id,passband,mjd,flux,flux_norm
0,13,0,59577.0,0.125404,-0.242782
1,13,0,59592.521127,0.243606,-0.167201


In [6]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
# from flax import optim
import optax

class Encoder(nn.Module):
 latents: int

 @nn.compact
 def __call__(self, x):
   x = nn.Dense(500, name='fc1')(x)
   x = nn.relu(x)
   mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
   logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)
   return mean_x, logvar_x

In [7]:
class Decoder(nn.Module):

 @nn.compact
 def __call__(self, z):
   z = nn.Dense(500, name='fc1')(z)
   z = nn.relu(z)
   z = nn.Dense(784, name='fc2')(z)
   return z

In [47]:
class VAE(nn.Module):
 latents: int = 20

 def setup(self):
   self.encoder = Encoder(self.latents)
   self.decoder = Decoder()

 def __call__(self, x, z_rng):
   mean, logvar = self.encoder(x)
   z = reparameterize(z_rng, mean, logvar)
   recon_x = self.decoder(z)
   return recon_x, mean, logvar

def reparameterize(rng, mean, logvar):
 std = jnp.exp(0.5 * logvar)
 eps = random.normal(rng, logvar.shape)
 return mean + eps * std

def model():
 return VAE(latents=LATENTS)

In [48]:
@jax.vmap
def kl_divergence(mean, logvar):
 return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))

@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
 logits = nn.log_sigmoid(logits)
 return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))


In [24]:
import tensorflow_datasets as tfds
import tensorflow as tf

tf.config.experimental.set_visible_devices([], 'GPU')

def prepare_image(x):
 x = tf.cast(x['image'], tf.float32)
 x = tf.reshape(x, (-1,))
 return x

ds_builder = tfds.builder('binarized_mnist')
ds_builder.download_and_prepare()
train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN)
train_ds = train_ds.map(prepare_image)
train_ds = train_ds.cache()
train_ds = train_ds.repeat()
train_ds = train_ds.shuffle(50000)
train_ds = train_ds.batch(BATCH_SIZE)
train_ds = iter(tfds.as_numpy(train_ds))

test_ds = ds_builder.as_dataset(split=tfds.Split.TEST)
test_ds = test_ds.map(prepare_image).batch(10000)
test_ds = np.array(list(test_ds)[0])

2024-08-09 11:46:12.056960: W external/local_tsl/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".


[1mDownloading and preparing dataset 104.68 MiB (download: 104.68 MiB, generated: Unknown size, total: 104.68 MiB) to /Users/weixiang/tensorflow_datasets/binarized_mnist/1.0.0...[0m


  from .autonotebook import tqdm as notebook_tqdm
Dl Size...: 100%|██████████| 102/102 [00:07<00:00, 12.92 MiB/s]]
Dl Completed...: 100%|██████████| 3/3 [00:07<00:00,  2.63s/ url]
                                                                        

[1mDataset binarized_mnist downloaded and prepared to /Users/weixiang/tensorflow_datasets/binarized_mnist/1.0.0. Subsequent calls will reuse this data.[0m


In [49]:
rng = random.PRNGKey(0)
rng, key = random.split(rng)
BATCH_SIZE = 10000
LEARNING_RATE = jnp.array(0.3)
LATENTS = jnp.array(4)
NUM_EPOCHS = 5

init_data = jnp.ones((BATCH_SIZE, 784), jnp.float32)
params = model().init(key, init_data, rng)['params']

solver = optax.adam(learning_rate=LEARNING_RATE)
opt_state = solver.init(params)
# optimizer = jax.device_put(optimizer)

rng, z_key, eval_rng = random.split(rng, 3)
z = random.normal(z_key, (64, LATENTS))

steps_per_epoch = 50000 // BATCH_SIZE


for epoch in range(NUM_EPOCHS):
  for _ in range(steps_per_epoch):
    batch = jnp.array(next(train_ds))
    rng, key = random.split(rng)

    def loss_fn(params):
        recon_x, mean, logvar = model().apply({'params': params}, batch, key)

        bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()
        kld_loss = kl_divergence(mean, logvar).mean()
        loss = bce_loss + kld_loss
        return loss, recon_x

    grad, _ = jax.grad(loss_fn)(params)
    updates, opt_state = solver.update(grad, opt_state, params)
    params = optax.apply_updates(params, updates)

TypeError: Gradient only defined for scalar-output functions. Output was (Array(594.6463, dtype=float32), Array([[ 0.61815804, -0.41764262,  0.70575905, ...,  0.04700679,
        -0.5066801 ,  0.95056593],
       [ 0.45177734, -0.86700225,  0.54775393, ..., -0.04440165,
        -0.42860085, -0.40191096],
       [ 0.2963302 , -0.89565736,  1.5676657 , ..., -0.01574579,
        -0.7308027 ,  0.44394773],
       ...,
       [ 0.13361356, -0.9049638 ,  1.3185327 , ..., -0.23299226,
        -0.6748854 , -0.06925374],
       [ 0.41827285, -0.54509956,  0.6054094 , ...,  0.13635437,
        -0.5661339 ,  0.573411  ],
       [ 0.5138072 , -0.25701073,  0.07712258, ..., -0.00742109,
        -0.32133615,  0.01739584]], dtype=float32)).