In [1]:
#import matplotlib.pyplot as plt
from datetime import date
import jax
from jax import jit
import jax.numpy as jnp
import flax.linen as nn
import distrax
import jax.scipy.stats as stats
import optax
#from jax_resnet import ResNet18, pretrained_resnet
from resnet import *


2025-01-30 16:20:04.722445: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2025-01-30 16:20:04.722506: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [2]:
model=ResNet(num_classes= 10,
               c_hidden= (16, 32, 64),
               num_blocks= (3, 3, 3),
               act_fn= nn.relu,
               block_class= ResNetBlock)

In [8]:
import numpy as np
from jax.tree_util import tree_map
from torch.utils import data
from torchvision.datasets import CIFAR10
from torchvision import transforms

mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]

def image_to_numpy(img):
    img = np.array(img, dtype=np.float32)
    img = (img / 255. - mean) / std
    return img

test_transform = image_to_numpy
# For training, we add some augmentation. Networks are too powerful and would overfit.
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop((32,32),scale=(0.8,1.0),ratio=(0.9,1.1)),
    image_to_numpy
])

In [9]:
cifar_dataset = CIFAR10('/tmp/cifar10/', download=True, transform=train_transform)

Files already downloaded and verified


In [10]:
cifar_test = CIFAR10('/tmp/cifar10/', download=True, train=False, transform=test_transform)

Files already downloaded and verified


In [11]:
cifar_val = CIFAR10('/tmp/cifar10/', download=True, transform=test_transform)

Files already downloaded and verified


In [12]:
import torch
from torch.utils.data import DataLoader
from torch.utils import data

train_set, _ = torch.utils.data.random_split(cifar_dataset, [45000, 5000], generator=torch.Generator().manual_seed(42))
_, val_set = torch.utils.data.random_split(cifar_val, [45000, 5000], generator=torch.Generator().manual_seed(42))


In [13]:
from torch.utils.data import DataLoader
from torch.utils import data
import os

batch_size=8

# We need to stack the batch elements
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)
    
train_loader = data.DataLoader(train_set,
                               batch_size=batch_size,
                               shuffle=True,
                               drop_last=True,
                               collate_fn=numpy_collate,
                               num_workers=0,
                               )
val_loader   = data.DataLoader(val_set,
                               batch_size=batch_size,
                               shuffle=False,
                               drop_last=False,
                               collate_fn=numpy_collate,
                               num_workers=0,
                               )
test_loader  = data.DataLoader(cifar_test,
                               batch_size=batch_size,
                               shuffle=False,
                               drop_last=False,
                               collate_fn=numpy_collate,
                               num_workers=0,
                               )

In [14]:
X, y = next(iter(train_loader))
print("Batch mean", X.mean(axis=(0,1,2)))
print("Batch std", X.std(axis=(0,1,2)))

Batch mean [0.27015958 0.20856956 0.12457092]
Batch std [0.96578211 1.0372434  1.04271615]


In [15]:
X.shape

(8, 32, 32, 3)

In [16]:
def log_prior(params):
    return 1.

def loglikelihood(params, batch_stats,batch,train=True):
    X, y = batch
    outs= model.apply({'params':params,'batch_stats':batch_stats}, 
                                         X,
                                         train=train,
                                         mutable=['batch_stats'] if train else False)
    logits, new_model_state = outs if train else (outs, None)
    dist=distrax.Categorical(logits=logits)
    ll=jnp.mean(dist.log_prob(y))
    return ll,new_model_state 

def log_posterior(params,batch_stats,batch):
    nll,new_model_state=loglikelihood(params,batch_stats,batch)
    return nll-log_prior(params),new_model_state

