Skip to content

Commit

Permalink
--img-size stride-multiple verification
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Jun 13, 2020
1 parent 4ba47b1 commit 099e6f5
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 4 deletions.
1 change: 1 addition & 0 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def detect(save_img=False):
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
parser.add_argument('--augment', action='store_true', help='augmented inference')
opt = parser.parse_args()
opt.img_size = check_img_size(opt.img_size)
print(opt)

with torch.no_grad():
Expand Down
1 change: 1 addition & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def test(data,
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--verbose', action='store_true', help='report mAP by class')
opt = parser.parse_args()
opt.img_size = check_img_size(opt.img_size)
opt.save_json = opt.save_json or opt.data.endswith('coco.yaml')
opt.data = glob.glob('./**/' + opt.data, recursive=True)[0] # find file
print(opt)
Expand Down
4 changes: 1 addition & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ def train(hyp):

# Image sizes
gs = int(max(model.stride)) # grid size (max stride)
if any(x % gs != 0 for x in opt.img_size):
print('WARNING: --img-size %g,%g must be multiple of %s max stride %g' % (*opt.img_size, opt.cfg, gs))
imgsz, imgsz_test = [make_divisible(x, gs) for x in opt.img_size] # image sizes (train, test)
imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples

# Optimizer
nbs = 64 # nominal batch size
Expand Down
9 changes: 8 additions & 1 deletion utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,20 @@ def init_seeds(seed=0):


def check_git_status():
# Suggest 'git pull' if repo is out of date
if platform in ['linux', 'darwin']:
# Suggest 'git pull' if repo is out of date
s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
if 'Your branch is behind' in s:
print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')


def check_img_size(img_size, s=32):
# Verify img_size is a multiple of stride s
if img_size % s != 0:
print('WARNING: --img-size %g must be multiple of max stride %g' % (img_size, s))
return make_divisible(img_size, s) # nearest gs-multiple


def make_divisible(x, divisor):
# Returns x evenly divisble by divisor
return math.ceil(x / divisor) * divisor
Expand Down

0 comments on commit 099e6f5

Please sign in to comment.