Skip to content

Commit

Permalink
merge resnet into master (#111)
Browse files Browse the repository at this point in the history
* save

* initial refactor

* jit

* jit + init_jit

* handle rng

* jit + value_and_grad

* save

* save

* save

* fix metrics_loss

* save

* save

* *_on_batch methods

* get_states

* save

* fix tests

* fix examples

* format black

* use pickle only to save

* clean model

* save

* [Fix] Return all files to 0644 file permisions

* fix docs

* update module-system guide

* update README

* fix elegy.jit

* update jax

* fix tests

* small refactor

* jupyter dev dependency

* update docs

* update poetry in github actions

* use --no-hashes

* use --without-hashes

* update requirements during docs deployment

* especify poetry >= 1.1.4 as a dev dependency

* fix wraps init

* Resnet (#108)

* added resnet18

* imagenet input pipeline, from https://github.com/google/flax

* experimental support for mixed precision

* full training script

* black + resnet test

* format black

* re-jit when loading a model for compability among platforms

* format black

* use different poetry installer

Co-authored-by: Cristian Garcia <cgarcia.e88@gmail.com>

Co-authored-by: David Cardozo <david@cerberusdata.ai>
Co-authored-by: alexander-g <3867427+alexander-g@users.noreply.github.com>
  • Loading branch information
3 people committed Nov 18, 2020
1 parent 19ec87b commit 9538853
Show file tree
Hide file tree
Showing 14 changed files with 559 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
python-version: ${{ matrix.python-version }}

- name: Install Poetry
uses: dschep/install-poetry-action@v1.2
uses: snok/install-poetry@v1.1.1
with:
version: 1.1.4

Expand Down
13 changes: 12 additions & 1 deletion elegy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
__version__ = "0.2.2"


from . import callbacks, initializers, losses, metrics, model, module, nn, regularizers
from . import (
callbacks,
initializers,
losses,
metrics,
model,
module,
nets,
nn,
regularizers,
)
from .losses import Loss
from .metrics import Metric
from .model import Model
Expand Down Expand Up @@ -38,6 +48,7 @@
"losses",
"metrics",
"model",
"nets",
"nn",
"regularizers",
"hooks_context",
Expand Down
10 changes: 7 additions & 3 deletions elegy/model/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,18 @@ def __init__(
self.loss = Losses(loss) if loss is not None else None
self.metrics = Metrics(metrics)
self.optimizer = Optimizer(optimizer) if optimizer is not None else None
self.predict_fn_jit = elegy_jit(self.predict_fn, modules=self)
self.test_fn_jit = elegy_jit(self.test_fn, modules=self)
self.train_fn_jit = elegy_jit(self.train_fn, modules=self)
self._jit_functions()
self.initial_metrics_state: tp.Optional[tp.Dict[str, tp.Any]] = None
self.run_eagerly = run_eagerly

utils.wraps(self.module)(self)

def _jit_functions(self):
super()._jit_functions()
self.predict_fn_jit = elegy_jit(self.predict_fn, modules=self)
self.test_fn_jit = elegy_jit(self.test_fn, modules=self)
self.train_fn_jit = elegy_jit(self.train_fn, modules=self)

def call(self, *args, **kwargs):
return self.module(*args, **kwargs)

Expand Down
11 changes: 11 additions & 0 deletions elegy/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,16 @@ def init(*args, **kwargs):
utils.wraps(self.call)(self.init)
utils.wraps(self.call)(self)

self._jit_functions()

def _jit_functions(self):
self.jit = jit(self)
self.init_jit = jit(self.init, modules=self)

def __setstate__(self, d):
self.__dict__ = d
self._jit_functions()

@property
def initialized(self) -> bool:
return self._initialized
Expand Down Expand Up @@ -882,6 +889,10 @@ def get_static_context() -> "StaticContext":
def set_context(static: "StaticContext", dynamic: "DynamicContext"):
LOCAL.set_from(static, dynamic)

def _grad_fn(parameters_tuple: tp.Tuple[tp.Dict, ...], *args, **kwargs):
assert isinstance(parameters_tuple, tuple)
assert isinstance(modules, list)


# -------------------------------------------------------------
# transforms
Expand Down
1 change: 1 addition & 0 deletions elegy/nets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import resnet
116 changes: 116 additions & 0 deletions elegy/nets/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# adapted from the flax library https://github.com/google/flax

import jax, jax.numpy as jnp
from elegy import module, nn


class ResNetBlock(module.Module):
"""ResNet (identity) block"""

def call(self, x, n_filters, strides=(1, 1)):
x0 = x
x = nn.Conv2D(
n_filters, (3, 3), with_bias=False, stride=strides, dtype=self.dtype
)(x)
x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x)
x = jax.nn.relu(x)

x = nn.Conv2D(n_filters, (3, 3), with_bias=False, dtype=self.dtype)(x)
x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x)

if x0.shape != x.shape:
x0 = nn.Conv2D(
n_filters, (1, 1), with_bias=False, stride=strides, dtype=self.dtype
)(x0)
x0 = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x0)
return jax.nn.relu(x0 + x)


