In [2]:
#import matplotlib.pyplot as plt
from datetime import date
import jax
from jax import jit
import jax.numpy as jnp
import flax.linen as nn
import distrax
import jax.scipy.stats as stats
import optax
#from jax_resnet import ResNet18, pretrained_resnet
from resnet import *
from kernels import *


In [3]:
model=ResNet(num_classes= 10,
               c_hidden= (16, 32, 64),
               num_blocks= (3, 3, 3),
               act_fn= nn.relu,
               block_class= ResNetBlock)

In [4]:
import numpy as np
from jax.tree_util import tree_map
from torch.utils import data
from torchvision.datasets import CIFAR10
from torchvision import transforms

mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]

def image_to_numpy(img):
    img = np.array(img, dtype=np.float32)
    img = (img / 255. - mean) / std
    return img

test_transform = image_to_numpy
# For training, we add some augmentation. Networks are too powerful and would overfit.
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop((32,32),scale=(0.8,1.0),ratio=(0.9,1.1)),
    image_to_numpy
])

In [5]:
cifar_dataset = CIFAR10('/tmp/cifar10/', download=True, transform=train_transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/cifar10/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:11<00:00, 14283757.09it/s]


Extracting /tmp/cifar10/cifar-10-python.tar.gz to /tmp/cifar10/


In [6]:
cifar_test = CIFAR10('/tmp/cifar10/', download=True, train=False, transform=test_transform)

Files already downloaded and verified


In [9]:
cifar_val = CIFAR10('/tmp/cifar10/', download=True, transform=test_transform)

Files already downloaded and verified


In [10]:
import torch
from torch.utils.data import DataLoader
from torch.utils import data

train_set, _ = torch.utils.data.random_split(cifar_dataset, [45000, 5000], generator=torch.Generator().manual_seed(42))
_, val_set = torch.utils.data.random_split(cifar_val, [45000, 5000], generator=torch.Generator().manual_seed(42))


In [11]:
from torch.utils.data import DataLoader
from torch.utils import data
import os

batch_size=32

# We need to stack the batch elements
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)
    
train_loader = data.DataLoader(train_set,
                               batch_size=batch_size,
                               shuffle=True,
                               drop_last=True,
                               collate_fn=numpy_collate,
                               num_workers=0,
                               )
val_loader   = data.DataLoader(val_set,
                               batch_size=batch_size,
                               shuffle=False,
                               drop_last=False,
                               collate_fn=numpy_collate,
                               num_workers=0,
                               )
test_loader  = data.DataLoader(cifar_test,
                               batch_size=batch_size,
                               shuffle=False,
                               drop_last=False,
                               collate_fn=numpy_collate,
                               num_workers=0,
                               )

In [12]:
X, y = next(iter(train_loader))
print("Batch mean", X.mean(axis=(0,1,2)))
print("Batch std", X.std(axis=(0,1,2)))

Batch mean [ 0.05059145 -0.00235559  0.05920585]
Batch std [0.97349558 0.98742023 0.98531993]


In [13]:
def log_prior(params):
    return 1.

def loglikelihood(params, batch_stats,batch,train=True):
    X, y = batch
    outs= model.apply({'params':params,'batch_stats':batch_stats}, 
                                         X,
                                         train=train,
                                         mutable=['batch_stats'] if train else False)
    logits, new_model_state = outs if train else (outs, None)
    dist=distrax.Categorical(logits=logits)
    nll=-1.0*jnp.mean(dist.log_prob(y))
    return nll,new_model_state 

def log_posterior(params,batch_stats,batch):
    nll,new_model_state=loglikelihood(params,batch_stats,batch)
    return nll-log_prior(params),new_model_state

@jit
def acc_top1(params,stats,data_loader):
    y_pred=list()
    y_true=list()
    for batch in enumerate(data_loader):
        X, y = batch
        X_batch=jnp.array(X)
        y_batch=jnp.array(y)
        prediction = model.apply({'params':params,'batch_stats':stats}, X_batch, train=False, mutable=False)
        y_pred.append(jnp.argmax(prediction, axis=1))
        y_true.append(y_batch)
    y_pred=jnp.concatenate(y_pred)
    y_true=jnp.concatenate(y_true)   
    return jnp.mean(y_pred == y_true)

In [14]:
from functools import partial 

grad_log_post=jax.jit(jax.value_and_grad(log_posterior,has_aux=True))

In [15]:
key = jax.random.PRNGKey(10)
model_key, data_key = jax.random.split(key)
variables = model.init(model_key, jnp.array(X),train=True)
params, batch_stats = variables['params'], variables['batch_stats']
batch=(jnp.array(X),jnp.array(y))
ret,grads=grad_log_post(params,batch_stats,batch)

In [32]:
import orbax.checkpoint as ocp
import os
import shutil

curdir = os.getcwd()
ckpt_dir =os.path.join(curdir,'posterior_samples')

if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir)  # Remove any existing checkpoints from the last notebook run.

In [34]:
ckpt_dir

'/Users/sergio/code/quantized_sgmcmc/posterior_samples'

In [39]:
from flax.training import orbax_utils
import orbax

ckpt = {'model': params, 'stats': batch_stats}
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=10, create=True)
checkpoint_manager = orbax.checkpoint.CheckpointManager(ckpt_dir, orbax_checkpointer, options)

In [None]:
import numpy as np
import orbax.checkpoint as ocp

learning_rate=1e-5
n_epochs=20
num_iter=0

key = jax.random.PRNGKey(10)
key, subkey = jax.random.split(key)
variables = model.init(subkey, jnp.array(X),train=True)
params, batch_stats = variables['params'], variables['batch_stats']
momemtum=jax.tree_util.tree_map(lambda p:jnp.zeros_like(p),params)
kernel=jit(psgld_momemtum)
num_iter=0
for j in range(n_epochs):
    for i,(X,y) in enumerate(train_loader):
        #learning_rate=learning_rate*0.99
        X_batch,y_batch=jnp.array(X),jnp.array(y)
        (nll,new_state),grads=grad_log_post(params,batch_stats,(X_batch,y_batch))
        batch_stats=new_state['batch_stats']
        key,params,momemtum=kernel(key,params,momemtum,grads,learning_rate)
        ckpt = {'model': params, 'stats': batch_stats}
        checkpoint_manager.save(num_iter, ckpt)
        num_iter+=1
    logits=model.apply({'params':params,'batch_stats':batch_stats},X_batch,train=False,mutable=False)
    accuracy = (logits.argmax(axis=-1) == y_batch).mean()
    nll,_=loglikelihood(params,batch_stats,(X_batch,y_batch),train=False)
    print('Epoch {0}, Log-likelihood : {1:8.2f}, Accuracy {2:8.2f}: '.format(j,nll,accuracy))
    

Epoch 0, Log-likelihood :     1.24, Accuracy     0.53: 


In [33]:
batch_stats['batch_stats'].keys()

dict_keys(['BatchNorm_0', 'ResNetBlock_0', 'ResNetBlock_1', 'ResNetBlock_2', 'ResNetBlock_3', 'ResNetBlock_4', 'ResNetBlock_5', 'ResNetBlock_6', 'ResNetBlock_7', 'ResNetBlock_8'])