Skip to content

Commit

Permalink
[ViT] Graduate ViT from prototype (#5173)
Browse files Browse the repository at this point in the history
* graduate vit from prototype

* nit

* add vit to docs and hubconf

* ufmt

* re-correct ufmt

* again

* fix linter
  • Loading branch information
yiwen-song committed Jan 10, 2022
1 parent d675c0c commit 68f511e
Show file tree
Hide file tree
Showing 5 changed files with 464 additions and 332 deletions.
26 changes: 26 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ architectures for image classification:
- `MNASNet`_
- `EfficientNet`_
- `RegNet`_
- `VisionTransformer`_

You can construct a model with random weights by calling its constructor:

Expand Down Expand Up @@ -82,6 +83,10 @@ You can construct a model with random weights by calling its constructor:
regnet_x_8gf = models.regnet_x_8gf()
regnet_x_16gf = models.regnet_x_16gf()
regnet_x_32gf = models.regnet_x_32gf()
vit_b_16 = models.vit_b_16()
vit_b_32 = models.vit_b_32()
vit_l_16 = models.vit_l_16()
vit_l_32 = models.vit_l_32()
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
These can be constructed by passing ``pretrained=True``:
Expand Down Expand Up @@ -125,6 +130,10 @@ These can be constructed by passing ``pretrained=True``:
regnet_x_8gf = models.regnet_x_8gf(pretrained=True)
regnet_x_16gf = models.regnet_x_16gf(pretrainedTrue)
regnet_x_32gf = models.regnet_x_32gf(pretrained=True)
vit_b_16 = models.vit_b_16(pretrained=True)
vit_b_32 = models.vit_b_32(pretrained=True)
vit_l_16 = models.vit_l_16(pretrained=True)
vit_l_32 = models.vit_l_32(pretrained=True)
Instancing a pre-trained model will download its weights to a cache directory.
This directory can be set using the `TORCH_HOME` environment variable. See
Expand Down Expand Up @@ -233,6 +242,10 @@ regnet_y_3_2gf 78.948 94.576
regnet_y_8gf 80.032 95.048
regnet_y_16gf 80.424 95.240
regnet_y_32gf 80.878 95.340
vit_b_16 81.072 95.318
vit_b_32 75.912 92.466
vit_l_16 79.662 94.638
vit_l_32 76.972 93.070
================================ ============= =============


Expand All @@ -250,6 +263,7 @@ regnet_y_32gf 80.878 95.340
.. _MNASNet: https://arxiv.org/abs/1807.11626
.. _EfficientNet: https://arxiv.org/abs/1905.11946
.. _RegNet: https://arxiv.org/abs/2003.13678
.. _VisionTransformer: https://arxiv.org/abs/2010.11929

.. currentmodule:: torchvision.models

Expand Down Expand Up @@ -433,6 +447,18 @@ RegNet
regnet_x_16gf
regnet_x_32gf

VisionTransformer
-----------------

.. autosummary::
:toctree: generated/
:template: function.rst

vit_b_16
vit_b_32
vit_l_16
vit_l_32

Quantized Models
----------------

Expand Down
9 changes: 6 additions & 3 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Optional list of dependencies required by the package
dependencies = ["torch"]

# classification
from torchvision.models.alexnet import alexnet
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
from torchvision.models.efficientnet import (
Expand Down Expand Up @@ -47,8 +46,6 @@
wide_resnet50_2,
wide_resnet101_2,
)

# segmentation
from torchvision.models.segmentation import (
fcn_resnet50,
fcn_resnet101,
Expand All @@ -60,3 +57,9 @@
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from torchvision.models.vision_transformer import (
vit_b_16,
vit_b_32,
vit_l_16,
vit_l_32,
)
1 change: 1 addition & 0 deletions torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .shufflenetv2 import *
from .efficientnet import *
from .regnet import *
from .vision_transformer import *
from . import detection
from . import feature_extraction
from . import optical_flow
Expand Down

0 comments on commit 68f511e

Please sign in to comment.