
# JAX ResNet CIFAR10

Connect to public google colab GPU runtime

In [1]:
!pip install flax

Collecting flax
  Downloading flax-0.3.5-py3-none-any.whl (193 kB)
[K     |████████████████████████████████| 193 kB 998 kB/s 
Collecting optax
  Downloading optax-0.0.9-py3-none-any.whl (118 kB)
[K     |████████████████████████████████| 118 kB 5.1 MB/s 
Collecting chex>=0.0.4
  Downloading chex-0.0.8-py3-none-any.whl (57 kB)
[K     |████████████████████████████████| 57 kB 2.5 MB/s 
Installing collected packages: chex, optax, flax
Successfully installed chex-0.0.8 flax-0.3.5 optax-0.0.9


In [2]:
#@title If connecting to TPU run this
# TPU setup : Boilerplate for connecting JAX to TPU.

import os
if 'google.colab' in str(get_ipython()) and 'COLAB_TPU_ADDR' in os.environ:
  # Make sure the Colab Runtime is set to Accelerator: TPU.
  import requests
  if 'TPU_DRIVER_MODE' not in globals():
    url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
    resp = requests.post(url)
    TPU_DRIVER_MODE = 1

  # The following is required to use TPU Driver as JAX's backend.
  from jax.config import config
  config.FLAGS.jax_xla_backend = "tpu_driver"
  config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
  print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
  print('No TPU detected. Can be changed under "Runtime/Change runtime type".')

Registered TPU: grpc://10.96.247.122:8470


In [3]:
#@title imports

import io
import re

from functools import partial

import numpy as np

import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp

import flax
import flax.linen as nn
import flax.optim as optim
import flax.jax_utils as flax_utils

# Assert that GPU is available
# assert 'Gpu' in str(jax.devices())

import tensorflow as tf
import tensorflow_datasets as tfds

import optax

from functools import partial
from typing import Any, Callable, Sequence, Tuple
ModuleDef = Any

import logging
from flax.training import train_state
logging.getLogger().setLevel(logging.INFO)

# Data from torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

In [4]:
#@title ResNet model

class ResNetBlock(nn.Module):
    """ResNet block."""

    filters: int
    conv: ModuleDef
    norm: ModuleDef
    act: Callable
    strides: Tuple[int, int] = (1, 1)

    @nn.compact
    def __call__(
        self,
        x,
    ):
        residual = x
        y = self.conv(self.filters, (3, 3), self.strides, padding=((1, 1), (1, 1)))(x)
        y = self.norm()(y)
        y = self.act(y)
        y = self.conv(self.filters, (3, 3), padding=((1, 1), (1, 1)))(y)
        y = self.norm(scale_init=nn.initializers.zeros)(y)

        if residual.shape != y.shape:
            residual = self.conv(self.filters, (1, 1), self.strides, name="conv_proj")(
                residual
            )
            residual = self.norm(name="norm_proj")(residual)

        return self.act(residual + y)


class BottleneckResNetBlock(nn.Module):
    """Bottleneck ResNet block."""

    filters: int
    conv: ModuleDef
    norm: ModuleDef
    act: Callable
    strides: Tuple[int, int] = (1, 1)

    @nn.compact
    def __call__(self, x):
        residual = x
        y = self.conv(self.filters, (1, 1))(x)
        y = self.norm()(y)
        y = self.act(y)
        y = self.conv(self.filters, (3, 3), self.strides, padding=((1, 1), (1, 1)))(y)
        y = self.norm()(y)
        y = self.act(y)
        y = self.conv(self.filters * 4, (1, 1))(y)
        y = self.norm(scale_init=nn.initializers.zeros)(y)

        if residual.shape != y.shape:
            residual = self.conv(
                self.filters * 4, (1, 1), self.strides, name="conv_proj"
            )(residual)
            residual = self.norm(name="norm_proj")(residual)

        return self.act(residual + y)


class ResNet(nn.Module):
    """ResNetV1."""

    stage_sizes: Sequence[int]
    block_cls: ModuleDef
    num_classes: int
    num_filters: int = 64
    dtype: Any = jnp.float32
    act: Callable = nn.relu

    @nn.compact
    def __call__(self, x, train: bool = True):
        conv = partial(nn.Conv, use_bias=False, dtype=self.dtype)
        norm = partial(
            nn.BatchNorm,
            use_running_average=not train,
            momentum=0.9,
            epsilon=1e-5,
            dtype=self.dtype,
        )

        x = conv(
            self.num_filters, (3, 3), (1, 1), padding=((1, 1), (1, 1)), name="conv_init"
        )(x)
        x = norm(name="bn_init")(x)
        x = nn.relu(x)
        # x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=((1, 1), (1, 1)))
        for i, block_size in enumerate(self.stage_sizes):
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                x = self.block_cls(
                    self.num_filters * 2 ** i,
                    strides=strides,
                    conv=conv,
                    norm=norm,
                    act=self.act,
                )(x)
        x = jnp.mean(x, axis=(1, 2))
        x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
        x = jnp.asarray(x, self.dtype)
        return x


ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock)
ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock)
ResNet50 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=BottleneckResNetBlock)
ResNet101 = partial(ResNet, stage_sizes=[3, 4, 23, 3], block_cls=BottleneckResNetBlock)
ResNet152 = partial(ResNet, stage_sizes=[3, 8, 36, 3], block_cls=BottleneckResNetBlock)
ResNet200 = partial(ResNet, stage_sizes=[3, 24, 36, 3], block_cls=BottleneckResNetBlock)


