In [1]:
## Standard libraries
import os
import math
import numpy as np
import time

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgba
import seaborn as sns
sns.set()

## Progress bar
from tqdm.auto import tqdm

  set_matplotlib_formats('svg', 'pdf') # For export


In [2]:
import jax
import jax.numpy as jnp
print("Using jax", jax.__version__)

Using jax 0.4.7


## Pseudo random

In [3]:
rng = jax.random.PRNGKey(42)



In [4]:
# A non-desirable way of generating pseudo-random numbers...
jax_random_number_1 = jax.random.normal(rng)
jax_random_number_2 = jax.random.normal(rng)
print('JAX - Random number 1:', jax_random_number_1)
print('JAX - Random number 2:', jax_random_number_2)

# Typical random numbers in NumPy
np.random.seed(42)
np_random_number_1 = np.random.normal()
np_random_number_2 = np.random.normal()
print('NumPy - Random number 1:', np_random_number_1)
print('NumPy - Random number 2:', np_random_number_2)

JAX - Random number 1: -0.18471177
JAX - Random number 2: -0.18471177
NumPy - Random number 1: 0.4967141530112327
NumPy - Random number 2: -0.13826430117118466


In [5]:
rng, subkey1, subkey2 = jax.random.split(rng, num=3)  # We create 3 new keys
jax_random_number_1 = jax.random.normal(subkey1)
jax_random_number_2 = jax.random.normal(subkey2)
print('JAX new - Random number 1:', jax_random_number_1)
print('JAX new - Random number 2:', jax_random_number_2)
rng = jax.random.PRNGKey(42)

JAX new - Random number 1: 0.107961535
JAX new - Random number 2: -1.2226542


## Model

In [6]:
try:
    import flax
except ModuleNotFoundError: # Install flax if missing
    !pip install --quiet flax
    import flax

In [7]:
from flax import linen as nn

In [8]:
class SimpleClassifier(nn.Module):
  num_hidden : int   # Number of hidden neurons
  num_outputs : int  # Number of output neurons

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(features=self.num_hidden)(x)
    x = nn.tanh(x)
    x = nn.Dense(features=self.num_outputs)(x)
    return x

In [9]:
model = SimpleClassifier(num_hidden=8, num_outputs=1)
print(model)

SimpleClassifier(
    # attributes
    num_hidden = 8
    num_outputs = 1
)


In [10]:
rng, inp_rng, init_rng = jax.random.split(rng, 3)
inp = jax.random.normal(inp_rng, (8, 2))  # Batch size 8, input size 2
# Initialize the model
params = model.init(init_rng, inp)
print(params)
model.apply(params, inp)

FrozenDict({
    params: {
        Dense_0: {
            kernel: Array([[-0.8734889 ,  0.03292416,  0.45095628,  0.9860286 ,  0.9650168 ,
                    -0.50356966, -0.567441  , -0.32092765],
                   [ 0.6106076 , -0.8035141 , -0.8497237 , -1.0364467 ,  0.11642699,
                    -0.37274948, -0.06301995,  0.23880544]], dtype=float32),
            bias: Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
        },
        Dense_1: {
            kernel: Array([[-0.08973367],
                   [-0.15572299],
                   [ 0.12597609],
                   [-0.02248076],
                   [ 0.48822802],
                   [ 0.19107282],
                   [-0.32372728],
                   [-0.04857434]], dtype=float32),
            bias: Array([0.], dtype=float32),
        },
    },
})


Array([[ 0.6650056 ],
       [ 0.20904619],
       [ 0.30390516],
       [-0.4101298 ],
       [ 0.5975989 ],
       [-0.66803074],
       [-0.11172031],
       [-0.82014966]], dtype=float32)

## Data

In [None]:
import torch

In [None]:
class XORDataset(torch.utils.data.Dataset):
  
  def __init__(self, size, seed, std=0.1):
    '''
    Inputs:
      size - Number of data points we want to generate
      seed - The random seed
      std - The standard deviation of the noise
    '''
    super().__init__()
    self.size = size
    self.np_rng = np.random.RandomState(seed=seed)
    self.std = std
    self.generate_continuous_xor()
  
  def generate_continuous_xor(self):
    # xor class with noisy
    data = self.np_rng.randint(low=0, high=2, size=(self.size, 2)).astype(np.float32)
    label = (data.sum(axis=1) == 1).astype(np.int32)
    data += self.np_rng.normal(loc=0.0, scale=self.std, size=data.shape)
    self.data = data
    self.label = label
  
  def __len__(self):
    # Number of data point we have. Alternatively self.data.shape[0], or self.label.shape[0]
    return self.size
  
  def __getitem__(self, idx):
    # Return the idx-th data point of the dataset
    # If we have multiple things to return (data point and label), we can return them as tuple
    data_point = self.data[idx]
    data_label = self.label[idx]
    return data_point, data_label