diff --git a/models/yolov5l.yaml b/models/yolov5l.yaml index 31362f876932..19bb18b664a8 100644 --- a/models/yolov5l.yaml +++ b/models/yolov5l.yaml @@ -1,6 +1,7 @@ # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license # Parameters +input_channels: 3 # number of input channels, RGB is 3 nc: 80 # number of classes depth_multiple: 1.0 # model depth multiple width_multiple: 1.0 # layer channel multiple diff --git a/models/yolov5m.yaml b/models/yolov5m.yaml index a76900c5a2e2..6d492821ae6b 100644 --- a/models/yolov5m.yaml +++ b/models/yolov5m.yaml @@ -1,6 +1,7 @@ # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license # Parameters +input_channels: 3 # number of input channels, RGB is 3 nc: 80 # number of classes depth_multiple: 0.67 # model depth multiple width_multiple: 0.75 # layer channel multiple diff --git a/models/yolov5n.yaml b/models/yolov5n.yaml index aba96cfc54f4..558c06210b3d 100644 --- a/models/yolov5n.yaml +++ b/models/yolov5n.yaml @@ -1,6 +1,7 @@ # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license # Parameters +input_channels: 3 # number of input channels, RGB is 3 nc: 80 # number of classes depth_multiple: 0.33 # model depth multiple width_multiple: 0.25 # layer channel multiple diff --git a/models/yolov5s.yaml b/models/yolov5s.yaml index 5d05364c4936..d92e99e2ae04 100644 --- a/models/yolov5s.yaml +++ b/models/yolov5s.yaml @@ -1,6 +1,7 @@ # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license # Parameters +input_channels: 3 # number of input channels, RGB is 3 nc: 80 # number of classes depth_multiple: 0.33 # model depth multiple width_multiple: 0.50 # layer channel multiple diff --git a/train.py b/train.py index 004c8eeda121..9c656c9f28fb 100644 --- a/train.py +++ b/train.py @@ -127,14 +127,15 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio with torch_distributed_zero_first(LOCAL_RANK): weights = attempt_download(weights) # download if not found locally ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak - model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create + model = Model(cfg or ckpt['model'].yaml, ch=hyp.get('input_channels', 3), nc=nc, + anchors=hyp.get('anchors')).to(device) # create exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect model.load_state_dict(csd, strict=False) # load LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report else: - model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create + model = Model(cfg, ch=hyp.get('input_channels', 3), nc=nc, anchors=hyp.get('anchors')).to(device) # create amp = check_amp(model) # check AMP # Freeze