In [None]:
!pip3 install -U optax dm-haiku -q

[K     |████████████████████████████████| 122kB 13.2MB/s 
[K     |████████████████████████████████| 286kB 20.3MB/s 
[K     |████████████████████████████████| 61kB 6.7MB/s 
[?25h

In [None]:
import jax
import jax.numpy as jnp
from jax import random
import optax
import haiku as hk

In [None]:
def softmax_cross_entropy(logits, labels):
  one_hot = jax.nn.one_hot(labels, logits.shape[-1])
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)

In [None]:
num_devices = jax.local_device_count()

In [None]:
 
from typing import Generator, Mapping, Tuple
 
from absl import app
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds
 
Batch = Mapping[str, np.ndarray]

In [None]:
!python3 --version

Python 3.7.10


In [None]:
def load_dataset(
    split: str,
    *,
    is_training: bool,
    batch_size: int,
) -> Generator[Batch, None, None]:
  """Loads the dataset as a generator of batches."""
  ds = tfds.load("mnist:3.*.*", split=split).cache().repeat()
  if is_training:
    ds = ds.shuffle(10 * batch_size, seed=0)
  ds = ds.batch(batch_size)
  return iter(tfds.as_numpy(ds))
 
train = load_dataset("train", is_training=True, batch_size=1000)
train_eval = load_dataset("train", is_training=False, batch_size=10000)
test_eval = load_dataset("test", is_training=False, batch_size=10000)

[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…



[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m


In [None]:
class MLP(hk.Module):
  def __init__(self,
               hidden_dim_1: int = 512,
               hidden_dim_2: int = 128):
    
    super(MLP, self).__init__()
    self.h1 = hidden_dim_1
    self.h2 = hidden_dim_2
 
  def __call__(self, x):
    x = hk.Flatten()(x)
    x = hk.Linear(self.h1)(x)
    x = jax.nn.relu(x)
    x = hk.Linear(self.h2)(x)
    x = jax.nn.relu(x)
    x = hk.Linear(10)(x)
 
    return x / 0.007

In [None]:
 
 
def _model(input):
 
  logits = MLP()(input)
  return logits
 
rng_keys = hk.PRNGSequence(5322)
 
model = hk.without_apply_rng(hk.transform(_model))
 
x, y = random.normal(next(rng_keys), (1, 28, 28)),\
       random.normal(next(rng_keys), (10,))



In [None]:
def loss_f(params, input, labels):
  logits = model.apply(params, input)
  return jnp.mean(softmax_cross_entropy(logits, labels))

In [None]:
opt = optax.adam(1e-3)
 
 
params = model.init(next(rng_keys), x)
 
opt_state = opt.init(params)

In [None]:
def train_step(batch, params, opt_state):
  x = batch["image"].astype(jnp.float32) / 255.
  y = batch["label"].astype(jnp.float16)
 
  #logits = model.apply(params, x) : this didn't work :((((((((
 
  loss, grads = jax.value_and_grad(loss_f)(params, x, y)
 
  updates, new_state = opt.update(grads, opt_state, params)
 
  new_params = optax.apply_updates(params, updates)
 
  return loss, new_params, new_state

In [None]:
for step in range(6*30):
  batch = next(train)
  
  loss, params, opt_state = train_step(batch, params, opt_state)
 
  #if step+1 % 100 == 0:
  print(loss)

26.093689
18.403444
14.273032
12.678823
9.254094
7.155964
4.879031
2.8370836
2.0081465
2.0044057
2.6979413
2.1526723
1.8961673
1.1878397
1.0918144
1.0188046
1.082282
1.0657595
0.89700973
0.97095025
1.0158923
0.9057389
0.8451142
0.90824515
0.95075977
0.7087435
0.6762309
0.6196399
0.65618885
0.59841835
0.57153213
0.4829418
0.58315027
0.55353343
0.46622682
0.522438
0.63020796
0.42958814
0.37804466
0.4834515
0.432176
0.41795257
0.39416224
0.430921
0.40570354
0.40630338
0.43873733
0.3684984
0.43351117
0.35383353
0.45855662
0.3238036
0.40179837
0.3713284
0.29529378
0.42883945
0.39610118
0.38812608
0.29434985
0.25721738
0.26304486
0.297356
0.28069127
0.37101153
0.34849668
0.3445206
0.29595545
0.2537937
0.40256327
0.2628358
0.23771578
0.2582818
0.26050845
0.2668162
0.2714106
0.2919264
0.29753554
0.25187933
0.20026307
0.25052473
0.21849391
0.26209986
0.27636898
0.18609782
0.25044376
0.186051
0.26291746
0.1955873
0.19530097
0.26065564
0.24837945
0.21697068
0.20775148
0.2517066
0.2184599
0.233275

In [None]:
!pip3 install gradio -q

[K     |████████████████████████████████| 1.6MB 13.7MB/s 
[K     |████████████████████████████████| 215kB 59.5MB/s 
[K     |████████████████████████████████| 1.9MB 61.7MB/s 
[K     |████████████████████████████████| 71kB 8.2MB/s 
[K     |████████████████████████████████| 962kB 57.4MB/s 
[K     |████████████████████████████████| 3.2MB 54.1MB/s 
[?25h  Building wheel for flask-cachebuster (setup.py) ... [?25l[?25hdone
  Building wheel for ffmpy (setup.py) ... [?25l[?25hdone


In [None]:
def recognize_digit(img):
  img = jnp.array(img)
  x = jnp.reshape(img, (1, 28, 28))
  predictions = jax.nn.softmax(model.apply(params, x), -1).squeeze().tolist()
 
  dict = {str(i): predictions[i] for i in range(10)}
 
  return dict

In [None]:
import gradio as gr

In [None]:
gr.Interface(fn=recognize_digit, inputs="sketchpad", outputs="label").launch()

Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`
This share link will expire in 24 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted (NEW!)
Running on External URL: https://50132.gradio.app
Interface loading below...


(<Flask 'gradio.networking'>,
 'http://127.0.0.1:7861/',
 'https://50132.gradio.app')

In [None]:
recognize_digit(jax.random.normal(next(rng_keys), (28, 28)))

{'0': 0.0,
 '1': 0.0,
 '2': 0.0,
 '3': 0.0,
 '4': 0.0,
 '5': 0.0,
 '6': 0.0,
 '7': 1.0,
 '8': 0.0,
 '9': 0.0}