In [1]:
import jax
import jax.numpy as jnp

In [2]:
func = lambda x: x**2

jit_func = jax.jit(func)

In [3]:
from pprint import pprint

pprint(jit_func.lower(10.0).as_text())

('module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, '
 'mhlo.num_replicas = 1 : i32} {\n'
 '  func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> '
 '{jax.result_info = ""}) {\n'
 '    %0 = stablehlo.multiply %arg0, %arg0 : tensor<f32>\n'
 '    return %0 : tensor<f32>\n'
 '  }\n'
 '}\n')


In [4]:
jit_grad_func = jax.jit(jax.grad(func))

pprint(jit_grad_func.lower(10.0).as_text())

('module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, '
 'mhlo.num_replicas = 1 : i32} {\n'
 '  func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> '
 '{jax.result_info = ""}) {\n'
 '    %cst = stablehlo.constant dense<2.000000e+00> : tensor<f32>\n'
 '    %0 = stablehlo.multiply %cst, %arg0 : tensor<f32>\n'
 '    %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>\n'
 '    %1 = stablehlo.multiply %cst_0, %0 : tensor<f32>\n'
 '    return %1 : tensor<f32>\n'
 '  }\n'
 '}\n')


In [5]:
from flax import nnx

class MLP(nnx.Module):
    def __init__(self, rngs, output=10):
        self.dense_0 = nnx.Linear(in_features=1000, out_features=output, rngs=rngs)
        self.dense_1 = nnx.Linear(in_features=1000, out_features=output, rngs=rngs)
        self.dense_2 = nnx.Linear(in_features=1000, out_features=output, rngs=rngs)
    
    def __call__(self, x):
        x_0 = nnx.relu(self.dense_0(x))
        x_1 = nnx.relu(self.dense_1(x))
        x_2 = nnx.relu(self.dense_2(x))
        
        return x_0 + x_1 + x_2


rngs = nnx.Rngs(0)

mlp = MLP(rngs)

inputs = jnp.ones((10, 1000))

mlp(inputs).shape

(10, 10)

In [6]:
class Linear:
    def __init__(self, in_features, out_features, num_policies, seed=0):
        
        self.weights = jax.random.uniform(jax.random.key(seed), shape=(num_policies, in_features, out_features)) # policies
        self.bias = jax.random.uniform(jax.random.key(seed), shape=(num_policies, 1, out_features))
    
    def __call__(self, x):
        dot_prod = jax.vmap(jnp.dot, in_axes=(None, 0))(x, self.weights)
        return jax.nn.relu(dot_prod) + self.bias

in_features = 1000
batch_size = 64

linear_jax = Linear(in_features=in_features, out_features=10, num_policies=10)

inputs = jax.random.normal(jax.random.key(0), shape=(batch_size, in_features)) # (64, 1000)
outputs = linear_jax(inputs)
outputs = jnp.argmax(outputs, axis=-1, keepdims=True)

In [7]:
labels = jnp.array([1,] * batch_size)[:, None] # (batch_size, 1)
labels.shape

(64, 1)

In [87]:
from jax.experimental import checkify
from functools import partial

@partial(jax.jit)
def policy_evaluation(outputs, labels):
    def _returns(outputs, labels):
        return jnp.sum(outputs == labels[None, :, None], axis=(-1, -2))
    
    # should have l (batch_size, 1) and out (num_policies, batch_size, 1)
    # jax.experimental.checkify.check(labels.shape == outputs.shape[1:], f"wrong shapes {labels.shape=} {outputs.shape[1:]}")
    # outputs = model(inputs)
    outputs = jnp.argmax(outputs, axis=-1, keepdims=True)
    return _returns(outputs, labels)


def policy_improvement():
    # take the best
    pass


# jit_policy_eval = jax.jit(checkify.checkify(policy_evalutation))
# jit_policy_eval(outputs, labels)[1].shape

In [9]:
# best_policy = jnp.argmax(jit_policy_eval(outputs, labels)[1])
# worst_policy = jnp.argmin(jit_policy_eval(outputs, labels)[1])

# print("accuracy of the best policy")
# print(jit_policy_eval(outputs, labels)[1][best_policy] / batch_size)

# print("accuracy of the worst policy")
# print(jit_policy_eval(outputs, labels)[1][worst_policy] / batch_size)

In [30]:
# class MLP_jax:
#     def __init__(self, in_features, out_features, seed=0):
#         self.dense_0 = Linear(in_features=in_features, out_features=1000, seed=seed)
#         self.dense_1 = Linear(in_features=1000, out_features=1000, seed=seed)
#         self.dense_2 = Linear(in_features=1000, out_features=out_features, seed=seed)
    
#     def __call__(self, x):
#         x = self.dense_0(x)
#         x = jax.nn.relu(x)
#         x = self.dense_1(x)
#         x = jax.nn.relu(x)
#         return self.dense_2(x)

