In [1]:
#import matplotlib.pyplot as plt
from datetime import date
import jax
from jax import jit
import jax.numpy as jnp
import flax.linen as nn
import distrax
import jax.scipy.stats as stats
import optax
#from jax_resnet import ResNet18, pretrained_resnet
from resnet import *
from kernels import *
from cifar_train_test import *

2025-03-19 09:51:32.474566: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-03-19 09:51:32.597020: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-03-19 09:51:32.629573: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/cifar10/cifar-10-python.tar.gz


100%|████████████████████████████████████████| 170M/170M [00:23<00:00, 7.36MB/s]


Extracting /tmp/cifar10/cifar-10-python.tar.gz to /tmp/cifar10/
Files already downloaded and verified
Files already downloaded and verified


In [32]:
import orbax.checkpoint as ocp
import os
import shutil
from flax.training import orbax_utils
import orbax


In [22]:
model=ResNet(num_classes= 10,
               c_hidden= (16, 32, 64),
               num_blocks= (3, 3, 3),
               act_fn= nn.relu,
               block_class= ResNetBlock)

In [23]:
X, y = next(iter(train_loader))
print("Batch mean", X.mean(axis=(0,1,2)))
print("Batch std", X.std(axis=(0,1,2)))

Batch mean [ 0.01642355 -0.01215357 -0.04992252]
Batch std [0.98853619 0.9731538  0.97723662]


In [24]:
def log_prior(params):
    leaves, _ = jax.tree_util.tree_flatten(params)
    flat_params = jnp.concatenate([jnp.ravel(a) for a in leaves])
    return distrax.Normal(0., 10.).log_prob(flat_params).sum()/len(train_set)

def loglikelihood(params, batch_stats,batch,train=True):
    X, y = batch
    outs= model.apply({'params':params,'batch_stats':batch_stats}, 
                                         X,
                                         train=train,
                                         mutable=['batch_stats'] if train else False)
    logits, new_model_state = outs if train else (outs, None)
    dist=distrax.Categorical(logits=logits)
    nll=-1.0*jnp.mean(dist.log_prob(y))
    return nll,new_model_state 

def log_posterior(params,batch_stats,batch):
    nll,new_model_state=loglikelihood(params,batch_stats,batch)
    return nll-log_prior(params),new_model_state

def acc_top1(params,stats,data_loader):
    y_pred=list()
    y_true=list()
    for batch in data_loader:
        X, y = batch
        X_batch=jnp.array(X)
        y_batch=jnp.array(y)
        prediction = model.apply({'params':params,'batch_stats':stats}, X_batch, train=False, mutable=False)
        y_pred.append(jnp.argmax(prediction, axis=1))
        y_true.append(y_batch)
    y_pred=jnp.concatenate(y_pred)
    y_true=jnp.concatenate(y_true)   
    return jnp.mean(y_pred == y_true)

In [25]:
from functools import partial 

grad_log_post=jax.jit(jax.value_and_grad(log_posterior,has_aux=True))

In [26]:
#key = jax.random.PRNGKey(10)
#model_key, data_key = jax.random.split(key)
#variables = model.init(model_key, jnp.array(X),train=True)
#params, batch_stats = variables['params'], variables['batch_stats']
#batch=(jnp.array(X),jnp.array(y))
#ret,grads=grad_log_post(params,batch_stats,batch)

In [49]:
curdir = os.getcwd()
ckpt_dir =os.path.join(curdir,'posterior_samples')

if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir)  

