# Torch

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

print(torch.__version__)

1.9.0a0+gitd69c22d


In [2]:
BATCH_SIZE = 32
EPOCHS = 15

transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

## download and load training dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2, pin_memory=True)

## download and load testing dataset
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*5*5, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out
    
    
def get_accuracy(logit, target, batch_size):
    ''' Obtain accuracy for training round '''
    corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
    accuracy = 100.0 * corrects/batch_size
    return accuracy.item()

In [4]:
model = LeNet()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

model = model.to(device)

cuda:0


In [5]:
learning_rate = 0.001
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

In [6]:
%%time
from tqdm import tqdm

for epoch in tqdm(range(EPOCHS)):
    train_running_loss = 0.0
    train_acc = 0.0

    model = model.train()

    ## training step
    for images, labels in trainloader:
        
        images = images.to(device)
        labels = labels.to(device)

        ## forward + backprop + loss
        logits = model(images)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()

        ## update model params
        optimizer.step()

        train_running_loss += loss.detach().item()
        train_acc += get_accuracy(logits, labels, BATCH_SIZE)
    
    model.eval()
    print('\t Epoch: %d | Train Loss: %.4f | Train Accuracy: %.4f' \
          %(epoch, train_running_loss / len(trainloader), train_acc / len(trainloader)))

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
  7%|▋         | 1/15 [00:11<02:47, 11.98s/it]

	 Epoch: 0 | Train Loss: 1.5733 | Train Accuracy: 42.9023


 13%|█▎        | 2/15 [00:16<01:42,  7.85s/it]

	 Epoch: 1 | Train Loss: 1.2739 | Train Accuracy: 54.5146


 20%|██        | 3/15 [00:21<01:18,  6.54s/it]

	 Epoch: 2 | Train Loss: 1.1629 | Train Accuracy: 58.8012


 27%|██▋       | 4/15 [00:26<01:05,  5.92s/it]

	 Epoch: 3 | Train Loss: 1.0916 | Train Accuracy: 61.7123


 33%|███▎      | 5/15 [00:31<00:55,  5.59s/it]

	 Epoch: 4 | Train Loss: 1.0408 | Train Accuracy: 63.3677


 40%|████      | 6/15 [00:36<00:48,  5.37s/it]

	 Epoch: 5 | Train Loss: 0.9938 | Train Accuracy: 65.3551


 47%|████▋     | 7/15 [00:41<00:41,  5.25s/it]

	 Epoch: 6 | Train Loss: 0.9641 | Train Accuracy: 66.0549


 53%|█████▎    | 8/15 [00:46<00:36,  5.17s/it]

	 Epoch: 7 | Train Loss: 0.9295 | Train Accuracy: 67.3984


 60%|██████    | 9/15 [00:51<00:30,  5.10s/it]

	 Epoch: 8 | Train Loss: 0.9062 | Train Accuracy: 68.1482


 67%|██████▋   | 10/15 [00:56<00:25,  5.06s/it]

	 Epoch: 9 | Train Loss: 0.8843 | Train Accuracy: 68.8340


 73%|███████▎  | 11/15 [01:01<00:20,  5.03s/it]

	 Epoch: 10 | Train Loss: 0.8597 | Train Accuracy: 69.6797


 80%|████████  | 12/15 [01:06<00:15,  5.01s/it]

	 Epoch: 11 | Train Loss: 0.8437 | Train Accuracy: 70.2435


 87%|████████▋ | 13/15 [01:11<00:10,  5.03s/it]

	 Epoch: 12 | Train Loss: 0.8281 | Train Accuracy: 70.7993


 93%|█████████▎| 14/15 [01:16<00:05,  5.03s/it]

	 Epoch: 13 | Train Loss: 0.8147 | Train Accuracy: 71.4271


100%|██████████| 15/15 [01:21<00:00,  5.47s/it]

	 Epoch: 14 | Train Loss: 0.8072 | Train Accuracy: 71.8310
