Skip to content
A PyTorch implementation of MobileNet V2 architecture and pretrained model.
Branch: master
Clone or download
Latest commit 83ef3f8 May 16, 2019
Permalink
Type Name Latest commit message Commit time
Failed to load latest commit information.
LICENSE initial commit Jan 21, 2018
MobileNetV2.py Fix an difference in implementation again official TF Aug 6, 2018
README.md Add training recipe May 16, 2019

README.md

A PyTorch implementation of MobileNetV2

This is a PyTorch implementation of MobileNetV2 architecture as described in the paper Inverted Residuals and Linear Bottlenecks: Mobile Networks for Classification, Detection and Segmentation.

[NEW] I fixed a difference in implementation compared to the official TensorFlow model. Please use the new model file and checkpoint!

Training Recipe

Recently I have figured out a good training setting:

  1. number of epochs: 150
  2. learning rate schedule: cosine learning rate, initial lr=0.05
  3. weight decay: 4e-5
  4. remove dropout

You should get >72% top-1 accuracy with this training recipe!

Accuracy & Statistics

Here is a comparison of statistics against the official TensorFlow implementation.

FLOPs Parameters Top1-acc Pretrained Model
Official TF 300 M 3.47 M 71.8% -
Ours 300.775 M 3.471 M 71.8% [google drive]

Usage

To use the pretrained model, run

from MobileNetV2 import MobileNetV2

net = MobileNetV2(n_class=1000)
state_dict = torch.load('mobilenetv2.pth.tar') # add map_location='cpu' if no gpu
net.load_state_dict(state_dict)

Data Pre-processing

I used the following code for data pre-processing on ImageNet:

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

input_size = 224
train_dataset = datasets.ImageFolder(
    traindir,
    transforms.Compose([
        transforms.RandomResizedCrop(input_size, scale=(0.2, 1.0)), 
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True,
    num_workers=n_worker, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(valdir, transforms.Compose([
        transforms.Resize(int(input_size/0.875)),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=batch_size, shuffle=False,
    num_workers=n_worker, pin_memory=True)
You can’t perform that action at this time.