goal 

- code for running denoising autoencoder 
- replicate 1d, 2d example that r(x)-x approximates the score of data distribution.




1. What regularized auto-encoders learn from the data-generating distribution

In [None]:
import numpy as onp

import jax
import jax.numpy as np
from jax import random
from flax import linen as nn
from flax.training import train_state
import optax

from functools import partial

import matplotlib.pylab as plt
import matplotlib as mpl
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 25
mpl.rcParams['font.family'] = 'Times New Roman'
cmap = plt.cm.get_cmap('bwr')

from jax_utils import get_data_stream

In [None]:
## parameters

# number of data points
n = 1000

# dimension of inputs
d = 2

# dimension of encoder/decoder mapping
encoder_dims = [200,100]
decoder_dims = encoder_dims[::-1][1:] + [d]
print(encoder_dims, decoder_dims)

# \in {'l2_norm', 'bce'}
criterion_name = 'l2_norm'
# \in {'tanh', 'relu'}
nonlinearity_name = 'relu'

# optimization
lr = .002
batch_size = 50
n_epochs = 100


In [None]:
# https://github.com/google/flax/blob/main/examples/vae/train.py
#
# x --Encoder--> z --Decoder --> x_recon
#

def nonlinearity():
    if nonlinearity_name == 'tanh':
        return nn.tanh
    elif nonlinearity_name == 'relu':
        return nn.relu
    else:
        raise ValueError(f'{nonlinearity_name} not valid.')

class Encoder(nn.Module):

    @nn.compact
    def __call__(self, x):
        for i, dim in enumerate(encoder_dims[:-1]):
            x = nn.Dense(dim, name=f'fc{i+1}')(x)
            x = nonlinearity()(x)
        x = nn.Dense(encoder_dims[-1], name=f'fc{len(encoder_dims)}')(x)
        return x

class Decoder(nn.Module):

    @nn.compact
    def __call__(self, x):
        for i, dim in enumerate(decoder_dims[:-1]):
            x = nn.Dense(dim, name=f'fc{i+1}')(x)
            x = nonlinearity()(x)
        x = nn.Dense(decoder_dims[-1], name=f'fc{len(decoder_dims)}')(x)
        return x

class Autoencoder(nn.Module):
    
    def setup(self):
        self.encoder = Encoder()
        self.decoder = Decoder()

    def __call__(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon

    def score_apx(self, x):
        x_recon = self.__call__(x)
        return x_recon - x
    

@jax.vmap
def bce_with_logits(logits, labels):
    logits = nn.log_sigmoid(logits)
    loss = -np.sum(labels * logits + (1. - labels) * np.log(-np.expm1(logits)))
    return loss

def l2_loss(predictions, targets):
    return np.linalg.norm(predictions-targets) ** 2

def get_criterion():
    if criterion_name == 'l2_norm':
        return l2_loss
    elif criterion_name == 'bce':
        return (lambda x, y: bce_with_logits(x, y).mean())
    else:
        raise ValueError(f'{criterion_name} not valid.')

def model():
    return Autoencoder()


In [None]:
def func_1d(key):
    x = random.uniform(key, shape=(n,1), minval=-1, maxval=1)
    y = np.sin(x*3*3.14) + 0.3*np.cos(x*9*3.14) + 0.5*np.sin(x*7*3.14)
    return x, y

def func_2d(key, minval=3, maxval=12):
    t = random.uniform(key, shape=(n,1), minval=minval, maxval=maxval)
    x = 0.04*np.sin(t)*t
    y = 0.04*np.cos(t)*t
    xy = np.hstack((x,y))
    return xy

σ = .06
mask_ratio = .5

def noise_additive_gaussian(key, X):
    return X + σ*random.normal(key, X.shape)

def noise_mask(key, X):
    mask = ( random.uniform(key, X.shape) < mask_ratio ).astype(np.float32)
    return mask*X

key = random.PRNGKey(0)
X = func_2d(key)
X̃ = noise_additive_gaussian(key, X)
fig, axs = plt.subplots(1, 2, figsize=(12,6))
ax = axs[0]
ax.scatter(X[:,0], X[:,1])
ax.grid()
ax.set_xlim((-1,1))
ax.set_ylim((-1,1))
ax.set_title(r'$X$')

ax = axs[1]
ax.scatter(X̃[:,0], X̃[:,1], alpha=.1)
ax.grid()
ax.set_xlim((-1,1))
ax.set_ylim((-1,1))
ax.set_title(r'$\tilde{X}$')


In [None]:
@jax.jit
def train_step(state, batch, rng):
    def loss_fn(params):
        X̃ = noise_additive_gaussian(rng, batch)
        x_recon = model().apply(params, X̃)
        loss = get_criterion()(x_recon, batch)
        return loss, x_recon
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, x_recon), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss, x_recon


@jax.jit
def ae_recon_jit(state, X):
    return model().apply(state.params, X)

@jax.jit
def ae_score_apx_jit(state, X):
    return model().apply(state.params, X,
                         method=Autoencoder.score_apx)


def plt_score_apx_one(ax, state, xlim, ylim, n_point_per_side = 30):
    xgrid, ygrid = np.meshgrid(np.linspace(xlim[0],xlim[1],n_point_per_side),
                               np.linspace(ylim[0],ylim[1],n_point_per_side))
    Xgrid = np.vstack((xgrid.flatten(),
                       ygrid.flatten())).T
    score_apx = ae_score_apx_jit(state, Xgrid)

    Xt = func_2d(key, minval=3, maxval=12)
    Xt_recon = ae_recon_jit(state, Xt)

    ax.scatter(Xt[:,0], Xt[:,1], color='b', label=r'$X$')
    ax.scatter(Xt_recon[:,0], Xt_recon[:,1], color='r', label=r'$\tilde{X}$')
    ax.quiver(Xgrid[:,0], Xgrid[:,1], score_apx[:,0], score_apx[:,1])
    ax.grid()
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.legend(loc='upper left')
    ax.set_title(f'{np.linalg.norm(score_apx, axis=-1).mean():.4f}')
    

def plt_score_apx(state):
    
    fig, axs = plt.subplots(1, 2, figsize=(15,7))
    xlim = np.array([-1,1])
    ylim = np.array([-1,1])
    plt_score_apx_one(axs[0], state, xlim, ylim)
    plt_score_apx_one(axs[1], state, xlim*.3, ylim*.3)
    
    return fig


In [None]:


x_init = np.ones((1,d), np.float32)
params = model().init(key, x_init)

state = train_state.TrainState.create(
    apply_fn=model().apply,
    params=params,
    tx=optax.adam(lr))

n_batches, batches = get_data_stream(
    key, batch_size, X)

for epoch in range(n_epochs):
    for j in range(n_batches):
        batch = next(batches)
        key, rng = random.split(key)
        state, loss, x_recon = train_step(state, batch, rng)
        
    if epoch%(n_epochs//10)==0:
        print(f'[{epoch}] loss={loss}')
        fig = plt_score_apx(state)