Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resnet #108

Merged
merged 10 commits into from
Nov 18, 2020
Merged

Resnet #108

merged 10 commits into from
Nov 18, 2020

Conversation

alexander-g
Copy link
Contributor

  • ResNet model architecture and an example for training on ImageNet
    • code is mostly adapted from the flax library
    • pretrained ResNet50 with 76.5% accuracy
    • pretrained ResNet18 with 68.7% accuracy
  • Experimental support for mixed precision: previously all layers set their parameters' dtype to the input's dtype. This is incorrect, for numerical stability reasons all parameters should be float32 even when performing float16 computations. See more here.
  • Some issues I had during training:
    • There seems to be a memory leak during training, RAM constantly increased
    • I had to use smaller batch sizes than when training with flax or with TensorFlow before maxing out GPU memory (64 instead of 128 for ResNet50 on a RTX2080Ti). This might be of course due to a mistake in my code, but the number of parameters is identical to the flax and PyTorch versions, so I think it might be somewhere else

@codecov-io
Copy link

codecov-io commented Nov 15, 2020

Codecov Report

❗ No coverage uploaded for pull request base (feature/custom_jit@f341e2b). Click here to learn what that means.
The diff coverage is n/a.

Impacted file tree graph

@@                  Coverage Diff                  @@
##             feature/custom_jit     #108   +/-   ##
=====================================================
  Coverage                      ?   75.91%           
=====================================================
  Files                         ?      101           
  Lines                         ?     4695           
  Branches                      ?        0           
=====================================================
  Hits                          ?     3564           
  Misses                        ?     1131           
  Partials                      ?        0           

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update f341e2b...3b5ec39. Read the comment docs.

@cgarciae
Copy link
Collaborator

@alexander-g this is awesome, thanks! I'll start reviewing the code.

@@ -0,0 +1,3 @@
#additional requirements
tensorflow-datasets==4.0.1
tensorflow-gpu==2.2.0 #tensorflow-cpu also ok, but with gpu faster
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it faster with tf-gpu? 🤔 Here we are only using tf.data which runs in CPU, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is because JPEG decoding is performed on the GPU. To be honest I have not tested it with elegy, I had noticed this when training with flax and simply adopted it here.

x = self.block_type(dtype=self.dtype)(x, 64 * 2 ** i, strides=strides)
x = jnp.mean(x, axis=(1, 2))
x = nn.Linear(1000, dtype=self.dtype)(x)
x = jnp.asarray(x, jnp.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The asarray here is for casting right? This would be similar to this?

x = x.astype(jnp.float32)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default astype performs a copy of the array even if the type is the same, whereas asarray only performs the copy if conversion is needed.

@cgarciae
Copy link
Collaborator

cgarciae commented Nov 16, 2020

Hey @alexander-g, thanks a lot, this is an amazing contribution! We really appreciate it.

Some notes:

  1. Regarding the memory leak, since I don't yet have the ImageNet files needed to train the model I created a script that uses a generator of random images and labels solely for the purpose to test the training code, you can find it here. I wasn't able to reproduce the memory leak on my machine, the only thing that could potentially explain it right now is that since we yet don't have any device placement policies the scalar arrays for the logs that accumulate during training remain in the GPU indefinitely. Not guaranteed this is the problem and 2 floats per step doesn't seem all that much. Regardless, a device placement API could maybe be discussed in a separate issue.
  2. We have to decide the name for the module where we will keep these standard architectures, nets sounds nice but Keras uses applications.
  3. A point we can start discussing is how to better load pre-trained models, we should try (again) to save the weights separate from the model in a serializable format the avoid library version dependence that come with pickle, HDF5 through the tables package was relatively easy. We should probably follow the weights="imagenet" API from Keras to load pretrained weights.

CC: @charlielito

@alexander-g
Copy link
Contributor Author

I'm glad you appreciate it. I really like JAX and Elegy and would like to contribute more in the future.

to 1: I want to do some profiling myself and will let you know if I find out something
to 3: Consider adding some kind of version control for the pretrained models. I may retrain R18 because the performance is slightly worse than PyTorch, R50 is fine though.

@alexander-g
Copy link
Contributor Author

alexander-g commented Nov 16, 2020

Another issue I forgot: I first tried to convert a pretrained Flax model to Elegy but I got completely different results for identical inputs. Even the very first convolutional layer gave different outputs.
This doesn't have to be wrong but I wanted to mention it.

@cgarciae
Copy link
Collaborator

Our current implementation for most of the layers is taken from Haiku, for the case of Conv it only has a slight modification to support feature grouping. This should be a separate issue, but I think we should do the same we do for losses and metrics:

  • Try to expose the Keras API
  • Test numerically equivalence against a base implementation, could be Flax if we get the benefit of porting their pretrained models.

@cgarciae
Copy link
Collaborator

@alexander-g I am going to merge this branch as it contains useful fixes.
Feel free to open a new PR if you want to continue improving it.

@cgarciae cgarciae merged commit 22757e0 into poets-ai:feature/custom_jit Nov 18, 2020
cgarciae added a commit that referenced this pull request Nov 18, 2020
* save

* initial refactor

* jit

* jit + init_jit

* handle rng

* jit + value_and_grad

* save

* save

* save

* fix metrics_loss

* save

* save

* *_on_batch methods

* get_states

* save

* fix tests

* fix examples

* format black

* use pickle only to save

* clean model

* save

* [Fix] Return all files to 0644 file permisions

* fix docs

* update module-system guide

* update README

* fix elegy.jit

* update jax

* fix tests

* small refactor

* jupyter dev dependency

* update docs

* update poetry in github actions

* use --no-hashes

* use --without-hashes

* update requirements during docs deployment

* especify poetry >= 1.1.4 as a dev dependency

* fix wraps init

* Resnet (#108)

* added resnet18

* imagenet input pipeline, from https://github.com/google/flax

* experimental support for mixed precision

* full training script

* black + resnet test

* format black

* re-jit when loading a model for compability among platforms

* format black

* use different poetry installer

Co-authored-by: Cristian Garcia <cgarcia.e88@gmail.com>

Co-authored-by: David Cardozo <david@cerberusdata.ai>
Co-authored-by: alexander-g <3867427+alexander-g@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants