From d3a0773a1270a4511cfc2a79870aeb0524d2ce10 Mon Sep 17 00:00:00 2001 From: Ilya Ovodov Date: Wed, 13 May 2020 10:01:07 +0300 Subject: [PATCH] convert(...) changed to save converted file alongside the original file --- README.md | 4 ++-- models.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 06a3ecddae..a9301c8ffb 100755 --- a/README.md +++ b/README.md @@ -107,11 +107,11 @@ $ git clone https://github.com/ultralytics/yolov3 && cd yolov3 # convert darknet cfg/weights to pytorch model $ python3 -c "from models import *; convert('cfg/yolov3-spp.cfg', 'weights/yolov3-spp.weights')" -Success: converted 'weights/yolov3-spp.weights' to 'converted.pt' +Success: converted 'weights/yolov3-spp.weights' to 'weights/yolov3-spp.pt' # convert cfg/pytorch model to darknet weights $ python3 -c "from models import *; convert('cfg/yolov3-spp.cfg', 'weights/yolov3-spp.pt')" -Success: converted 'weights/yolov3-spp.pt' to 'converted.weights' +Success: converted 'weights/yolov3-spp.pt' to 'weights/yolov3-spp.weights' ``` # mAP diff --git a/models.py b/models.py index afd5b87b0e..ebe151b6a3 100755 --- a/models.py +++ b/models.py @@ -423,8 +423,9 @@ def convert(cfg='cfg/yolov3-spp.cfg', weights='weights/yolov3-spp.weights'): # Load weights and save if weights.endswith('.pt'): # if PyTorch format model.load_state_dict(torch.load(weights, map_location='cpu')['model']) - save_weights(model, path='converted.weights', cutoff=-1) - print("Success: converted '%s' to 'converted.weights'" % weights) + target = weights.rsplit('.', 1)[0] + '.weights' + save_weights(model, path=target, cutoff=-1) + print("Success: converted '%s' to '%s'" % (weights, target)) elif weights.endswith('.weights'): # darknet format _ = load_darknet_weights(model, weights) @@ -435,8 +436,9 @@ def convert(cfg='cfg/yolov3-spp.cfg', weights='weights/yolov3-spp.weights'): 'model': model.state_dict(), 'optimizer': None} - torch.save(chkpt, 'converted.pt') - print("Success: converted '%s' to 'converted.pt'" % weights) + target = weights.rsplit('.', 1)[0] + '.pt' + torch.save(chkpt, target) + print("Success: converted '%s' to '%'" % (weights, target)) else: print('Error: extension not supported.')