Skip to content

Commit

Permalink
rename ch to input_channels to clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
jere357 committed Nov 15, 2023
1 parent 8b6bf0b commit e29b190
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,14 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
with torch_distributed_zero_first(LOCAL_RANK):
weights = attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
model = Model(cfg or ckpt['model'].yaml, ch=hyp.get('ch', 3), nc=nc, anchors=hyp.get('anchors')).to(device) # create
model = Model(cfg or ckpt['model'].yaml, ch=hyp.get('input_channels', 3), nc=nc, anchors=hyp.get('anchors')).to(device) # create
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
model.load_state_dict(csd, strict=False) # load
LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
else:
model = Model(cfg, ch=hyp.get('ch', 3), nc=nc, anchors=hyp.get('anchors')).to(device) # create
model = Model(cfg, ch=hyp.get('input_channels', 3), nc=nc, anchors=hyp.get('anchors')).to(device) # create
amp = check_amp(model) # check AMP

# Freeze
Expand Down

0 comments on commit e29b190

Please sign in to comment.