Skip to content

This repository hosts the code to port NumPy model weights of BiT-ResNets to TensorFlow SavedModel format.

License

Notifications You must be signed in to change notification settings

sayakpaul/BiT-jax2tf

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 

Repository files navigation

BiT-jax2tf

This repository hosts the code to port NumPy model weights of BiT-ResNets [1] to TensorFlow SavedModel format. These models are results of [2]. The original model weights come from [3].

Huge thanks to Willi Gierke (of Google) for helping with the porting.

The TensorFlow SavedModels are available on TensorFlow Hub as a collection: https://tfhub.dev/sayakpaul/collections/bit-resnet/1. A total of 8 models are available:

Model
Name
Input
Resolution
Classifier Feature
Extractor
BiT-ResNet152x2 384 Link Link
BiT-ResNet152x2 224 Link Link
BiT-ResNet50x1 224 Link Link
BiT-ResNet50x1 160 Link Link

You could use the convert_jax_weights_tf.ipynb notebook to understand how model porting works between JAX and TensorFlow. There is also an experimental tool called jax2tf from the JAX team that you can find here.

References

[1] Big Transfer (BiT): General Visual Representation Learning by Kolesnikov et al.

[2] Knowledge distillation: A good teacher is patient and consistent by Beyer et al.

[3] BiT GitHub

About

This repository hosts the code to port NumPy model weights of BiT-ResNets to TensorFlow SavedModel format.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published