class BottleneckResNetBlock(module.Module):
"""ResNet Bottleneck block."""

def call(self, x, n_filters, strides=(1, 1)):
x0 = x
x = nn.Conv2D(n_filters, (1, 1), with_bias=False, dtype=self.dtype)(x)
x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x)
x = jax.nn.relu(x)
x = nn.Conv2D(
n_filters, (3, 3), with_bias=False, stride=strides, dtype=self.dtype
)(x)
x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x)
x = jax.nn.relu(x)
x = nn.Conv2D(n_filters * 4, (1, 1), with_bias=False, dtype=self.dtype)(x)
x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5, scale_init=jnp.zeros)(x)

if x0.shape != x.shape:
x0 = nn.Conv2D(
n_filters * 4, (1, 1), with_bias=False, stride=strides, dtype=self.dtype
)(x0)
x0 = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x0)
return jax.nn.relu(x0 + x)


class ResNet(module.Module):
"""ResNet V1"""

def __init__(self, stages, block_type, *args, **kwargs):
super().__init__(*args, **kwargs)
self.stages = stages
self.block_type = block_type

def call(self, x):
x = nn.Conv2D(
64, (7, 7), stride=(2, 2), padding="SAME", with_bias=False, dtype=self.dtype
)(x)
x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x)
x = jax.nn.relu(x)

x = nn.linear.hk.max_pool(
x, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding="SAME"
)
for i, block_size in enumerate(self.stages):
for j in range(block_size):
strides = (2, 2) if i > 0 and j == 0 else (1, 1)
x = self.block_type(dtype=self.dtype)(x, 64 * 2 ** i, strides=strides)
x = jnp.mean(x, axis=(1, 2))
x = nn.Linear(1000, dtype=self.dtype)(x)
x = jnp.asarray(x, jnp.float32)
return x


class ResNet18(ResNet):
def __init__(self, *args, **kwargs):
super().__init__(stages=[2, 2, 2, 2], block_type=ResNetBlock, *args, **kwargs)


class ResNet34(ResNet):
def __init__(self, *args, **kwargs):
super().__init__(stages=[3, 4, 6, 3], block_type=ResNetBlock, *args, **kwargs)


class ResNet50(ResNet):
def __init__(self, *args, **kwargs):
super().__init__(
stages=[3, 4, 6, 3], block_type=BottleneckResNetBlock, *args, **kwargs
)


class ResNet101(ResNet):
def __init__(self, *args, **kwargs):
super().__init__(
stages=[3, 4, 23, 3], block_type=BottleneckResNetBlock, *args, **kwargs
)


class ResNet152(ResNet):
def __init__(self, *args, **kwargs):
super().__init__(
stages=[3, 8, 36, 3], block_type=BottleneckResNetBlock, *args, **kwargs
)


class ResNet200(ResNet):
def __init__(self, *args, **kwargs):
super().__init__(
stages=[3, 24, 36, 3], block_type=BottleneckResNetBlock, *args, **kwargs
)
16 changes: 16 additions & 0 deletions elegy/nets/resnet_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from elegy import utils

import jax.numpy as jnp
from unittest import TestCase

import elegy


class ResNetTest(TestCase):
def test_basic_predict(self):
# FIXME: test succeeds if run alone or if run on the cpu-only version of jax
# test fails with "DNN library is not found" if run on gpu with all other tests together

