Skip to content

Commit

Permalink
Added ESPNetv2
Browse files Browse the repository at this point in the history
  • Loading branch information
sacmehta committed Jun 7, 2019
1 parent 166120c commit 2902d3c
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 17 deletions.
2 changes: 1 addition & 1 deletion model/classification/model_zoo/README.md
Expand Up @@ -12,7 +12,7 @@
| 1.25 | 224x224 | 98.11 | 67.77 | 87.64 | [here](/model/classification/model_zoo/dicenet/dicenet_s_1.25_imagenet_224x224.pth) |
| 1.5 | 224x224 | 135.48 | 69.51 | 88.76 | [here](/model/classification/model_zoo/dicenet/dicenet_s_1.5_imagenet_224x224.pth) |
| 1.75 | 224x224 | 182.24 | 70.26 | 89.33 | [here](/model/classification/model_zoo/dicenet/dicenet_s_1.75_imagenet_224x224.pth) |
| 2.0 | 224x224 | 236.29 | | | [here](/model/classification/model_zoo/dicenet/dicenet_s_2.0_imagenet_224x224.pth) |
| 2.0 | 224x224 | 236.29 | 70.99 | 89.80 | [here](/model/classification/model_zoo/dicenet/dicenet_s_2.0_imagenet_224x224.pth) |


## ESPNetv2 models
Expand Down
2 changes: 1 addition & 1 deletion test_classification.py
Expand Up @@ -17,7 +17,7 @@ def main(args):
if args.model == 'dicenet':
from model.classification import dicenet as net
model = net.CNNModel(args)
elif args.model == 'espnet':
elif args.model == 'espnetv2':
from model.classification import espnetv2 as net
model = net.EESPNet(args)
elif args.model == 'shufflenetv2':
Expand Down
21 changes: 6 additions & 15 deletions train_classification.py
Expand Up @@ -24,24 +24,15 @@ def main(args):
# -----------------------------------------------------------------------------
# Create model
# -----------------------------------------------------------------------------
if args.model == 'basic_dw':
from model.classification import basic_dw as net
model = net.CNNModel(args)
elif args.model == 'basic_vw':
from model.classification import basic_vw as net
model = net.CNNModel(args)
elif args.model == 'shuffle_dw':
from model.classification import shufflenetv2 as net
model = net.CNNModel(args)
elif args.model == 'e_shuffle_dw':
from model.classification import eff_shuffle_dw as net
model = net.CNNModel(args)
elif args.model == 'dicenet':
if args.model == 'dicenet':
from model.classification import dicenet as net
model = net.CNNModel(args)
elif args.model == 'espnet':
elif args.model == 'espnetv2':
from model.classification import espnetv2 as net
model = net.EESPNet(classes=args.num_classes, s=args.s)
model = net.EESPNet(args)
elif args.model == 'shufflenetv2':
from model.classification import shufflenetv2 as net
model = net.CNNModel(args)
else:
print_error_message('Model {} not yet implemented'.format(args.model))
exit()
Expand Down

0 comments on commit 2902d3c

Please sign in to comment.