options = orbax.checkpoint.CheckpointManagerOptions(save_interval_steps=len(train_set)//batch_size,max_to_keep=100, create=True)
manager = ocp.CheckpointManager(ckpt_dir, options=options)

In [50]:
learning_rate=1e-5
n_epochs=200

key = jax.random.PRNGKey(10)
key, subkey = jax.random.split(key)
variables = model.init(subkey, jnp.array(X),train=True)
params, batch_stats = variables['params'], variables['batch_stats']
momemtum=jax.tree_util.tree_map(lambda p:jnp.zeros_like(p),params)
kernel=jit(psgld_momemtum)
num_iter=0
for j in range(n_epochs):
    for i,(X,y) in enumerate(train_loader):
        #learning_rate=learning_rate*0.99
        X_batch,y_batch=jnp.array(X),jnp.array(y)
        (nll,new_state),grads=grad_log_post(params,batch_stats,(X_batch,y_batch))
        batch_stats=new_state['batch_stats']
        key,params,momemtum=kernel(key,params,momemtum,grads,learning_rate)
        position={'params':params,'batch_stats':batch_stats}
        #samples.append(position)
        manager.save(num_iter,args=ocp.args.StandardSave(position))
        num_iter+=1
    if (j%(n_epochs//10)==0):
        logits=model.apply({'params':params,'batch_stats':batch_stats},X_batch,train=False,mutable=False)
        accuracy = (logits.argmax(axis=-1) == y_batch).mean()
        nll,_=loglikelihood(params,batch_stats,(X_batch,y_batch),train=False)
        val_accuracy=acc_top1(params,batch_stats,val_loader)
        print('Epoch {0}, Log-likelihood : {1:8.2f}, Train Accuracy : {2:8.2f}, Val Accuracy {3:8.2f}: '.
              format(j,nll,accuracy,val_accuracy))
    

Epoch 0, Log-likelihood :     2.11, Train Accuracy :     0.36, Val Accuracy     0.37: 
Epoch 20, Log-likelihood :     0.47, Train Accuracy :     0.84, Val Accuracy     0.72: 
Epoch 40, Log-likelihood :     0.34, Train Accuracy :     0.87, Val Accuracy     0.77: 
Epoch 60, Log-likelihood :     0.19, Train Accuracy :     0.94, Val Accuracy     0.81: 
Epoch 80, Log-likelihood :     0.36, Train Accuracy :     0.86, Val Accuracy     0.79: 
Epoch 100, Log-likelihood :     0.21, Train Accuracy :     0.93, Val Accuracy     0.81: 
Epoch 120, Log-likelihood :     0.31, Train Accuracy :     0.89, Val Accuracy     0.82: 
Epoch 140, Log-likelihood :     0.51, Train Accuracy :     0.85, Val Accuracy     0.78: 
Epoch 160, Log-likelihood :     0.24, Train Accuracy :     0.93, Val Accuracy     0.83: 
Epoch 180, Log-likelihood :     0.32, Train Accuracy :     0.89, Val Accuracy     0.81: 


In [51]:
acc_top1(params,batch_stats,test_loader)

Array(0.7507, dtype=float32)

In [52]:
last_ckpt=manager.all_steps()[0]
restored = manager.restore(last_ckpt)
acc_top1(restored['params'],restored['batch_stats'],test_loader)



Array(0.78639996, dtype=float32)

# Quantization

In [53]:
def quantization(x, s, z, alpha_q, beta_q):
    x_q = jnp.round(1 / s * x + z, decimals=0)
    x_q = jnp.clip(x_q, a_min=alpha_q, a_max=beta_q)
    return x_q.astype(jnp.uint8)


def quantization_int8(x, s, z):
    x_q = quantization(x, s, z, alpha_q=-128, beta_q=127)
    x_q = x_q.astype(jnp.int8)
    return x_q

def dequantization(x_q, s, z):
    # x_q - z might go outside the quantization range.
    x_q = x_q.astype(jnp.int32)
    x = s * (x_q - z)
    x = x.astype(jnp.float16)
    return x


def generate_quantization_constants_scale(alpha, beta, alpha_q, beta_q):
    # Affine quantization mapping
    s = (beta - alpha) / (beta_q - alpha_q)
    return s

def generate_quantization_constants_bias(alpha, beta, alpha_q, beta_q):
    # Affine quantization mapping
    z = jnp.int8((beta * alpha_q - alpha * beta_q) / (beta - alpha))
    return z


In [54]:
sgld_samples = [manager.restore(ckpt) for ckpt in manager.all_steps()]







In [139]:
def tree_stack(trees):
    return jax.tree.map(lambda *v: jnp.stack(v), *trees)

def tree_unstack(tree):
    leaves, treedef = jax.tree.flatten(tree)
    return [treedef.unflatten(leaf) for leaf in zip(*leaves, strict=True)]

stacked_samples=tree_stack(sgld_samples)
alpha=jax.tree.map(lambda p:jnp.min(p,axis=0),stacked_samples)
beta=jax.tree.map(lambda p:jnp.max(p,axis=0),stacked_samples)
b=8
alpha_q = 0
beta_q = 255
s=jax.tree.map(lambda a,b:generate_quantization_constants_scale(a,b,alpha_q,beta_q),alpha,beta)
z=jax.tree.map(lambda a,b:generate_quantization_constants_bias(a,b,alpha_q,beta_q),alpha,beta)

In [140]:
quantized_stacked_samples=jax.tree.map(lambda x,s,z:quantization(x,s,z,alpha_q,beta_q),stacked_samples,s,z)

In [141]:
dequantized_stacked_samples=jax.tree.map(
    lambda x,s,z:dequantization(x,s,z),quantized_stacked_samples,s,z)

In [156]:
jax.tree.flatten(jax.tree.map(lambda p,q:jnp.linalg.norm(p-q),dequantized_stacked_samples,stacked_samples))[0]

[Array(0.18904808, dtype=float32),
 Array(2.406212, dtype=float32),
 Array(45.370583, dtype=float32),
 Array(264.3012, dtype=float32),
 Array(200.23672, dtype=float32),
 Array(104.53427, dtype=float32),
 Array(94.76575, dtype=float32),
 Array(174.6066, dtype=float32),
 Array(229.02715, dtype=float32),
 Array(367.3588, dtype=float32),
 Array(103.81532, dtype=float32),
 Array(523.71796, dtype=float32),
 Array(208.92838, dtype=float32),
 Array(405.59525, dtype=float32),
 Array(333.01276, dtype=float32),
 Array(633.8947, dtype=float32),
 Array(259.0445, dtype=float32),
 Array(73.722664, dtype=float32),
 Array(375.2109, dtype=float32),
 Array(85.95269, dtype=float32),
 Array(401.93546, dtype=float32),
 Array(91.125656, dtype=float32),
 Array(728.72687, dtype=float32),
 Array(181.37878, dtype=float32),
 Array(170.66014, dtype=float32),
 Array(83.59627, dtype=float32),
 Array(817.59534, dtype=float32),
 Array(460.04758, dtype=float32),
 Array(182.50426, dtype=float32),
 Array(1694.9888, dtype

In [142]:
dequantized_samples=tree_unstack(dequantized_stacked_samples)

In [143]:
fp32_size=sum(jax.tree.flatten(jax.tree.map(lambda p:p.nbytes,stacked_samples))[0])*1e-6

In [144]:
int8_size=sum(jax.tree.flatten(jax.tree.map(lambda p:p.nbytes,quantized_stacked_samples))[0])*1e-6

In [145]:
print('Tamaño Muestras FP32 : {0:.2f}Mb'.format(fp32_size))
print('Tamaño Muestras INT8 : {0:.2f}Mb'.format(int8_size))

Tamaño Muestras FP32 : 109.50Mb
Tamaño Muestras INT8 : 27.38Mb


# Pruebas Clasificación

In [63]:
def ensemble_acc_top1(params,stats,data_loader):
    y_pred=list()
    y_true=list()
    for batch in data_loader:
        X, y = batch
        X_batch=jnp.array(X)
        y_batch=jnp.array(y)
        prediction = model.apply({'params':params,'batch_stats':stats}, X_batch, train=False, mutable=False)
        y_pred.append(prediction)
        y_true.append(y_batch)
    y_pred=jnp.concatenate(y_pred)
    y_true=jnp.concatenate(y_true)   
    return y_pred,y_true

In [73]:
ensemble_pred=list()
for sample in sgld_samples:
    y_pred,y_true = ensemble_acc_top1(sample['params'],sample['batch_stats'],test_loader)
    ensemble_pred.append(y_pred)

In [75]:
ensemble_pred=jnp.stack(ensemble_pred)

In [77]:
mean_pred=jnp.mean(ensemble_pred,axis=0)

In [80]:
mean_pred.argmax(axis=1).shape

(10000,)

In [81]:
from sklearn.metrics import classification_report

print(classification_report(np.array(y_true), np.asarray(mean_pred.argmax(axis=1))))

              precision    recall  f1-score   support

           0       0.88      0.91      0.89      1000
           1       0.94      0.96      0.95      1000
           2       0.85      0.83      0.84      1000
           3       0.75      0.75      0.75      1000
           4       0.89      0.85      0.87      1000
           5       0.85      0.80      0.82      1000
           6       0.86      0.94      0.90      1000
           7       0.93      0.91      0.92      1000
           8       0.94      0.93      0.94      1000
           9       0.91      0.93      0.92      1000

    accuracy                           0.88     10000
   macro avg       0.88      0.88      0.88     10000
weighted avg       0.88      0.88      0.88     10000



In [85]:
quantized_ensemble_pred=list()
for sample in dequantized_samples:
    y_pred,y_true = ensemble_acc_top1(sample['params'],sample['batch_stats'],test_loader)
    quantized_ensemble_pred.append(y_pred)

In [86]:
quantized_ensemble_pred=jnp.stack(quantized_ensemble_pred)

In [87]:
qmean_pred=jnp.mean(quantized_ensemble_pred,axis=0)

In [88]:
print(classification_report(np.array(y_true), np.asarray(qmean_pred.argmax(axis=1))))

              precision    recall  f1-score   support

           0       0.00      0.00      0.00      1000
           1       0.00      0.00      0.00      1000
           2       0.00      0.00      0.00      1000
           3       0.10      1.00      0.18      1000
           4       0.00      0.00      0.00      1000
           5       0.00      0.00      0.00      1000
           6       0.00      0.00      0.00      1000
           7       0.00      0.00      0.00      1000
           8       0.00      0.00      0.00      1000
           9       0.00      0.00      0.00      1000

    accuracy                           0.10     10000
   macro avg       0.01      0.10      0.02     10000
weighted avg       0.01      0.10      0.02     10000



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [146]:
y_pred,y_true=ensemble_acc_top1(dequantized_samples[0]['params'],dequantized_samples[0]['batch_stats'],test_loader)

In [147]:
jnp.mean(jnp.argmax(y_pred,axis=1)==y_true)

Array(0.09999999, dtype=float32)