Skip to content

Commit

Permalink
convert(...) changed to save converted file alongside the original fi…
Browse files Browse the repository at this point in the history
…le (#1167)
  • Loading branch information
IlyaOvodov committed May 13, 2020
1 parent c066d7d commit b2fcfc5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ $ git clone https://github.com/ultralytics/yolov3 && cd yolov3

# convert darknet cfg/weights to pytorch model
$ python3 -c "from models import *; convert('cfg/yolov3-spp.cfg', 'weights/yolov3-spp.weights')"
Success: converted 'weights/yolov3-spp.weights' to 'converted.pt'
Success: converted 'weights/yolov3-spp.weights' to 'weights/yolov3-spp.pt'

# convert cfg/pytorch model to darknet weights
$ python3 -c "from models import *; convert('cfg/yolov3-spp.cfg', 'weights/yolov3-spp.pt')"
Success: converted 'weights/yolov3-spp.pt' to 'converted.weights'
Success: converted 'weights/yolov3-spp.pt' to 'weights/yolov3-spp.weights'
```

# mAP
Expand Down
10 changes: 6 additions & 4 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,9 @@ def convert(cfg='cfg/yolov3-spp.cfg', weights='weights/yolov3-spp.weights'):
# Load weights and save
if weights.endswith('.pt'): # if PyTorch format
model.load_state_dict(torch.load(weights, map_location='cpu')['model'])
save_weights(model, path='converted.weights', cutoff=-1)
print("Success: converted '%s' to 'converted.weights'" % weights)
target = weights.rsplit('.', 1)[0] + '.weights'
save_weights(model, path=target, cutoff=-1)
print("Success: converted '%s' to '%s'" % (weights, target))

elif weights.endswith('.weights'): # darknet format
_ = load_darknet_weights(model, weights)
Expand All @@ -435,8 +436,9 @@ def convert(cfg='cfg/yolov3-spp.cfg', weights='weights/yolov3-spp.weights'):
'model': model.state_dict(),
'optimizer': None}

torch.save(chkpt, 'converted.pt')
print("Success: converted '%s' to 'converted.pt'" % weights)
target = weights.rsplit('.', 1)[0] + '.pt'
torch.save(chkpt, target)
print("Success: converted '%s' to '%'" % (weights, target))

else:
print('Error: extension not supported.')
Expand Down

0 comments on commit b2fcfc5

Please sign in to comment.