model = elegy.Model(elegy.nets.resnet.ResNet18())
y = model.predict(jnp.zeros((2, 224, 224, 3)))
assert jnp.all(y.shape == (2, 1000))
7 changes: 5 additions & 2 deletions elegy/nn/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def call(
Returns:
The array, normalized across all but the last dimension.
"""
inputs = jnp.asarray(inputs, jnp.float32)

if training is None:
training = module.is_training()

Expand Down Expand Up @@ -161,7 +163,7 @@ def call(
self.var_ema(var)

w_shape = [1 if i in axis else inputs.shape[i] for i in range(inputs.ndim)]
w_dtype = inputs.dtype
w_dtype = jnp.float32

if self.create_scale:
scale = self.add_parameter("scale", w_shape, w_dtype, self.scale_init)
Expand All @@ -174,4 +176,5 @@ def call(
offset = np.zeros([], dtype=w_dtype)

inv = scale * jax.lax.rsqrt(var + self.eps)
return (inputs - mean) * inv + offset
output = (inputs - mean) * inv + offset
return jnp.asarray(output, self.dtype)
7 changes: 5 additions & 2 deletions elegy/nn/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,13 @@ def call(self, inputs: np.ndarray) -> np.ndarray:
fan_in_shape = np.prod(w_shape[:-1])
stddev = 1.0 / np.sqrt(fan_in_shape)
w_init = initializers.TruncatedNormal(stddev=stddev)
w = self.add_parameter("w", w_shape, inputs.dtype, initializer=w_init)
w = self.add_parameter("w", w_shape, jnp.float32, initializer=w_init)

if self.mask is not None:
w *= self.mask

inputs = jnp.asarray(inputs, dtype=self.dtype)
w = jnp.asarray(w, dtype=self.dtype)
out = lax.conv_general_dilated(
inputs,
w,
Expand All @@ -201,9 +203,10 @@ def call(self, inputs: np.ndarray) -> np.ndarray:
else:
bias_shape = (self.output_channels,) + (1,) * self.num_spatial_dims
b = self.add_parameter(
"b", bias_shape, inputs.dtype, initializer=self.b_init
"b", bias_shape, jnp.float32, initializer=self.b_init
)
b = jnp.broadcast_to(b, out.shape)
b = jnp.asarray(b, self.dtype)
out = out + b

return out
Expand Down
5 changes: 4 additions & 1 deletion elegy/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def call(self, inputs: np.ndarray) -> np.ndarray:

input_size = self.input_size = inputs.shape[-1]
output_size = self.output_size
dtype = inputs.dtype
dtype = jnp.float32

w_init = self.w_init

Expand All @@ -60,13 +60,16 @@ def call(self, inputs: np.ndarray) -> np.ndarray:
"w", [input_size, output_size], dtype, initializer=w_init
)

inputs = jnp.asarray(inputs, self.dtype)
w = jnp.asarray(w, self.dtype)
out = jnp.dot(inputs, w)

if self.with_bias:
b = self.add_parameter(
"b", [self.output_size], dtype, initializer=self.b_init
)
b = jnp.broadcast_to(b, out.shape)
b = jnp.asarray(b, self.dtype)
out = out + b

return out
6 changes: 3 additions & 3 deletions elegy/nn/moving_averages.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def _cond(self, cond, t, f, dtype):
def initialize(self, value):
"""If uninitialized sets the average to ``zeros_like`` the given value."""
self.add_parameter(
"hidden", value.shape, value.dtype, initializer=jnp.zeros, trainable=False
"hidden", value.shape, jnp.float32, initializer=jnp.zeros, trainable=False
)
self.add_parameter(
"average", value.shape, value.dtype, initializer=jnp.zeros, trainable=False
"average", value.shape, jnp.float32, initializer=jnp.zeros, trainable=False
)

def call(self, value, update_stats=True):
Expand Down Expand Up @@ -104,7 +104,7 @@ def call(self, value, update_stats=True):

one = jnp.ones([], value.dtype)
hidden = self.add_parameter(
"hidden", value.shape, value.dtype, initializer=jnp.zeros, trainable=False
"hidden", value.shape, jnp.float32, initializer=jnp.zeros, trainable=False
)
hidden = hidden * decay + value * (one - decay)

Expand Down

0 comments on commit 9538853

Please sign in to comment.