CPU times: user 1min 12s, sys: 8.03 s, total: 1min 20s
Wall time: 1min 21s





In [7]:
test_acc = 0.0

model.eval()
with torch.no_grad():
    for images, labels in testloader:
        images = images.to(device)
        labels = labels.to(device)

        logits = model(images)
        test_acc += get_accuracy(logits, labels, BATCH_SIZE)
print('Test Accuracy: %.4f'%(test_acc / len(testloader)))



Test Accuracy: 66.3538


# TF

In [2]:
import tensorflow as tf

from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model

2021-10-27 18:21:37.513663: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


In [3]:
print(tf.__version__)

2.4.2


In [4]:
import tensorflow.image as transforms

cifar10 = tf.keras.datasets.cifar10

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")

In [5]:
BATCH_SIZE = 32
EPOCHS = 15

def transform_train(image, label):
    image = transforms.resize(image, (32, 32))
    image = transforms.random_flip_left_right(image)
    image = transforms.per_image_standardization(image)
    return image, label

def transform(image, label):
    image = transforms.resize(image, (32, 32))
    image = transforms.per_image_standardization(image)
    return image, label

train_ds = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train)).map(transform_train).shuffle(10000).batch(BATCH_SIZE)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).map(transform).batch(BATCH_SIZE)


2021-10-27 18:21:52.444378: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-10-27 18:21:52.445174: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2021-10-27 18:21:52.593843: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
pciBusID: 0000:87:00.0 name: A100-SXM4-40GB computeCapability: 8.0
coreClock: 1.41GHz coreCount: 108 deviceMemorySize: 39.59GiB deviceMemoryBandwidth: 1.41TiB/s
2021-10-27 18:21:52.593877: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2021-10-27 18:21:52.596634: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2021-10-27 18:21:52.596662: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
2021-10-

In [6]:
import tensorflow as tf

from tensorflow.keras.layers import Dense, Flatten, Conv2D, Dropout, MaxPool2D
from tensorflow.keras import Model

class LeNet(Model):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = Conv2D(6, 5, activation='relu')
        self.conv2 = Conv2D(16, 5, activation='relu')
        self.flatten = Flatten()
        self.maxpool = MaxPool2D((2,2), 2)
        self.fc1 = Dense(120, activation='relu')
        self.fc2 = Dense(84, activation='relu')
        self.fc3 = Dense(10)

    def call(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

In [7]:
# Create an instance of the model
model = LeNet()

In [8]:
learning_rate = 0.001
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

In [9]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

In [10]:
@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        # training=True is only needed if there are layers with different
        # behavior during training versus inference (e.g. Dropout).
        predictions = model(images, training=True)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_accuracy(labels, predictions)

In [11]:
%%time 
from tqdm import tqdm

for epoch in tqdm(range(EPOCHS)):
    # Reset the metrics at the start of the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

    for images, labels in train_ds:
        train_step(images, labels)

    print('\t Epoch: %d | Loss: %.4f | Train Accuracy: %.2f' \
        %(epoch, train_loss.result(), train_accuracy.result() * 100))

  0%|          | 0/15 [00:00<?, ?it/s]2021-10-27 18:22:32.983234: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2021-10-27 18:22:33.003723: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2245865000 Hz
2021-10-27 18:22:34.152619: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2021-10-27 18:22:34.822175: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
2021-10-27 18:22:34.869730: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8
2021-10-27 18:22:37.619292: I tensorflow/stream_executor/cuda/cuda_blas.cc:1838] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
  7%|▋         | 1/15 [00:09<02:11,  9.41s/it]

	 Epoch: 0 | Loss: 1.5148 | Train Accuracy: 45.12


 13%|█▎        | 2/15 [00:14<01:26,  6.67s/it]

	 Epoch: 1 | Loss: 1.2378 | Train Accuracy: 56.14


 20%|██        | 3/15 [00:18<01:09,  5.80s/it]

	 Epoch: 2 | Loss: 1.1251 | Train Accuracy: 60.03


 27%|██▋       | 4/15 [00:23<00:59,  5.40s/it]

	 Epoch: 3 | Loss: 1.0442 | Train Accuracy: 63.40


 33%|███▎      | 5/15 [00:28<00:51,  5.16s/it]

	 Epoch: 4 | Loss: 0.9870 | Train Accuracy: 65.27


 40%|████      | 6/15 [00:33<00:45,  5.06s/it]

	 Epoch: 5 | Loss: 0.9421 | Train Accuracy: 66.93


 47%|████▋     | 7/15 [00:38<00:39,  4.99s/it]

	 Epoch: 6 | Loss: 0.9034 | Train Accuracy: 68.30


 53%|█████▎    | 8/15 [00:42<00:34,  4.93s/it]

	 Epoch: 7 | Loss: 0.8809 | Train Accuracy: 69.10


 60%|██████    | 9/15 [00:47<00:29,  4.92s/it]

	 Epoch: 8 | Loss: 0.8492 | Train Accuracy: 70.16


 67%|██████▋   | 10/15 [00:52<00:24,  4.90s/it]

	 Epoch: 9 | Loss: 0.8297 | Train Accuracy: 70.82


 73%|███████▎  | 11/15 [00:57<00:19,  4.86s/it]

	 Epoch: 10 | Loss: 0.8035 | Train Accuracy: 71.77


 80%|████████  | 12/15 [01:02<00:14,  4.83s/it]

	 Epoch: 11 | Loss: 0.7894 | Train Accuracy: 72.34


 87%|████████▋ | 13/15 [01:07<00:09,  4.83s/it]

	 Epoch: 12 | Loss: 0.7719 | Train Accuracy: 72.82


 93%|█████████▎| 14/15 [01:11<00:04,  4.83s/it]

	 Epoch: 13 | Loss: 0.7581 | Train Accuracy: 73.39


