-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* save * initial refactor * jit * jit + init_jit * handle rng * jit + value_and_grad * save * save * save * fix metrics_loss * save * save * *_on_batch methods * get_states * save * fix tests * fix examples * format black * use pickle only to save * clean model * save * [Fix] Return all files to 0644 file permisions * fix docs * update module-system guide * update README * fix elegy.jit * update jax * fix tests * small refactor * jupyter dev dependency * update docs * update poetry in github actions * use --no-hashes * use --without-hashes * update requirements during docs deployment * especify poetry >= 1.1.4 as a dev dependency * fix wraps init * Resnet (#108) * added resnet18 * imagenet input pipeline, from https://github.com/google/flax * experimental support for mixed precision * full training script * black + resnet test * format black * re-jit when loading a model for compability among platforms * format black * use different poetry installer Co-authored-by: Cristian Garcia <cgarcia.e88@gmail.com> Co-authored-by: David Cardozo <david@cerberusdata.ai> Co-authored-by: alexander-g <3867427+alexander-g@users.noreply.github.com>
- Loading branch information
1 parent
19ec87b
commit 9538853
Showing
14 changed files
with
559 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from . import resnet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# adapted from the flax library https://github.com/google/flax | ||
|
||
import jax, jax.numpy as jnp | ||
from elegy import module, nn | ||
|
||
|
||
class ResNetBlock(module.Module): | ||
"""ResNet (identity) block""" | ||
|
||
def call(self, x, n_filters, strides=(1, 1)): | ||
x0 = x | ||
x = nn.Conv2D( | ||
n_filters, (3, 3), with_bias=False, stride=strides, dtype=self.dtype | ||
)(x) | ||
x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) | ||
x = jax.nn.relu(x) | ||
|
||
x = nn.Conv2D(n_filters, (3, 3), with_bias=False, dtype=self.dtype)(x) | ||
x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) | ||
|
||
if x0.shape != x.shape: | ||
x0 = nn.Conv2D( | ||
n_filters, (1, 1), with_bias=False, stride=strides, dtype=self.dtype | ||
)(x0) | ||
x0 = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x0) | ||
return jax.nn.relu(x0 + x) | ||
|
||
|
||
class BottleneckResNetBlock(module.Module): | ||
"""ResNet Bottleneck block.""" | ||
|
||
def call(self, x, n_filters, strides=(1, 1)): | ||
x0 = x | ||
x = nn.Conv2D(n_filters, (1, 1), with_bias=False, dtype=self.dtype)(x) | ||
x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) | ||
x = jax.nn.relu(x) | ||
x = nn.Conv2D( | ||
n_filters, (3, 3), with_bias=False, stride=strides, dtype=self.dtype | ||
)(x) | ||
x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) | ||
x = jax.nn.relu(x) | ||
x = nn.Conv2D(n_filters * 4, (1, 1), with_bias=False, dtype=self.dtype)(x) | ||
x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5, scale_init=jnp.zeros)(x) | ||
|
||
if x0.shape != x.shape: | ||
x0 = nn.Conv2D( | ||
n_filters * 4, (1, 1), with_bias=False, stride=strides, dtype=self.dtype | ||
)(x0) | ||
x0 = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x0) | ||
return jax.nn.relu(x0 + x) | ||
|
||
|
||
class ResNet(module.Module): | ||
"""ResNet V1""" | ||
|
||
def __init__(self, stages, block_type, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.stages = stages | ||
self.block_type = block_type | ||
|
||
def call(self, x): | ||
x = nn.Conv2D( | ||
64, (7, 7), stride=(2, 2), padding="SAME", with_bias=False, dtype=self.dtype | ||
)(x) | ||
x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) | ||
x = jax.nn.relu(x) | ||
|
||
x = nn.linear.hk.max_pool( | ||
x, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding="SAME" | ||
) | ||
for i, block_size in enumerate(self.stages): | ||
for j in range(block_size): | ||
strides = (2, 2) if i > 0 and j == 0 else (1, 1) | ||
x = self.block_type(dtype=self.dtype)(x, 64 * 2 ** i, strides=strides) | ||
x = jnp.mean(x, axis=(1, 2)) | ||
x = nn.Linear(1000, dtype=self.dtype)(x) | ||
x = jnp.asarray(x, jnp.float32) | ||
return x | ||
|
||
|
||
class ResNet18(ResNet): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(stages=[2, 2, 2, 2], block_type=ResNetBlock, *args, **kwargs) | ||
|
||
|
||
class ResNet34(ResNet): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(stages=[3, 4, 6, 3], block_type=ResNetBlock, *args, **kwargs) | ||
|
||
|
||
class ResNet50(ResNet): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__( | ||
stages=[3, 4, 6, 3], block_type=BottleneckResNetBlock, *args, **kwargs | ||
) | ||
|
||
|
||
class ResNet101(ResNet): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__( | ||
stages=[3, 4, 23, 3], block_type=BottleneckResNetBlock, *args, **kwargs | ||
) | ||
|
||
|
||
class ResNet152(ResNet): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__( | ||
stages=[3, 8, 36, 3], block_type=BottleneckResNetBlock, *args, **kwargs | ||
) | ||
|
||
|
||
class ResNet200(ResNet): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__( | ||
stages=[3, 24, 36, 3], block_type=BottleneckResNetBlock, *args, **kwargs | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from elegy import utils | ||
|
||
import jax.numpy as jnp | ||
from unittest import TestCase | ||
|
||
import elegy | ||
|
||
|
||
class ResNetTest(TestCase): | ||
def test_basic_predict(self): | ||
# FIXME: test succeeds if run alone or if run on the cpu-only version of jax | ||
# test fails with "DNN library is not found" if run on gpu with all other tests together | ||
|
||
model = elegy.Model(elegy.nets.resnet.ResNet18()) | ||
y = model.predict(jnp.zeros((2, 224, 224, 3))) | ||
assert jnp.all(y.shape == (2, 1000)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.