In [5]:
#@title Helpers - preprocess, train, eval steps

class ToArray:
    def __call__(self, x):
        x = np.asarray(x, dtype=np.float32)
        x /= 255.0
        return x

class ArrayNormalize:
    def __init__(self, mean, std):
        super().__init__()
        self.mean = mean
        self.std = std

    def __call__(self, x):
        mean = np.asarray(self.mean, dtype=np.float32)
        std = np.asarray(self.std, dtype=np.float32)
        if mean.ndim == 1:
            mean = mean.reshape(1, 1, -1)
        if std.ndim == 1:
            std = std.reshape(1, 1, -1)
        x -= mean
        x /= std
        return x

def array_collate(batch):
    imgs, targets = zip(*batch)
    return np.stack(imgs, axis=0), np.array(targets)

def shard(xs):
    return jax.tree_map(lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]) if len(x.shape) != 0 else x, xs)

class TrainState(train_state.TrainState):
    batch_stats: Any

def ce_loss(logits, labels):
    one_hot_labels = jax.nn.one_hot(labels, 10)
    xentropy = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels)
    return jnp.mean(xentropy)

def compute_metrics(logits, labels):
    loss = ce_loss(logits, labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {
        'loss': loss,
        'accuracy': accuracy,
    }
    metrics = jax.lax.pmean(metrics, axis_name='batch')
    return metrics

def train_step(state, batch):
    imgs, targets = batch
    def loss_fn(params):
        logits, new_model_state = state.apply_fn(
            {'params': params, 'batch_stats': state.batch_stats},
            imgs,
            mutable=['batch_stats']
        )
        loss = ce_loss(logits, targets)
        weight_penalty_params = jax.tree_leaves(params)
        weight_decay = 5e-4
        weight_l2 = sum([jnp.sum(x ** 2) for x in weight_penalty_params if x.ndim > 1])
        weight_penalty = weight_decay * 0.5 * weight_l2
        loss = loss + weight_penalty
        return loss, (new_model_state, logits)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    aux, grads = grad_fn(state.params)
    grads = jax.lax.pmean(grads, axis_name='batch')
    new_model_state, logits = aux[1]
    metrics = compute_metrics(logits, targets)

    new_state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
    return new_state, metrics

def eval_step(state, batch):
    imgs, targets = batch
    variables = {'params': state.params, 'batch_stats': state.batch_stats}
    logits = state.apply_fn(variables, imgs, train=False, mutable=False)
    return compute_metrics(logits, targets)

def stack_forest(forest):
    stack_args = lambda *args: np.stack(args)
    return jax.tree_multimap(stack_args, *forest)

def get_metrics(device_metrics):
    device_metrics = jax.tree_map(lambda x: x[0], device_metrics)
    metrics_np = jax.device_get(device_metrics)
    return stack_forest(metrics_np)

cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, 'x'), 'x')

def sync_batch_stats(state):
    return state.replace(batch_stats=cross_replica_mean(state.batch_stats))


In [6]:
BATCH_SIZE = 128
DATA_DIR = "data/"
EPOCHS = 30

if BATCH_SIZE % jax.device_count() != 0:
    raise ValueError('Batch size must be divisible by the number of devices')
LOCAL_BATCH_SIZE = BATCH_SIZE // jax.process_count()

INFO:absl:Unable to initialize backend 'gpu': FAILED_PRECONDITION: No visible GPU devices.
INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.


In [17]:
#@title torch data loader
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    ToArray(),
    ArrayNormalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dset = datasets.CIFAR10(root=DATA_DIR, download=True, train=True, transform=train_transform)
