This subpackage provides a variety of pre-trained state-of-the-art models which is trained on ImageNet dataset.
The pre-trained models can be used for both inference and training as following:
# Create ResNet-50 for inference
import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF
import numpy as np
from nnabla.models.imagenet import ResNet50
model = ResNet50()
batch_size = 1
# model.input_shape returns (3, 224, 224) when ResNet-50
x = nn.Variable((batch_size,) + model.input_shape)
y = model(x, training=False)
# Execute inference
# Load input image as uint8 array with shape of (3, 224, 224)
from nnabla.utils.image_utils import imread
img = imread('example.jpg', size=model.input_shape[1:], channel_first=True)
x.d[0] = img
y.forward()
predicted_label = np.argmax(y.d[0])
print('Predicted label:', model.category_names[predicted_label])
# Create ResNet-50 for fine-tuning
batch_size=32
x = nn.Variable((batch_size,) + model.input_shape)
# * By training=True, it sets batch normalization mode for training
# and gives trainable attributes to parameters.
# * By use_up_to='pool', it creates a network up to the output of
# the final global average pooling.
pool = model(x, training=True, use_up_to='pool')
# Add a classification layer for another 10 category dataset
# and loss function
num_classes = 10
y = PF.affine(pool, num_classes, name='classifier10')
t = nn.Variable((batch_size, 1))
loss = F.sum(F.softmax_cross_entropy(y, t))
# Training...
Available models are summarized in the following table. Error rates are calculated using single center crop.
Available ImageNet modelsName | Class | Top-1 error | Top-5 error | Trained by/with |
---|---|---|---|---|
ResNet-18 | ResNet18 | 30.28 | 10.90 | Neural Network Console |
ResNet-34 | ResNet34 | 26.72 | 8.89 | Neural Network Console |
ResNet-50 | ResNet50 | 24.59 | 7.48 | Neural Network Console |
ResNet-101 | ResNet101 | 23.81 | 7.01 | Neural Network Console |
ResNet-152 | ResNet152 | 23.48 | 7.09 | Neural Network Console |
MobileNet | MobileNet | 29.51 | 10.34 | Neural Network Console |
MobileNetV2 | MobileNetV2 | 29.94 | 10.82 | Neural Network Console |
SENet-154 | SENet | 22.04 | 6.29 | Neural Network Console |
SqueezeNet v1.0 | SqueezeNetV10 | 42.71 | 20.12 | Neural Network Console |
SqueezeNet v1.1 | SqueezeNetV11 | 41.23 | 19.18 | Neural Network Console |
VGG-11 | VGG11 | 30.85 | 11.38 | Neural Network Console |
VGG-13 | VGG13 | 29.51 | 10.46 | Neural Network Console |
VGG-16 | VGG16 | 29.03 | 10.07 | Neural Network Console |
NIN | NIN | 42.91 | 20.66 | Neural Network Console |
DenseNet-161 | DenseNet | 23.82 | 7.02 | Neural Network Console |
InceptionV3 | InceptionV3 | 21.82 | 5.88 | Neural Network Console |
Xception | Xception | 23.59 | 6.91 | Neural Network Console |
GoogLeNet | GoogLeNet | 31.22 | 11.34 | Neural Network Console |
ResNeXt-50 | ResNeXt50 | 22.95 | 6.73 | Neural Network Console |
ResNeXt-101 | ResNeXt101 | 22.80 | 6.74 | Neural Network Console |
ShuffleNet | ShuffleNet10 | 34.15 | 13.85 | Neural Network Console |
ShuffleNet-0.5x | ShuffleNet05 | 41.99 | 19.64 | Neural Network Console |
ShuffleNet-2.0x | ShuffleNet20 | 30.34 | 11.12 | Neural Network Console |
nnabla.models.imagenet.base
ImageNetBase
nnabla.models.imagenet
ResNet18
ResNet34
ResNet50
ResNet101
ResNet152
ResNet
MobileNet
MobileNetV2
SENet
SqueezeNetV10
SqueezeNetV11
SqueezeNet
VGG11
VGG13
VGG16
VGG
NIN
DenseNet
InceptionV3
Xception
GoogLeNet
ResNeXt50
ResNeXt101
ResNeXt
ShuffleNet10
ShuffleNet05
ShuffleNet20
ShuffleNet