100%|██████████| 15/15 [01:16<00:00,  5.11s/it]

	 Epoch: 14 | Loss: 0.7392 | Train Accuracy: 74.11
CPU times: user 1min 27s, sys: 7.95 s, total: 1min 35s
Wall time: 1min 16s





In [12]:
@tf.function
def test_step(images, labels):
    # training=False is only needed if there are layers with different
    # behavior during training versus inference (e.g. Dropout).
    predictions = model(images, training=False)
    t_loss = loss_object(labels, predictions)

    test_loss(t_loss)
    test_accuracy(labels, predictions)

In [13]:
for test_images, test_labels in test_ds:
    test_step(test_images, test_labels)

print(f'Test Accuracy: {test_accuracy.result() * 100:.2f}')

Test Accuracy: 67.41


# JAX

In [14]:
!nvidia-smi

Wed Oct 27 18:24:11 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.142.00   Driver Version: 450.142.00   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  A100-SXM4-40GB      On   | 00000000:87:00.0 Off |                    0 |
| N/A   29C    P0    59W / 400W |  39424MiB / 40537MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [15]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Mon_May__3_19:15:13_PDT_2021
Cuda compilation tools, release 11.3, V11.3.109
Build cuda_11.3.r11.3/compiler.29920130_0


These are run in terminal 

``` bash
pip install tensorflow-datasets

# in case main install doesn't work, uninstall everything first
pip uninstall jax jaxlib dm-haiku optax -y 

# per Peng's recommendation 
pip install https://storage.googleapis.com/jax-releases/cuda110/jaxlib-0.1.71+cuda110-cp38-none-manylinux2010_x86_64.whl

# main install
pip install "jax[cuda]"
pip install git+https://github.com/deepmind/dm-haiku
pip install optax

```

Do this for testing 
``` python
import jax
import jax.numpy as jnp
print(jax.devices()) # should print out -> [GpuDevice(id=0, process_index=0)]
print(jnp.ones(3).device_buffer.device()) # should print out -> gpu:0
```

Additionally, in case there was an error about no kernel image when doing the above, per Greg's advice, set 
``` bash
XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 
```