data_train_iter = DataLoader(train_dset, batch_size=LOCAL_BATCH_SIZE, shuffle=True, collate_fn=array_collate, num_workers=4, drop_last=True)

val_transform = transforms.Compose([
    ToArray(),
    ArrayNormalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
val_dset = datasets.CIFAR10(root=DATA_DIR, download=True, train=False, transform=val_transform)
data_test_iter = DataLoader(val_dset, batch_size=LOCAL_BATCH_SIZE, shuffle=False, collate_fn=array_collate, num_workers=4, drop_last=True)

Files already downloaded and verified


  cpuset_checked))


Files already downloaded and verified


In [10]:
#@title Main train and eval function
def main():
    logging.info(f"JAX process: {jax.process_index()} / {jax.process_count()}")
    logging.info(f"JAX local devices: {jax.local_device_count()}")

    key = jax.random.PRNGKey(0)
    
    model = ResNet18(num_classes=10)

    @jax.jit
    def init(*args):
        return model.init(*args)
    vars = init({"params": key}, jnp.ones((1, 32, 32, 3)))
    params, batch_stats = vars["params"], vars["batch_stats"]
    optim = optax.rmsprop(learning_rate=0.0001, decay=1e-06, momentum=0.9)
    state = TrainState.create(apply_fn=model.apply, params=params, tx=optim, batch_stats=batch_stats)
    state = flax.jax_utils.replicate(state)

    p_train_step = jax.pmap(train_step, axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')
    for epoch in range(1, EPOCHS+1):
        logging.info(f"EPOCH: {epoch}")
        train_metrics, val_metrics = [], []
        for batch in data_train_iter:
            batch = shard(batch)
            imgs, tgts = batch
            state, metrics = p_train_step(state, batch)
            train_metrics.append(metrics) 
        train_metrics = get_metrics(train_metrics)
        train_summary = {
            f'train {k}': v
            for k, v in jax.tree_map(lambda x: x.mean(), train_metrics).items()
        }
        val_summary = {}
        if epoch % 5 == 0:   
            state = sync_batch_stats(state)    
            for batch in data_test_iter:
                batch = shard(batch)
                metrics = p_eval_step(state, batch)
                val_metrics.append(metrics)        
            val_metrics = get_metrics(val_metrics)
            val_summary = {
                f'val {k}': v
                for k, v in jax.tree_map(lambda x: x.mean(), val_metrics).items()
            }
        summary = {**train_summary, **val_summary}
        msg = "".join(["[{}] {:.5f} ".format(key, value) for key, value in summary.items()])
        logging.info(msg)

    jax.random.normal(key, ()).block_until_ready()
    logging.info("Completed, Cleaning up... Done!")

In [None]:
#@title Output from a GPU run
main()

INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
INFO:root:JAX process: 0 / 1
INFO:root:JAX local devices: 1


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/cifar-10-python.tar.gz to data/


  cpuset_checked))


Files already downloaded and verified


INFO:root:EPOCH: 1
INFO:root:[train accuracy] 0.41927 [train loss] 1.58606 
INFO:root:EPOCH: 2
INFO:root:[train accuracy] 0.58538 [train loss] 1.14783 
INFO:root:EPOCH: 3
INFO:root:[train accuracy] 0.65879 [train loss] 0.94817 
INFO:root:EPOCH: 4
INFO:root:[train accuracy] 0.71889 [train loss] 0.80188 
INFO:root:EPOCH: 5
INFO:root:[train accuracy] 0.74868 [train loss] 0.72249 [val accuracy] 0.72256 [val loss] 0.85418 
INFO:root:EPOCH: 6
INFO:root:[train accuracy] 0.77188 [train loss] 0.66075 
INFO:root:EPOCH: 7
INFO:root:[train accuracy] 0.78948 [train loss] 0.61210 
INFO:root:EPOCH: 8
INFO:root:[train accuracy] 0.80256 [train loss] 0.57567 
INFO:root:EPOCH: 9
INFO:root:[train accuracy] 0.81182 [train loss] 0.54749 
INFO:root:EPOCH: 10
INFO:root:[train accuracy] 0.82266 [train loss] 0.51819 [val accuracy] 0.74669 [val loss] 0.79857 
INFO:root:EPOCH: 11
INFO:root:[train accuracy] 0.82831 [train loss] 0.49704 
INFO:root:EPOCH: 12
INFO:root:[train accuracy] 0.83804 [train loss] 0.47616 
I

In [None]:
#@title Output from a TPU run
main()