# mlp = MLP_jax(in_features=in_features, out_features=10)

# mlp(inputs).shape


In [38]:
from tensorflow.keras.datasets import mnist
from tensorflow.data import Dataset, AUTOTUNE

def get_mnsit_dataloaders(batch_size):
    # Load the MNIST dataset
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    
    # Normalize the images to [0, 1]
    x_train = x_train.astype("float32") / 255.0
    x_test = x_test.astype("float32") / 255.0
    
    x_train = x_train.reshape(x_train.shape[0], -1)
    x_test = x_test.reshape(x_test.shape[0], -1)
    
    train_dataset = Dataset.from_tensor_slices((x_train, y_train)).shuffle(buffer_size=1024).batch(batch_size=batch_size).prefetch(AUTOTUNE)
    test_dataset = Dataset.from_tensor_slices((x_test, y_test)).shuffle(buffer_size=1024).batch(batch_size=batch_size).prefetch(AUTOTUNE)
    
    return train_dataset, test_dataset

2025-01-20 09:30:05.372698: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1737361805.507647 2109905 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1737361805.545180 2109905 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [61]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train.shape

(60000, 28, 28)

In [63]:
from jax.tree_util import tree_flatten

In [79]:
from jax.tree_util import tree_map

# @partial(jax.vmap, in_axes=({"w": 0, "b": 0}, None)) # vmap over policies
def dense(params, x):
    return jnp.matmul(x, params["w"]) #+ params["b"][:, None]


def init(in_features, hidden_dim, out_features, num_layers=4):
    shape_in = (in_features, hidden_dim)
    shape_hid = ((hidden_dim, hidden_dim),) * (num_layers - 2)
    shape_out = (hidden_dim, out_features)
    
    shapes = (shape_in, *shape_hid, shape_out)
    params = {f"dense_{layer}": {"w": (), "b": ()} for layer in range(num_layers)}
    
    for shape, layer in zip(shapes, params.values()):
        in_f, out_f = shape
        print(f"{shape=}")
        sigma = 2 / (in_f + out_f)
        layer["w"] = jnp.sqrt(sigma) * jax.random.normal(jax.random.key(0), shape=shape)
        layer["b"] = jnp.sqrt(sigma) * jax.random.normal(jax.random.key(0), shape=shape[slice(0, len(shape), len(shape)-1)])
    
    return params

pars = {"w": jnp.ones((10, 100, 1)), "b": jnp.ones((10, 1))}
inputs = jnp.ones((64, 100))

dense(pars, inputs).shape

(10, 64, 1)

In [15]:
x = jnp.ones((1, 64, 200))
x1 = jnp.ones((64, 200))
y = jnp.ones((10, 200, 1))
jnp.allclose(jnp.matmul(x, y), jnp.dot(x1, y).transpose(1, 0, 2))


Array(True, dtype=bool)

In [49]:
x = jnp.ones((64, 200))
y = jnp.ones((1, 200, 1))
z = jnp.ones((10, 1, 200))
f = jnp.matmul(x, y)
jnp.matmul(f, z).shape

(10, 64, 200)

In [91]:
params = init(in_features=100, hidden_dim=100, out_features=10)

@jax.jit
def forward(params, x):
    last = len(params.keys()) - 1
    for i, param in enumerate(params.values()):
        x = jax.nn.relu(dense(param, x)) if i != last else dense(param, x)
    return x

@jax.jit
def eval_step(params, x, labels):
    outputs = forward(params, x)
    outputs = jnp.argmax(outputs, axis=-1, keepdims=True)
    return jnp.mean(outputs == labels[:, None])

predicted = forward(params, inputs)
predicted.shape

shape=(100, 100)
shape=(100, 100)
shape=(100, 100)
shape=(100, 10)


(64, 10)

In [73]:
for i in reversed(params):
    print(params[i]["w"].shape)

(10, 100, 10)
(10, 100, 100)
(10, 100, 100)
(10, 100, 100)


In [36]:
jnp.var(predicted)

Array(0.19269249, dtype=float32)

In [96]:
import jax.numpy as jnp
import numpy as np
# from matplotlib.pyplot import matshow
# import matplotlib.pyplt as plt


train_dataset, test_dataset = get_mnsit_dataloaders(batch_size=512)
# initialize network
params = init(in_features=28*28, hidden_dim=100, out_features=10)

