In [1]:
#import matplotlib.pyplot as plt
from datetime import date
import jax
from jax import jit
import jax.numpy as jnp
import flax.linen as nn
import distrax
import jax.scipy.stats as stats
import optax
#from jax_resnet import ResNet18, pretrained_resnet
from resnet import *


2025-03-19 09:17:57.996867: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-03-19 09:17:58.131634: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-03-19 09:17:58.166612: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
model=ResNet(num_classes= 10,
               c_hidden= (16, 32, 64),
               num_blocks= (3, 3, 3),
               act_fn= nn.relu,
               block_class= ResNetBlock)

In [3]:
import numpy as np
from jax.tree_util import tree_map
from torch.utils import data
from torchvision.datasets import CIFAR10
from torchvision import transforms

mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]

def image_to_numpy(img):
    img = np.array(img, dtype=np.float32)
    img = (img / 255. - mean) / std
    return img

test_transform = image_to_numpy
# For training, we add some augmentation. Networks are too powerful and would overfit.
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop((32,32),scale=(0.8,1.0),ratio=(0.9,1.1)),
    image_to_numpy
])

In [4]:
cifar_dataset = CIFAR10('/tmp/cifar10/', download=True, transform=train_transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/cifar10/cifar-10-python.tar.gz


100%|████████████████████████████████████████| 170M/170M [00:22<00:00, 7.66MB/s]


Extracting /tmp/cifar10/cifar-10-python.tar.gz to /tmp/cifar10/


In [5]:
cifar_test = CIFAR10('/tmp/cifar10/', download=True, train=False, transform=test_transform)

Files already downloaded and verified


In [6]:
cifar_val = CIFAR10('/tmp/cifar10/', download=True, transform=test_transform)

Files already downloaded and verified


In [7]:
import torch
from torch.utils.data import DataLoader
from torch.utils import data

train_set, _ = torch.utils.data.random_split(cifar_dataset, [45000, 5000], generator=torch.Generator().manual_seed(42))
_, val_set = torch.utils.data.random_split(cifar_val, [45000, 5000], generator=torch.Generator().manual_seed(42))


In [8]:
from torch.utils.data import DataLoader
from torch.utils import data
import os

batch_size=256

# We need to stack the batch elements
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)
    
train_loader = data.DataLoader(train_set,
                               batch_size=batch_size,
                               shuffle=True,
                               drop_last=True,
                               collate_fn=numpy_collate,
                               num_workers=0,
                               )
val_loader   = data.DataLoader(val_set,
                               batch_size=batch_size,
                               shuffle=False,
                               drop_last=False,
                               collate_fn=numpy_collate,
                               num_workers=0,
                               )
test_loader  = data.DataLoader(cifar_test,
                               batch_size=batch_size,
                               shuffle=False,
                               drop_last=False,
                               collate_fn=numpy_collate,
                               num_workers=0,
                               )

In [9]:
X, y = next(iter(train_loader))
print("Batch mean", X.mean(axis=(0,1,2)))
print("Batch std", X.std(axis=(0,1,2)))

Batch mean [-0.01136768  0.00461017  0.0126301 ]
Batch std [0.92082045 0.92998571 0.95497953]


In [10]:
X.shape

(256, 32, 32, 3)

In [110]:
len(train_set)

45000

In [None]:
def log_prior(params):
    leaves, _ = jax.tree_util.tree_flatten(params)
    flat_params = jnp.concatenate([jnp.ravel(a) for a in leaves])
    return distrax.Normal(0., 10.).log_prob(flat_params).sum()/len(train_set)

def loglikelihood(params, batch_stats,batch,train=True):
    X, y = batch
    outs= model.apply({'params':params,'batch_stats':batch_stats}, 
                                         X,
                                         train=train,
                                         mutable=['batch_stats'] if train else False)
    logits, new_model_state = outs if train else (outs, None)
    dist=distrax.Categorical(logits=logits)
    ll=dist.log_prob(y).sum()
    return ll,new_model_state 

def log_posterior(params,batch_stats,batch):
    nll,new_model_state=loglikelihood(params,batch_stats,batch)
    return nll+log_prior(params),new_model_state

@jit
def acc_top1(params,stats,data_loader):
    y_pred=list()
    y_true=list()
    for batch in data_loader:
        X, y = batch
        X_batch=jnp.array(X)
        y_batch=jnp.array(y)
        prediction = model.apply({'params':params,'batch_stats':stats}, X_batch, train=False, mutable=False)
        y_pred.append(jnp.argmax(prediction, axis=1))
        y_true.append(y_batch)
    y_pred=jnp.concatenate(y_pred)
    y_true=jnp.concatenate(y_true)   
    return jnp.mean(y_pred == y_true)

In [100]:
from functools import partial 

grad_log_post=jax.jit(jax.value_and_grad(log_posterior,has_aux=True))

In [101]:
key = jax.random.PRNGKey(10)
model_key, data_key = jax.random.split(key)
variables = model.init(model_key, jnp.array(X),train=True)
params, batch_stats = variables['params'], variables['batch_stats']
ret,grads=grad_log_post(params,batch_stats,(X,y))

In [102]:
import orbax.checkpoint as ocp
import os
import shutil

curdir = os.getcwd()
ckpt_dir =os.path.join(curdir,'posterior_samples')

if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir)  # Remove any existing checkpoints from the last notebook run.

In [103]:
ckpt_dir

'/home/sergio/code/quantized_sgmcmc/posterior_samples'

In [104]:
from flax.training import orbax_utils
import orbax