INFO:absl:Unable to initialize backend 'gpu': FAILED_PRECONDITION: No visible GPU devices.
INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
INFO:root:JAX process: 0 / 1
INFO:root:JAX local devices: 8


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/cifar-10-python.tar.gz to data/


  cpuset_checked))


Files already downloaded and verified


INFO:root:EPOCH: 1
INFO:root:[train accuracy] 0.39119 [train loss] 1.65428 
INFO:root:EPOCH: 2
INFO:root:[train accuracy] 0.56813 [train loss] 1.20797 
INFO:root:EPOCH: 3
INFO:root:[train accuracy] 0.65082 [train loss] 0.99867 
INFO:root:EPOCH: 4
INFO:root:[train accuracy] 0.69940 [train loss] 0.86388 
INFO:root:EPOCH: 5
INFO:root:[train accuracy] 0.73015 [train loss] 0.77962 [val accuracy] 0.72326 [val loss] 0.81792 
INFO:root:EPOCH: 6
INFO:root:[train accuracy] 0.75729 [train loss] 0.70349 
INFO:root:EPOCH: 7
INFO:root:[train accuracy] 0.77304 [train loss] 0.66201 
INFO:root:EPOCH: 8
INFO:root:[train accuracy] 0.78712 [train loss] 0.62450 
INFO:root:EPOCH: 9
INFO:root:[train accuracy] 0.79774 [train loss] 0.59005 
INFO:root:EPOCH: 10
INFO:root:[train accuracy] 0.80755 [train loss] 0.56362 [val accuracy] 0.79487 [val loss] 0.67416 
INFO:root:EPOCH: 11
INFO:root:[train accuracy] 0.81546 [train loss] 0.54117 
INFO:root:EPOCH: 12
INFO:root:[train accuracy] 0.82334 [train loss] 0.52120 
I

## Load data using tfds and run model [optional]

In [7]:
#@title TF data loader
data_builder = tfds.builder('cifar10')
data_builder.download_and_prepare()

CifarNormalizer = ArrayNormalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

def get_data(split, repeats, batch_size, shuffle_buffer):
  data = data_builder.as_dataset(split=split)
  # data = data.map(lambda d: {'image': d['image'], 'label': d['label']})
  def _pp(data):
    im = data['image']
    if split == 'train':
      im = tf.image.resize(im, [32, 32])
      im = tf.image.resize_with_crop_or_pad(im, 40, 40)
      im = tf.image.random_crop(im, [32, 32, 3])
      im = tf.image.flip_left_right(im)
    else:
      im = tf.image.resize(im, [32, 32])
    # im = (im - 127.5) / 127.5
    im /= 255.0
    im = CifarNormalizer(im)
    data['image'] = im
    # data['label'] = tf.one_hot(data['label'], 10)
    return {'image': data['image'], 'label': data['label']}

  data = data.repeat(repeats)
  data = data.shuffle(shuffle_buffer)
  data = data.map(_pp)
  return data.batch(batch_size)

data_train = get_data(split='train', repeats=1,
                      batch_size=BATCH_SIZE, shuffle_buffer=500)
data_test = get_data(split='test', repeats=1,
                      batch_size=BATCH_SIZE, shuffle_buffer=1)

INFO:absl:Load pre-computed DatasetInfo (eg: splits, num examples,...) from GCS: cifar10/3.0.2
INFO:absl:Load dataset info from /tmp/tmpnsa29z5xtfds
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Generating dataset cifar10 (/root/tensorflow_datasets/cifar10/3.0.2)


