In [None]:
def quantization(x, s, z, alpha_q, beta_q):
    x_q = jnp.round(1 / s * x + z, decimals=0)
    x_q = jnp.clip(x_q, a_min=alpha_q, a_max=beta_q)
    return x_q.astype(jnp.uint8)


def quantization_int8(x, s, z):
    x_q = quantization(x, s, z, alpha_q=-128, beta_q=127)
    x_q = x_q.astype(jnp.int8)
    return x_q

def dequantization(x_q, s, z):
    # x_q - z might go outside the quantization range.
    x_q = x_q.astype(jnp.int32)
    x = s * (x_q - z)
    x = x.astype(jnp.float16)
    return x


def generate_quantization_constants_scale(alpha, beta, alpha_q, beta_q):
    # Affine quantization mapping
    s = (beta - alpha) / (beta_q - alpha_q)
    return s

def generate_quantization_constants_bias(alpha, beta, alpha_q, beta_q):
    # Affine quantization mapping
    z = jnp.int8((beta * alpha_q - alpha * beta_q) / (beta - alpha))
    return z


In [None]:
data = ckpt_load.restore(1, args=restore_args)
params = data.state['model']
batch_stats = data.state['stats']
# sgld_samples = model.apply({'params':params,'batch_stats':batch_stats},X_batch,train=False,mutable=False)
sgld_samples = params

In [None]:
jax.tree.map(lambda p:p.shape,sgld_samples)

{'BatchNorm_0': {'bias': (16,), 'scale': (16,)},
 'Conv_0': {'kernel': (3, 3, 3, 16)},
 'Dense_0': {'bias': (10,), 'kernel': (64, 10)},
 'ResNetBlock_0': {'BatchNorm_0': {'bias': (16,), 'scale': (16,)},
  'BatchNorm_1': {'bias': (16,), 'scale': (16,)},
  'Conv_0': {'kernel': (3, 3, 16, 16)},
  'Conv_1': {'kernel': (3, 3, 16, 16)}},
 'ResNetBlock_1': {'BatchNorm_0': {'bias': (16,), 'scale': (16,)},
  'BatchNorm_1': {'bias': (16,), 'scale': (16,)},
  'Conv_0': {'kernel': (3, 3, 16, 16)},
  'Conv_1': {'kernel': (3, 3, 16, 16)}},
 'ResNetBlock_2': {'BatchNorm_0': {'bias': (16,), 'scale': (16,)},
  'BatchNorm_1': {'bias': (16,), 'scale': (16,)},
  'Conv_0': {'kernel': (3, 3, 16, 16)},
  'Conv_1': {'kernel': (3, 3, 16, 16)}},
 'ResNetBlock_3': {'BatchNorm_0': {'bias': (32,), 'scale': (32,)},
  'BatchNorm_1': {'bias': (32,), 'scale': (32,)},
  'Conv_0': {'kernel': (3, 3, 16, 32)},
  'Conv_1': {'kernel': (3, 3, 32, 32)},
  'Conv_2': {'bias': (32,), 'kernel': (1, 1, 16, 32)}},
 'ResNetBlock_4':

In [None]:
params

{'BatchNorm_0': {'bias': Array([ 0.00022452,  0.00321693,  0.00076451, -0.00253368, -0.01168925,
         -0.00172721,  0.00936044,  0.00894074, -0.00487682, -0.00907283,
          0.00326797, -0.00600898, -0.00438914,  0.00268972, -0.00438866,
         -0.00393576], dtype=float32),
  'scale': Array([1.0051334 , 1.0034027 , 1.0008631 , 0.997909  , 0.9977386 ,
         0.9995442 , 1.0100495 , 1.0070015 , 1.0009319 , 0.9835908 ,
         1.0032496 , 0.9994465 , 0.99563104, 1.0039631 , 0.9959028 ,
         0.99217033], dtype=float32)},
 'Conv_0': {'kernel': Array([[[[ 1.73774987e-01,  7.49341846e-02, -2.26074651e-01,
             4.62646224e-02, -4.94431816e-02, -1.18665025e-01,
             9.58214924e-02, -1.09561734e-01,  1.72167465e-01,
            -1.20421678e-01, -1.89866364e-01,  8.00102726e-02,
             1.36617005e-01, -1.27204031e-01,  1.19669989e-01,
             1.49182215e-01],
           [-2.84515798e-01, -7.84979165e-02, -1.34562626e-01,
            -2.48099923e-01, -2.1

In [None]:
def tree_stack(trees):
    return jax.tree.map(lambda *v: jnp.stack(v), *trees)

def tree_unstack(tree):
    leaves, treedef = jax.tree.flatten(tree)
    return [treedef.unflatten(leaf) for leaf in zip(*leaves, strict=True)]

In [None]:
stacked_samples=tree_stack(sgld_samples)

TypeError: Error interpreting argument to <function broadcast_in_dim at 0x7375702dd080> as an abstract array. The problematic value is of type <class 'str'> and was passed to the function at path args[0].
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.

In [None]:
stacked_samples

{'BatchNorm_0': {'bias': Array([[ 0.00022452,  0.00321693,  0.00076451, -0.00253368, -0.01168925,
          -0.00172721,  0.00936044,  0.00894074, -0.00487682, -0.00907283,
           0.00326797, -0.00600898, -0.00438914,  0.00268972, -0.00438866,
          -0.00393576]], dtype=float32),
  'scale': Array([[1.0051334 , 1.0034027 , 1.0008631 , 0.997909  , 0.9977386 ,
          0.9995442 , 1.0100495 , 1.0070015 , 1.0009319 , 0.9835908 ,
          1.0032496 , 0.9994465 , 0.99563104, 1.0039631 , 0.9959028 ,
          0.99217033]], dtype=float32)},
 'Conv_0': {'kernel': Array([[[[[ 1.73774987e-01,  7.49341846e-02, -2.26074651e-01,
              4.62646224e-02, -4.94431816e-02, -1.18665025e-01,
              9.58214924e-02, -1.09561734e-01,  1.72167465e-01,
             -1.20421678e-01, -1.89866364e-01,  8.00102726e-02,
              1.36617005e-01, -1.27204031e-01,  1.19669989e-01,
              1.49182215e-01],
            [-2.84515798e-01, -7.84979165e-02, -1.34562626e-01,
             -2.

block quantization

In [None]:
alpha=jax.tree.map(lambda p:jnp.min(p),stacked_samples)
beta=jax.tree.map(lambda p:jnp.max(p),stacked_samples)
b=8
alpha_q = 0
beta_q = 255
s=jax.tree.map(lambda a,b:generate_quantization_constants_scale(a,b,alpha_q,beta_q),alpha,beta)
z=jax.tree.map(lambda a,b:generate_quantization_constants_bias(a,b,alpha_q,beta_q),alpha,beta)


In [None]:
quantized_stacked_samples=jax.tree.map(lambda x,s,z:quantization(x,s,z,alpha_q,beta_q),stacked_samples,s,z)

In [None]:
jax.tree.map(lambda p:p.dtype,quantized_stacked_samples)

dtype('uint8')

In [None]:
jax.tree.map(lambda p:p.dtype,stacked_samples)

dtype('float32')

In [None]:
dequantized_stacked_samples=jax.tree.map(
    lambda x,s,z:dequantization(x,s,z),quantized_stacked_samples,s,z)

In [None]:
dequantized_samples=tree_unstack(dequantized_stacked_samples)

In [None]:
jax.tree.map(lambda p,q:jnp.linalg.norm(p-q),dequantized_stacked_samples,stacked_samples)

Array(149.76529, dtype=float32)