-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Resnet #108
Resnet #108
Conversation
Codecov Report
@@ Coverage Diff @@
## feature/custom_jit #108 +/- ##
=====================================================
Coverage ? 75.91%
=====================================================
Files ? 101
Lines ? 4695
Branches ? 0
=====================================================
Hits ? 3564
Misses ? 1131
Partials ? 0 Continue to review full report at Codecov.
|
@alexander-g this is awesome, thanks! I'll start reviewing the code. |
@@ -0,0 +1,3 @@ | |||
#additional requirements | |||
tensorflow-datasets==4.0.1 | |||
tensorflow-gpu==2.2.0 #tensorflow-cpu also ok, but with gpu faster |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is it faster with tf-gpu? 🤔 Here we are only using tf.data
which runs in CPU, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is because JPEG decoding is performed on the GPU. To be honest I have not tested it with elegy, I had noticed this when training with flax and simply adopted it here.
x = self.block_type(dtype=self.dtype)(x, 64 * 2 ** i, strides=strides) | ||
x = jnp.mean(x, axis=(1, 2)) | ||
x = nn.Linear(1000, dtype=self.dtype)(x) | ||
x = jnp.asarray(x, jnp.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The asarray
here is for casting right? This would be similar to this?
x = x.astype(jnp.float32)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By default astype
performs a copy of the array even if the type is the same, whereas asarray
only performs the copy if conversion is needed.
Hey @alexander-g, thanks a lot, this is an amazing contribution! We really appreciate it. Some notes:
CC: @charlielito |
I'm glad you appreciate it. I really like JAX and Elegy and would like to contribute more in the future. to 1: I want to do some profiling myself and will let you know if I find out something |
Another issue I forgot: I first tried to convert a pretrained Flax model to Elegy but I got completely different results for identical inputs. Even the very first convolutional layer gave different outputs. |
Our current implementation for most of the layers is taken from Haiku, for the case of Conv it only has a slight modification to support feature grouping. This should be a separate issue, but I think we should do the same we do for losses and metrics:
|
@alexander-g I am going to merge this branch as it contains useful fixes. |
* save * initial refactor * jit * jit + init_jit * handle rng * jit + value_and_grad * save * save * save * fix metrics_loss * save * save * *_on_batch methods * get_states * save * fix tests * fix examples * format black * use pickle only to save * clean model * save * [Fix] Return all files to 0644 file permisions * fix docs * update module-system guide * update README * fix elegy.jit * update jax * fix tests * small refactor * jupyter dev dependency * update docs * update poetry in github actions * use --no-hashes * use --without-hashes * update requirements during docs deployment * especify poetry >= 1.1.4 as a dev dependency * fix wraps init * Resnet (#108) * added resnet18 * imagenet input pipeline, from https://github.com/google/flax * experimental support for mixed precision * full training script * black + resnet test * format black * re-jit when loading a model for compability among platforms * format black * use different poetry installer Co-authored-by: Cristian Garcia <cgarcia.e88@gmail.com> Co-authored-by: David Cardozo <david@cerberusdata.ai> Co-authored-by: alexander-g <3867427+alexander-g@users.noreply.github.com>
dtype
to the input'sdtype
. This is incorrect, for numerical stability reasons all parameters should befloat32
even when performingfloat16
computations. See more here.