In [1]:
import os
os.environ['XLA_FLAGS']='--xla_gpu_force_compilation_parallelism=1'
!echo $XLA_FLAGS

--xla_gpu_force_compilation_parallelism=1


In [2]:
import tensorflow as tf
import tensorflow.image as transforms
import tensorflow_datasets as tfds 
import numpy as np

2021-10-27 18:51:03.123525: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


In [3]:
import jax
import jax.numpy as jnp
import haiku as hk
import optax

# testing if jax can find GPU device properly
print(tf.config.list_physical_devices('GPU'))
print(jax.devices())
print(jnp.ones(3).device_buffer.device())

2021-10-27 18:51:05.910163: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-10-27 18:51:05.910987: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2021-10-27 18:51:06.054378: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
pciBusID: 0000:87:00.0 name: A100-SXM4-40GB computeCapability: 8.0
coreClock: 1.41GHz coreCount: 108 deviceMemorySize: 39.59GiB deviceMemoryBandwidth: 1.41TiB/s
2021-10-27 18:51:06.054413: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2021-10-27 18:51:06.056893: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2021-10-27 18:51:06.056925: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
2021-10-

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
[GpuDevice(id=0, process_index=0)]
gpu:0


In [4]:
BATCH_SIZE = 32
EPOCHS = 15
LR = 1e-3

In [5]:
cifar10 = tf.keras.datasets.cifar10

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")

def transform_train(image, label):
    image = transforms.resize(image, (32, 32))
    image = transforms.random_flip_left_right(image)
    image = transforms.per_image_standardization(image)
    return image, label

def transform(image, label):
    image = transforms.resize(image, (32, 32))
    image = transforms.per_image_standardization(image)
    return image, label

def load_dataset(is_training: bool, batch_size: int):
    if is_training:
        ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(transform_train).shuffle(10000)
    else:
        ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).map(transform)
    ds = ds.batch(batch_size)
    return tfds.as_numpy(ds)

train_ds = load_dataset(is_training = True, batch_size = BATCH_SIZE)
test_ds = load_dataset(is_training = False, batch_size = BATCH_SIZE)


2021-10-27 18:53:36.180567: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-10-27 18:53:36.183636: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
pciBusID: 0000:87:00.0 name: A100-SXM4-40GB computeCapability: 8.0
coreClock: 1.41GHz coreCount: 108 deviceMemorySize: 39.59GiB deviceMemoryBandwidth: 1.41TiB/s
2021-10-27 18:53:36.183673: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2021-10-27 18:53:36.183707: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2021-10-27 18:53:36.183718: I tensorflow/stream_executor/platform/

In [6]:
def lenet(images):
    net = hk.Sequential([
      hk.Conv2D(6,5), jax.nn.relu,
      hk.MaxPool((2,2),2,'VALID'),
      hk.Conv2D(16,5), jax.nn.relu,
      hk.MaxPool((2,2),2,'VALID'),
      hk.Flatten(),
      hk.Linear(120), jax.nn.relu,
      hk.Linear(84), jax.nn.relu,
      hk.Linear(10),
    ])
    return net(images)

@jax.jit
def accuracy(params, images, labels):
    predictions = net.apply(params, images)
    return jnp.mean(jnp.argmax(predictions, axis=-1) == labels.flatten())


In [7]:
net = hk.without_apply_rng(hk.transform(lenet))
opt = optax.adam(LR)

In [8]:
def loss_and_acc(params, images, labels):
    logits = net.apply(params, images)
    labels = labels.flatten()
    acc = jnp.mean(jnp.argmax(logits, axis=-1) == labels)

    labels = jax.nn.one_hot(labels, logits.shape[-1])
    softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits))
    softmax_xent /= labels.shape[0]

    return softmax_xent, acc
    
@jax.jit
def update(params, opt_state, images, labels):
    (loss_val, acc_val), grads = jax.value_and_grad(loss_and_acc, has_aux=True)(params, images, labels)
    updates, opt_state = opt.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss_val, acc_val


