In [8]:
# An implementation of Masked Autoencoder for Density Estimation
# inspired by https://github.com/karpathy/pytorch-made

import jax.numpy as np
import jax.random as random
import jax.nn as nn
from jax import device_put, grad, jit, random, vmap

from jax.experimental import optimizers

import numpy as onp
import numpy.random as onpr

from tqdm import trange

In [9]:
mnist = onp.load('./binarized_mnist.npz')
xtr, xte = mnist['train_data'], mnist['valid_data']

In [10]:
def get_mask(nin,
             hidden_sizes, seed, natural_ordering):
    # Return a randomly generated mask
    # the mask generated is a deterministic function of the seed
    
    rng = onpr.RandomState(seed)
    L = len(hidden_sizes)
    
    m = {}
    
    # Sample the order of the input
    m[-1] = onp.arange(nin) if natural_ordering else rng.permutation(nin)
    
    # Sample the connectivity of all hidden layers
    for l in range(L):
        
        # For each unit in layer l,
        # it can be connected to at most nin 
        # and at least m[l-1] dimension in the input
        m[l] = rng.randint(
            m[l-1].min(), nin-1, 
            size=hidden_sizes[l]
        )
        
    # Construct the mask matrices
    masks = []
    for l in range(L):
        
        # The mask at each hidden layer is 1_{m^l >= m^{l-1}}
        masks.append(
            m[l-1][:, None] <= m[l][None, :]
        )
        
    # Construct the mask at the output layer
    # 1_{d > m^L}
    masks.append(
        m[L-1][:, None] < m[-1][None, :]
    )
        
    return masks

In [11]:
# Define parameters and optim
D = xtr.shape[1]

DEBUG = False

if DEBUG:
    hid_size = 4
    hidden_sizes = [] 
    sizes = [D, D]

else:
    hid_size = 500
    hidden_sizes = [hid_size] 
    sizes = [D, hid_size, D]

# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

key = random.PRNGKey(0)
params = init_network_params(sizes, key)

init_fnc, update_fnc, get_params = optimizers.adam(step_size=0.001)

opt_state = init_fnc(params)

In [12]:
# Define forward pass of the model
def forward(ipt, ps, ms):
    # ps, ms are param and mask respectively
    
    activation = ipt 
    
    for (w, b), m in zip(ps[:-1], ms[:-1]):
        
        masked_weight = np.multiply(w, m.T)
        out = np.dot(masked_weight, activation) + b
        activation = nn.relu(out)
        
    final_w, final_b = ps[-1]
    final_m = ms[-1]
    
    masked_w = np.multiply(final_w, final_m.T)
    logits = np.dot(masked_w, activation) + final_b
    return logits 

b_forward = vmap(forward, in_axes=(0, None, None))

In [13]:
# Define loss function
def binary_cross_entropy_with_logits(target, logit):
    
    eps = 1e-7
    
    # y log sigmoid(x) + (1 - y) log (1 - sigmoid(x))
    sig = nn.sigmoid(logit)
    
    # clipping for numerical stability
    sig = np.clip(sig, eps, 1.0 - eps)
    
    return - (target * np.log(sig) + (1.0 - target) * np.log(1.0 - sig))


b_binary_cross_entropy_with_logits = vmap(binary_cross_entropy_with_logits, 
                                          in_axes=(0, 0))

def loss_fnc(params, masks, b_x):
    
    b_logits = b_forward(b_x, params, masks)
    loss = b_binary_cross_entropy_with_logits(b_x.flatten(), b_logits.flatten())
    loss = np.sum(loss) / len(b_x)
            
    return loss

In [14]:
# define the training loop

N = xtr.shape[0]
B = 100
nsteps = N//B 

xte_subset = xte[:5]


from jax.experimental.optimizers import l2_norm



def loss_fnc_with_reg(params, masks, b_x):
    
    loss = loss_fnc(params, masks, b_x)
    
    return loss + 1e-4 * l2_norm(params)


@jit
def update(update_idx, opt_state, masks, b_x):
    
    params = get_params(opt_state)

    grads = grad(loss_fnc_with_reg)(params, masks, b_x)   

    return update_fnc(update_idx, grads, opt_state)


@jit
def j_loss_fnc(opt_state, masks, b_x):
    
    params = get_params(opt_state)
    loss = loss_fnc(params, masks, b_x)
    return loss


masks = get_mask(D, hidden_sizes, 1, True)


for epoch in trange(51):
    
    losses = []
    for step in range(nsteps):

        b_x = xtr[step*B:step*B+B]
        
        loss = j_loss_fnc(opt_state, masks, b_x)
        losses.append(loss)

        opt_state = update(epoch * nsteps + step, opt_state, masks, b_x)
    
    if epoch % 10 == 0:
        print('epoch', epoch)
        print('train loss: ', np.mean(np.array(losses)))
        print('test loss: ', j_loss_fnc(opt_state, masks, xte))

  0%|          | 0/51 [00:00<?, ?it/s]

epoch 0
train loss:  216.67427


  2%|▏         | 1/51 [00:01<01:17,  1.56s/it]

test loss:  152.27048


 22%|██▏       | 11/51 [00:06<00:23,  1.70it/s]

epoch 10
train loss:  102.56361
test loss:  104.3317


 41%|████      | 21/51 [00:12<00:16,  1.77it/s]

epoch 20
train loss:  99.15046
test loss:  102.308304


 61%|██████    | 31/51 [00:18<00:11,  1.79it/s]

epoch 30
train loss:  97.79139
test loss:  101.80868


 80%|████████  | 41/51 [00:23<00:05,  1.81it/s]

epoch 40
train loss:  97.048065
test loss:  101.68934


100%|██████████| 51/51 [00:28<00:00,  1.76it/s]

epoch 50
train loss:  96.56069
test loss:  101.697495