for epoch in range(10):
    print(f"{epoch=}")
    
    for param in reversed(params.keys()):
        print(f"{param=}")
        
        mean = 0 if epoch == 0 else params[param]["w"]
        std =  1
        
        for step, (imgs, labs) in enumerate(train_dataset):
            imgs, labs = jnp.array(imgs), jnp.array(labs)
            
            # explore
            num_policies = 1000
            shape_pars = (num_policies,) + params[param]["w"].shape
            
            params[param]["w"] = std * jax.random.normal(jax.random.key(0), shape=shape_pars) + mean
            params[param]["b"] = std * jax.random.normal(jax.random.key(0), shape=shape_pars) + mean
            
            # predict
            outputs = forward(params, imgs)
    #         break
    #     break
    # break
            # evaluate
            policy_returns = policy_evaluation(outputs, labs)
            mean = params[param]["w"][best_policy]
            
            best_policy = jnp.argmax(policy_returns, axis=-1)
            accuracy = jnp.max(policy_returns) / imgs.shape[0]
            
            print(f"For batch {step}")
            print(f"{best_policy=:,d} accuracy {accuracy=:.4f}")
            
            params[param]["w"] = params[param]["w"][best_policy]
            # params[param]["b"] = params[param]["b"][best_policy]
            
            # accuracy = jnp.max(policy_returns) / imgs.shape[0]
            
            # print(f"For batch {step}")
            # print(f"{best_policy=:,d} accuracy {accuracy=:.4f}")
    
    val_accuracy = []
    for step, (imgs, labs) in enumerate(test_dataset):
        imgs, labs = jnp.array(imgs), jnp.array(labs)
        
        val_accuracy.append(eval_step(params, imgs, labs))

    print(f"Validation accuracy: {np.mean(val_accuracy):.2f}")


shape=(784, 100)
shape=(100, 100)
shape=(100, 100)
shape=(100, 10)
epoch=0
param='dense_3'
For batch 0
best_policy=13 accuracy accuracy=0.1816
For batch 1
best_policy=58 accuracy accuracy=0.1484
For batch 2
best_policy=58 accuracy accuracy=0.1855
For batch 3
best_policy=13 accuracy accuracy=0.2168
For batch 4
best_policy=59 accuracy accuracy=0.1953
For batch 5
best_policy=79 accuracy accuracy=0.2031
For batch 6
best_policy=92 accuracy accuracy=0.2363
For batch 7
best_policy=33 accuracy accuracy=0.2285
For batch 8
best_policy=91 accuracy accuracy=0.2012
For batch 9
best_policy=24 accuracy accuracy=0.1953
For batch 10
best_policy=2 accuracy accuracy=0.2051
For batch 11
best_policy=27 accuracy accuracy=0.2148
For batch 12
best_policy=96 accuracy accuracy=0.2559
For batch 13
best_policy=33 accuracy accuracy=0.2422
For batch 14
best_policy=3 accuracy accuracy=0.2207
For batch 15
best_policy=30 accuracy accuracy=0.2598
For batch 16
best_policy=85 accuracy accuracy=0.2402
For batch 17
best_po

In [48]:
import jax.numpy as jnp
# from matplotlib.pyplot import matshow
# import matplotlib.pyplt as plt


train_dataset, test_dataset = get_mnsit_dataloaders(batch_size=512)
# initialize network
params = init(in_features=28*28, hidden_dim=100, out_features=10)

mean = 0
std = 0.1

for epoch in range(10):
    print(f"{epoch=}")
    
    for param in reversed(params.keys()):
        print(f"{param=}")
        
        mean = 0 if epoch == 0 else params[param]["w"]
        std = 0.1 #if epoch == 0 else 
        
        for step, (imgs, labs) in enumerate(train_dataset):
            imgs, labs = jnp.array(imgs), jnp.array(labs)
            
            # explore
            num_policies = 100
            shape_pars = (num_policies,) + params[param]["w"].shape
            
            params[param]["w"] = std * jax.random.normal(jax.random.key(0), shape=shape_pars) + mean
            params[param]["b"] = std * jax.random.normal(jax.random.key(0), shape=shape_pars) + mean
            
            # predict
            outputs = forward(params, imgs)
            
            # evaluate
            policy_returns = policy_evaluation(outputs, labs)
            print(params[param]["w"][best_policy].shape)
            mean = params[param]["w"][best_policy]
            
            best_policy = jnp.argmax(policy_returns, axis=-1)
            accuracy = jnp.max(policy_returns) / imgs.shape[0]
            
            print(f"For batch {step}")
            print(f"{best_policy=:,d} accuracy {accuracy=:.4f}")
            
            params[param]["w"] = params[param]["w"][best_policy]
            # params[param]["b"] = params[param]["b"][best_policy]
            
            # accuracy = jnp.max(policy_returns) / imgs.shape[0]
            
            # print(f"For batch {step}")
            # print(f"{best_policy=:,d} accuracy {accuracy=:.4f}")
        


Array([16,  9, 13, 10, 10, 21, 15, 12, 13, 15], dtype=int32)

In [54]:
class PolicySampler:
    def __init__(self, shape):
        self.shape = shape

    def normal_sample(self, seed, mean, std):
        return std * jax.random.normal(jax.random.key(seed), shape=self.shape) + mean
    
    def uniform_sample(self, seed, min, max):
        return jax.random.uniform(jax.random.key(seed), shape=self.shape, minval=min, maxval=max)