ckpt = {'model': params, 'stats': batch_stats}
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=10, create=True)
mngr = ocp.CheckpointManager(ckpt_dir, options=options)

In [None]:
def run_sgmcmc(manager,key, Nsamples, init_fn, update, get_params,init_params,init_stats):
    "Run SGMCMC sampler and return the test accuracy list"
    loss = list()
    samples=list()
    params, batch_stats =init_params, init_stats
    key, subkey = jax.random.split(key)
    state = init_fn( params)
    num_iter=1
    for i in range(Nsamples):
        for batch in train_loader:
            key, subkey = jax.random.split(key)
            (nll,new_state),grads=grad_log_post(get_params(state),batch_stats,batch)
            batch_stats=new_state['batch_stats']
            state = update(num_iter, subkey, grads, state)
            loss.append(nll)
            position={'params':(get_params(state)),'batch_stats':batch_stats}
            #samples.append(position)
            manager.save(num_iter,args=ocp.args.StandardSave(position))
            num_iter+=1
        if (i%(Nsamples//10)==0):
            y_pred=list()
            y_true=list()
            for X,y in test_loader:
                X_batch=jnp.array(X)
                y_batch=jnp.array(y)
                preds = model.apply({'params':get_params(state),'batch_stats':batch_stats}, X_batch, train=False, mutable=False)
                y_pred.append(jnp.argmax(preds,axis=1))
                y_true.append(y_batch)
            y_pred=jnp.concatenate(y_pred)
            y_true=jnp.concatenate(y_true)   
            print('iteration {0}, loss {1:.2f}, test accuracy : {2:.2f}'.format(i,loss[-1],jnp.mean(y_pred == y_true)))  
    return position,loss
    

    

In [107]:
from sgmcmcjax.diffusions import sgld

learning_rate=1e-5
n_epochs=10
key = jax.random.PRNGKey(10)
init_fn, update, get_params = sgld(1e-5)
update = jit(update)

key = jax.random.PRNGKey(10)
model_key, data_key = jax.random.split(key)
variables = model.init(model_key, jnp.array(X),train=True)
params, batch_stats = variables['params'], variables['batch_stats']

In [108]:
position,loss=run_sgmcmc(mngr,key,n_epochs,init_fn,update,get_params,params,batch_stats)

iteration 0, loss -878058.69, test accuracy : 0.1111999973654747
iteration 1, loss -878051.31, test accuracy : 0.14569999277591705
iteration 2, loss -878064.50, test accuracy : 0.11180000007152557
iteration 3, loss -878068.81, test accuracy : 0.13300000131130219
iteration 4, loss -878048.06, test accuracy : 0.12700000405311584
iteration 5, loss -878058.75, test accuracy : 0.16509999334812164
iteration 6, loss -878068.00, test accuracy : 0.14499999582767487
iteration 7, loss -878072.94, test accuracy : 0.15719999372959137
iteration 8, loss -878085.81, test accuracy : 0.20679999887943268
iteration 9, loss -878089.38, test accuracy : 0.12889999151229858


In [77]:
mngr.all_steps()

[562491,
 562492,
 562493,
 562494,
 562495,
 562496,
 562497,
 562498,
 562499,
 562500]

In [None]:
restored = mngr.restore(562491)

In [86]:
y_pred=list()
y_test=list()
for X,y in train_loader:
    X_batch=jnp.array(X)
    y_batch=jnp.array(y)
    preds = model.apply({'params':restored['params'],'batch_stats':restored['batch_stats']}, X_batch, train=False, mutable=False)
    y_pred.append(jnp.argmax(preds,axis=1))
    y_test.append(y_batch)    

In [87]:
len(y_pred),len(y_test)

(5625, 5625)

In [88]:
y_pred=np.concatenate([p for p in y_pred])
y_test=np.concatenate([p for p in y_test])

In [89]:
y_test[:10]

array([8, 2, 4, 4, 4, 7, 0, 1, 5, 6], dtype=int32)

In [90]:
y_pred[:100]

array([7, 7, 7, 7, 2, 2, 9, 9, 7, 9, 9, 9, 7, 7, 9, 2, 7, 7, 9, 9, 9, 9,
       9, 9, 2, 9, 7, 9, 7, 9, 9, 9, 7, 9, 9, 7, 9, 2, 9, 9, 9, 7, 2, 9,
       2, 7, 2, 7, 7, 7, 2, 9, 9, 7, 7, 7, 9, 7, 7, 2, 7, 9, 9, 7, 7, 9,
       9, 7, 2, 9, 7, 2, 7, 9, 2, 2, 9, 9, 7, 9, 9, 7, 9, 7, 2, 2, 9, 7,
       7, 9, 2, 8, 9, 9, 7, 7, 9, 9, 9, 9], dtype=int32)

In [91]:
from sklearn.metrics import confusion_matrix

print(confusion_matrix(y_test,y_pred))


[[   0   19  355   12    9    3    0 2106    7 2001]
 [   1   37  756   25   20   10    0 1538    6 2095]
 [   0   31  385   10   73    4    0 1477    7 2481]
 [   0   66  783   10   70   11    0 1832    5 1752]
 [   0   23  452    2  164    3    0 1423    5 2457]
 [   0   30  615    8   41    5    0 1897   10 1880]
 [   2   47  694   11  171    5    0 1359    4 2200]
 [   0   41  675    9   54    6    0 1546   11 2158]
 [   0    8  453    8    7    5    1 1036    2 2976]
 [   0   16  793   18   12    1    1 1345    0 2313]]