@jit
def acc_top1(params,stats,data_loader):
    y_pred=list()
    y_true=list()
    for batch in enumerate(data_loader):
        X, y = batch
        X_batch=jnp.array(X)
        y_batch=jnp.array(y)
        prediction = model.apply({'params':params,'batch_stats':stats}, X_batch, train=False, mutable=False)
        y_pred.append(jnp.argmax(prediction, axis=1))
        y_true.append(y_batch)
    y_pred=jnp.concatenate(y_pred)
    y_true=jnp.concatenate(y_true)   
    return jnp.mean(y_pred == y_true)

In [17]:
from functools import partial 

grad_log_post=jax.jit(jax.value_and_grad(log_posterior,has_aux=True))

In [18]:
key = jax.random.PRNGKey(10)
model_key, data_key = jax.random.split(key)
variables = model.init(model_key, jnp.array(X),train=True)
params, batch_stats = variables['params'], variables['batch_stats']
ret,grads=grad_log_post(params,batch_stats,(X,y))

In [19]:
import orbax.checkpoint as ocp
import os
import shutil

curdir = os.getcwd()
ckpt_dir =os.path.join(curdir,'posterior_samples')

if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir)  # Remove any existing checkpoints from the last notebook run.

In [20]:
ckpt_dir

'/home/shernandez/quantized_sgmcmc/posterior_samples'

In [21]:
from flax.training import orbax_utils
import orbax

ckpt = {'model': params, 'stats': batch_stats}
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=10, create=True)
checkpoint_manager = orbax.checkpoint.CheckpointManager(ckpt_dir, orbax_checkpointer, options)



In [28]:
import numpy as np
import orbax.checkpoint as ocp


def run_sgmcmc(key, Nsamples, init_fn, update, get_params,init_params,init_stats):
    "Run SGMCMC sampler and return the test accuracy list"
    loss = list()
    samples=list()
    params, batch_stats =init_params, init_stats
    key, subkey = jax.random.split(key)
    state = init_fn( params)
    num_iter=1
    for i in range(Nsamples):
        for batch in train_loader:
            key, subkey = jax.random.split(key)
            print(batch[0].shape)
            (nll,new_state),grads=grad_log_post(get_params(state),batch_stats,batch)
            batch_stats=new_state['batch_stats']
            state = update(num_iter, subkey, grads, state)
            loss.append(nll)
            position={'params':(get_params(state)),'batch_stats':batch_stats}
            samples.append(position)
            num_iter+=1
        #if (i%(Nsamples//10)==0):
            print('iteration {0}, loss {1:.2f}'.format(i,loss[-1]))
        #if (i%(Nsamples//10)==0):
        #    logits=model.apply({'params':params,'batch_stats':batch_stats},X_batch,train=False,mutable=False)
        #    accuracy = (logits.argmax(axis=-1) == y_batch).mean()
        #    nll,_=loglikelihood(params,batch_stats,(X_batch,y_batch),train=False)
        #    print('Epoch {0}, Log-likelihood : {1:8.2f}, Accuracy {2:8.2f}: '.format(j,nll,accuracy))
    return samples,loss
    

    

In [29]:
from sgmcmcjax.diffusions import sgld

learning_rate=1e-5
n_epochs=100
key = jax.random.PRNGKey(10)
init_fn, update, get_params = sgld(1e-5)
update = jit(update)

key = jax.random.PRNGKey(10)
model_key, data_key = jax.random.split(key)
variables = model.init(model_key, jnp.array(X),train=True)
params, batch_stats = variables['params'], variables['batch_stats']

In [30]:
run_sgmcmc(key,n_epochs,init_fn,update,get_params,params,batch_stats)

(8, 32, 32, 3)


2025-01-30 16:27:46.773483: W external/xla/xla/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 13.18MiB (rounded to 13816320)requested by op 
2025-01-30 16:27:47.326498: W external/xla/xla/tsl/framework/bfc_allocator.cc:494] ****************************************************************************************************
E0130 16:27:47.326576 3886244 pjrt_stream_executor_client.cc:2985] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 13816248 bytes.


ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 13816248 bytes.