This code modifies the output layer of image classification network commonly used in pytorch. The modified model can be used to process any number of image classification data. At the same time, the pre-training parameters and re-training parameters are distinguished for fine-tuning and component training of parameters. We have modified the following network structure:
MobileNet | mobilenet_v2 | mobilenet_v3_small | mobilenet_v3_large | |
ResNet | resnet18 | resnet34 | resnet50 | resnet101 |
ResNeXt | resnext50_32x4d | resnext101_32x8d | ||
DenseNet | densenet121 | densenet161 | densenet169 | |
ShuffleNet | shufflenet_v2_x0_5 | shufflenet_v2_x1_0 | ||
SqueezeNet | squeezenet1_0 | squeezenet1_1 | ||
WideResNet | wide_resnet50_2 | wide_resnet101_2 |
- torch==1.9.0
- torchvision==0.10.0
- modify data/classifier.yaml
- classes: class list
- dataset_dir: dataset path
- Cat
- 1.img
- 2.img
- Dog
- 1.img
- 2.img
- Cat
- save_dir: model save path
- train_size: random choice sample number per class
- model train:
python train_classifier_model.py
- model inference:
python model_inference.py
we use the ./data/ciassifier_cifar100.yaml and train_cifar100_model.py to train cifar100 dataset. training curve in ./images dir.
model | accuracy | fps(batch=1, img_size=224x224) | GPU memory | GPU Peak(Tesla P40) | model_size(bit size) | Total Params |
mobilenet_v2 | 0.7123 | 164 | 847M | 43% | 9.3M | 2,236,682 |
mobilenet_v3_small | 0.7112 | 144 | 847M | 36% | 6.4M | 1,528,106 |
mobilenet_v3_large | 0.7521 | 138 | 855M | 50% | 17M | 4,214,842 |
resnet18 | 0.7139 | 236 | 913M | 56% | 43M | 11,181,642 |
resnet34 | 0.7573 | 182 | 955M | 59% | 82M | 21,289,802 |
resnet50 | 0.7683 | 123 | 967M | 74% | 91M | 23,528,522 |
resnet101 | 0.8025 | 81 | 1041M | 89% | 164M | 42,520,650 |
resnext50_32x4d | 0.7806 | 67 | 933M | 73% | 89M | 23,000,394 |
resnext101_32x8d | 0.8198 | 31 | 1195M | 88% | 333M | 86,762,826 |
densenet121 | 0.768 | 69 | 869M | 66% | 28M | 6,964,106 |
densenet161 | 0.8027 | 45 | 957M | 86% | 104M | 26,494,090 |
densenet169 | 0.7868 | 50 | 889M | 65% | 50M | 12,501,130 |
shufflenet_v2_x0_5 | 0.2203 | 136 | 843M | 24% | 1.9M | 352,042 |
shufflenet_v2_x1_0 | 0.3329 | 150 | 845M | 32% | 5.4M | 1,263,854 |
squeezenet1_0 | 0.5605 | 292 | 837M | 47% | 3.1M | 740,554 |
squeezenet1_1 | 0.575 | 266 | 837M | 41% | 3.0M | 727,626 |
wide_resnet50_2 | 0.789 | 90 | 1227M | 88% | 257M | 66,854,730 |
wide_resnet101_2 | 0.8122 | 50 | 1469M | 93% | 478M | 124,858,186 |