In [9]:
# for some reason took a long time
params = net.init(jax.random.PRNGKey(42), next(iter(train_ds))[0])
opt_state = opt.init(params)

2021-10-27 18:53:45.403761: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2021-10-27 18:53:45.423805: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2245865000 Hz


In [10]:
%%time
from tqdm import tqdm

for epoch in tqdm(range(EPOCHS)):
    train_loss, train_acc, cnt = jnp.array(0), jnp.array(0), 0
    train_ds = load_dataset(is_training = True, batch_size = BATCH_SIZE)

    for images, labels in train_ds:        
        params, opt_state, loss_val, acc_val = update(params, opt_state, images, labels)
        train_loss += loss_val
        train_acc += acc_val
        cnt += 1

    train_loss, train_acc = jax.device_get((train_loss, train_acc))

    print('\t Epoch: %d | Train Loss: %.4f | Train Accuracy: %.4f' \
          %(epoch, train_loss / cnt, 100*train_acc / cnt))

  7%|▋         | 1/15 [00:07<01:46,  7.63s/it]

	 Epoch: 0 | Train Loss: 1.5637 | Train Accuracy: 43.7900


 13%|█▎        | 2/15 [00:12<01:18,  6.03s/it]

	 Epoch: 1 | Train Loss: 1.2865 | Train Accuracy: 54.1747


 20%|██        | 3/15 [00:17<01:06,  5.52s/it]

	 Epoch: 2 | Train Loss: 1.1723 | Train Accuracy: 58.4633


 27%|██▋       | 4/15 [00:22<00:58,  5.28s/it]

	 Epoch: 3 | Train Loss: 1.1081 | Train Accuracy: 60.7126


 33%|███▎      | 5/15 [00:27<00:51,  5.15s/it]

	 Epoch: 4 | Train Loss: 1.0564 | Train Accuracy: 62.7759


 40%|████      | 6/15 [00:32<00:45,  5.05s/it]

	 Epoch: 5 | Train Loss: 1.0092 | Train Accuracy: 64.5134


 47%|████▋     | 7/15 [00:37<00:40,  5.02s/it]

	 Epoch: 6 | Train Loss: 0.9811 | Train Accuracy: 65.5230


 53%|█████▎    | 8/15 [00:42<00:35,  5.02s/it]

	 Epoch: 7 | Train Loss: 0.9489 | Train Accuracy: 66.6007


 60%|██████    | 9/15 [00:47<00:29,  5.00s/it]

	 Epoch: 8 | Train Loss: 0.9186 | Train Accuracy: 67.8223


 67%|██████▋   | 10/15 [00:51<00:24,  4.97s/it]

	 Epoch: 9 | Train Loss: 0.8967 | Train Accuracy: 68.5921


 73%|███████▎  | 11/15 [00:56<00:19,  4.96s/it]

	 Epoch: 10 | Train Loss: 0.8779 | Train Accuracy: 69.2258


 80%|████████  | 12/15 [01:01<00:14,  4.98s/it]

	 Epoch: 11 | Train Loss: 0.8570 | Train Accuracy: 69.9076


 87%|████████▋ | 13/15 [01:06<00:09,  4.98s/it]

	 Epoch: 12 | Train Loss: 0.8377 | Train Accuracy: 70.3875


 93%|█████████▎| 14/15 [01:11<00:04,  4.97s/it]

	 Epoch: 13 | Train Loss: 0.8237 | Train Accuracy: 70.9113


100%|██████████| 15/15 [01:16<00:00,  5.11s/it]

	 Epoch: 14 | Train Loss: 0.8095 | Train Accuracy: 71.3552
CPU times: user 1min 25s, sys: 10.8 s, total: 1min 36s
Wall time: 1min 16s





In [12]:
test_acc, cnt = jnp.array(0), 0
for images, labels in test_ds:  
    test_acc += accuracy(params, images, labels)
    cnt += 1

test_acc = jax.device_get(test_acc)
print('Test acc = %.2f' %(test_acc * 100/cnt))

Test acc = 64.68