[1mDownloading and preparing dataset cifar10/3.0.2 (download: 162.17 MiB, generated: 132.40 MiB, total: 294.58 MiB) to /root/tensorflow_datasets/cifar10/3.0.2...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

INFO:absl:Downloading https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz into /root/tensorflow_datasets/downloads/cs.toronto.edu_kriz_cifar-10-binaryODHPtIjLh3oLcXirEISTO7dkzyKjRCuol6lV8Wc6C7s.tar.gz.tmp.3fd2c62f04fc473c8a35a01f71fe9470...
INFO:absl:Generating split train







0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/cifar10/3.0.2.incompleteN2W7GD/cifar10-train.tfrecord


  0%|          | 0/50000 [00:00<?, ? examples/s]

INFO:absl:Done writing /root/tensorflow_datasets/cifar10/3.0.2.incompleteN2W7GD/cifar10-train.tfrecord. Shard lengths: [50000]
INFO:absl:Generating split test


0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/cifar10/3.0.2.incompleteN2W7GD/cifar10-test.tfrecord


  0%|          | 0/10000 [00:00<?, ? examples/s]

INFO:absl:Done writing /root/tensorflow_datasets/cifar10/3.0.2.incompleteN2W7GD/cifar10-test.tfrecord. Shard lengths: [10000]
INFO:absl:Skipping computing stats for mode ComputeStatsMode.SKIP.
INFO:absl:Constructing tf.data.Dataset for split train, from /root/tensorflow_datasets/cifar10/3.0.2


[1mDataset cifar10 downloaded and prepared to /root/tensorflow_datasets/cifar10/3.0.2. Subsequent calls will reuse this data.[0m


INFO:absl:Constructing tf.data.Dataset for split test, from /root/tensorflow_datasets/cifar10/3.0.2


In [8]:
#@title main train function using tfds
def main_use_tfds():
    logging.info(f"JAX process: {jax.process_index()} / {jax.process_count()}")
    logging.info(f"JAX local devices: {jax.local_device_count()}")

    key = jax.random.PRNGKey(0)
    
    model = ResNet18(num_classes=10)

    @jax.jit
    def init(*args):
        return model.init(*args)
    vars = init({"params": key}, jnp.ones((1, 32, 32, 3)))
    params, batch_stats = vars["params"], vars["batch_stats"]
    optim = optax.rmsprop(learning_rate=0.0001, decay=1e-06, momentum=0.9)
    state = TrainState.create(apply_fn=model.apply, params=params, tx=optim, batch_stats=batch_stats)
    state = flax.jax_utils.replicate(state)

    p_train_step = jax.pmap(train_step, axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')
    for epoch in range(1, EPOCHS+1):
        logging.info(f"EPOCH: {epoch}")
        train_metrics, val_metrics = [], []
        for batch in data_train.as_numpy_iterator():
            batch = shard(batch)
            if 'image' in batch:
                imgs, tgts = batch['image'], batch['label']
            else:
                imgs, tgts = batch
            state, metrics = p_train_step(state, (imgs, tgts))
            train_metrics.append(metrics) 
        train_metrics = get_metrics(train_metrics)
        train_summary = {
            f'train {k}': v
            for k, v in jax.tree_map(lambda x: x.mean(), train_metrics).items()
        }
        val_summary = {}
        if epoch % 5 == 0:   
            state = sync_batch_stats(state)    
            for batch in data_test.as_numpy_iterator():
                batch = shard(batch)
                if 'image' in batch:
                    imgs, tgts = batch['image'], batch['label']
                else:
                    imgs, tgts = batch
                metrics = p_eval_step(state, (imgs, tgts))
                val_metrics.append(metrics)        
            val_metrics = get_metrics(val_metrics)
            val_summary = {
                f'val {k}': v
                for k, v in jax.tree_map(lambda x: x.mean(), val_metrics).items()
            }
        summary = {**train_summary, **val_summary}
        msg = "".join(["[{}] {:.5f} ".format(key, value) for key, value in summary.items()])
        logging.info(msg)

    jax.random.normal(key, ()).block_until_ready()
    logging.info("Completed, Cleaning up... Done!")

In [10]:
main_use_tfds()

INFO:root:JAX process: 0 / 1
INFO:root:JAX local devices: 8
INFO:root:EPOCH: 1
INFO:root:[train accuracy] 0.39228 [train loss] 1.64622 
INFO:root:EPOCH: 2
INFO:root:[train accuracy] 0.56955 [train loss] 1.21034 
INFO:root:EPOCH: 3
INFO:root:[train accuracy] 0.64558 [train loss] 1.00623 
INFO:root:EPOCH: 4
INFO:root:[train accuracy] 0.69981 [train loss] 0.86114 
INFO:root:EPOCH: 5
INFO:root:[train accuracy] 0.73903 [train loss] 0.75713 [val accuracy] 0.71489 [val loss] 1.00703 
INFO:root:EPOCH: 6
INFO:root:[train accuracy] 0.76267 [train loss] 0.69382 
INFO:root:EPOCH: 7
INFO:root:[train accuracy] 0.77674 [train loss] 0.64590 
INFO:root:EPOCH: 8
INFO:root:[train accuracy] 0.79456 [train loss] 0.60419 
INFO:root:EPOCH: 9
INFO:root:[train accuracy] 0.80217 [train loss] 0.57761 
INFO:root:EPOCH: 10
INFO:root:[train accuracy] 0.81592 [train loss] 0.53918 [val accuracy] 0.75969 [val loss] 0.83739 
INFO:root:EPOCH: 11
INFO:root:[train accuracy] 0.82434 [train loss] 0.51725 
INFO:root:EPOCH: 1