In [1]:
import jax
num_devices = jax.local_device_count()
print(f'# of GPUs : {num_devices}')

import haiku as hk
import optax
import numpy as np
import jax.numpy as jnp
from tqdm import tqdm
from functools import partial

# we import some helper function/classes
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from utils import *
from models.resnet import ResNet, Block
from datasets.cifar import load_dataset

import matplotlib.pyplot as plt
plt.style.use('ggplot')
plt.rcParams['figure.figsize'] = (4, 3)
plt.rcParams['xtick.labelsize'] = 8
plt.rcParams['ytick.labelsize'] = 8

# of GPUs : 8


  PyTreeDef = type(jax.tree_structure(None))


In [2]:
NUM_CLASSES = 100
NUM_EPOCH = 200
NUM_TRAIN = 50000
BATCH_SIZE = 1000
WARMUP_RATIO = 0.1
PEAK_LR = 0.4

rng = jax.random.PRNGKey(42)
batch = jnp.zeros((1, 32, 32, 3), jnp.float32)
net = partial(
    ResNet,
    name='ResNet_18',
    stage_sizes=[2, 2, 2, 2],
    num_filters=[64, 128, 256, 512],
    strides=[1, 2, 2, 2],
    block_cls=Block,
    num_classes=NUM_CLASSES,
)

net = make_forward_with_state(net)
params, state = net.init(rng, batch, train=True, print_shape=True)

tx = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.sgd(learning_rate=create_lr_sched(NUM_EPOCH, NUM_TRAIN, BATCH_SIZE, WARMUP_RATIO, PEAK_LR), momentum=0.9)
    )
trainer = Trainer.create(
    apply_fn=net.apply,
    params=params,
    state=state,
    tx=tx,
)

input : (1, 32, 32, 3)
embedding : (1, 32, 32, 64)
block_0_0 : (1, 32, 32, 64)
block_0_1 : (1, 32, 32, 64)
block_1_0 : (1, 16, 16, 128)
block_1_1 : (1, 16, 16, 128)
block_2_0 : (1, 8, 8, 256)
block_2_1 : (1, 8, 8, 256)
block_3_0 : (1, 4, 4, 512)
block_3_1 : (1, 4, 4, 512)
representation : (1, 512)
classifier head : (1, 100)


In [5]:
trainer = load_ckpt('../3_image_classification/result/cifar100/renset_18', trainer)

batch_dims = (num_devices, BATCH_SIZE//num_devices)
train_dataset = load_dataset(NUM_CLASSES, batch_dims)
test_dataset = list(load_dataset(NUM_CLASSES, batch_dims, False, False, False))

test_acc = compute_acc_dataset(replicate(trainer), test_dataset)
print(f'Recovered Test Accuracy : {test_acc:.4f}')

Recovered Test Accuracy : 0.7883
