Skip to content

Commit

Permalink
PANet update
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Jun 22, 2020
1 parent ef58dac commit 364fcfd
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 96 deletions.
15 changes: 8 additions & 7 deletions README.md
Expand Up @@ -4,7 +4,7 @@

This repository represents Ultralytics open-source research into future object detection methods, and incorporates our lessons learned and best practices evolved over training thousands of models on custom client datasets with our previous YOLO repository https://github.com/ultralytics/yolov3. **All code and models are under active development, and are subject to modification or deletion without notice.** Use at your own risk.

<img src="https://user-images.githubusercontent.com/26833433/84200349-729f2680-aa5b-11ea-8f9a-604c9e01a658.png" width="1000">** GPU Speed measures end-to-end time per image averaged over 5000 COCO val2017 images using a V100 GPU with batch size 32, and includes image preprocessing, PyTorch FP32 inference, postprocessing and NMS.
<img src="https://user-images.githubusercontent.com/26833433/85336627-c6663280-b493-11ea-9b0a-289b0f182b84.png" width="1000">** GPU Speed measures end-to-end time per image averaged over 5000 COCO val2017 images using a V100 GPU with batch size 8, and includes image preprocessing, PyTorch FP16 inference, postprocessing and NMS.

- **June 19, 2020**: [FP16](https://pytorch.org/docs/stable/nn.html#torch.nn.Module.half) as new default for smaller checkpoints and faster inference. Comparison in [d4c6674](https://github.com/ultralytics/yolov5/commit/d4c6674c98e19df4c40e33a777610a18d1961145).
- **June 9, 2020**: [CSP](https://github.com/WongKinYiu/CrossStagePartialNetworks) updates to all YOLOv5 models. New models are faster, smaller and more accurate. Credit to @WongKinYiu for his excellent work with CSP.
Expand All @@ -14,13 +14,14 @@ This repository represents Ultralytics open-source research into future object d

## Pretrained Checkpoints

| Model | AP<sup>val</sup> | AP<sup>test</sup> | AP<sub>50</sub> | Speed<sub>GPU</sub> | FPS<sub>GPU</sub> || params | FLOPs |
| Model | AP<sup>val</sup> | AP<sup>test</sup> | AP<sub>50</sub> | Speed<sub>GPU</sub> | FPS<sub>GPU</sub> || params | FLOPS |
|---------- |------ |------ |------ | -------- | ------| ------ |------ | :------: |
| YOLOv5-s ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | 35.5 | 35.5 | 55.0 | **2.1ms** | **476** || 7.1M | 12.6B
| YOLOv5-m ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | 42.7 | 42.7 | 62.4 | 3.2ms | 312 || 22.0M | 39.0B
| YOLOv5-l ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | 45.7 | 45.9 | 65.1 | 4.1ms | 243 || 50.3M | 89.0B
| YOLOv5-x ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | **47.2** | **47.3** | **66.6** | 6.5ms | 153 || 95.9M | 170.3B
| YOLOv3-SPP ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | 45.6 | 45.5 | 65.2 | 4.8ms | 208 || 63.0M | 118.0B
| [YOLOv5s](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 36.5 | 36.5 | 55.6 | **2.2ms** | **455** || 7.5M | 13.2B
| [YOLOv5m](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 43.4 | 43.4 | 62.4 | 3.0ms | 333 || 21.8M | 39.4B
| [YOLOv5l](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 46.6 | 46.7 | 65.4 | 3.9ms | 256 || 47.8M | 88.1B
| [YOLOv5x](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | **48.2** | **48.3** | **66.9** | 6.1ms | 164 || 89.0M | 166.4B
| [YOLOv3-SPP](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 45.6 | 45.5 | 65.2 | 4.5ms | 222 || 63.0M | 118.0B


** AP<sup>test</sup> denotes COCO [test-dev2017](http://cocodataset.org/#upload) server results, all other AP results in the table denote val2017 accuracy.
** All AP numbers are for single-model single-scale without ensemble or test-time augmentation. Reproduce by `python test.py --img 736 --conf 0.001`
Expand Down
3 changes: 1 addition & 2 deletions models/yolov3-spp.yaml
Expand Up @@ -25,8 +25,7 @@ backbone:
[-1, 4, Bottleneck, [1024]], # 10
]

# yolov3-spp head
# na = len(anchors[0])
# YOLOv3-SPP head
head:
[[-1, 1, Bottleneck, [1024, False]], # 11
[-1, 1, SPP, [512, [5, 9, 13]]],
Expand Down
49 changes: 28 additions & 21 deletions models/yolov5l.yaml
Expand Up @@ -5,41 +5,48 @@ width_multiple: 1.0 # layer channel multiple

# anchors
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32

This comment has been minimized.

Copy link
@M-o-magic

M-o-magic Apr 24, 2023

order relative to feature map order. FPN ->PAN

- [30,61, 62,45, 59,119] # P4/16
- [10,13, 16,30, 33,23] # P3/8

# yolov5 backbone
# YOLOv5 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Focus, [64, 3]], # 1-P1/2
[-1, 1, Conv, [128, 3, 2]], # 2-P2/4
[-1, 3, Bottleneck, [128]],
[-1, 1, Conv, [256, 3, 2]], # 4-P3/8
[[-1, 1, Focus, [64, 3]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, BottleneckCSP, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 9, BottleneckCSP, [256]],
[-1, 1, Conv, [512, 3, 2]], # 6-P4/16
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, BottleneckCSP, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 8-P5/32
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 6, BottleneckCSP, [1024]], # 10
]

# yolov5 head
# YOLOv5 head
head:
[[-1, 3, BottleneckCSP, [1024, False]], # 11
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 12 (P5/32-large)
[[-1, 3, BottleneckCSP, [1024, False]], # 9

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 1, Conv, [512, 1, 1]],
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 17 (P4/16-medium)
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, BottleneckCSP, [512, False]], # 13

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, BottleneckCSP, [256, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P3/8-small)
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 18 (P3/8-small)

[-2, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P4/16-medium)

[-2, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, BottleneckCSP, [1024, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 26 (P5/32-large)

[[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
[[], 1, Detect, [nc, anchors]], # Detect(P5, P4, P3)
]
49 changes: 28 additions & 21 deletions models/yolov5m.yaml
Expand Up @@ -5,41 +5,48 @@ width_multiple: 0.75 # layer channel multiple

# anchors
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
- [30,61, 62,45, 59,119] # P4/16
- [10,13, 16,30, 33,23] # P3/8

# yolov5 backbone
# YOLOv5 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Focus, [64, 3]], # 1-P1/2
[-1, 1, Conv, [128, 3, 2]], # 2-P2/4
[-1, 3, Bottleneck, [128]],
[-1, 1, Conv, [256, 3, 2]], # 4-P3/8
[[-1, 1, Focus, [64, 3]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, BottleneckCSP, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 9, BottleneckCSP, [256]],
[-1, 1, Conv, [512, 3, 2]], # 6-P4/16
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, BottleneckCSP, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 8-P5/32
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 6, BottleneckCSP, [1024]], # 10
]

# yolov5 head
# YOLOv5 head
head:
[[-1, 3, BottleneckCSP, [1024, False]], # 11
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 12 (P5/32-large)
[[-1, 3, BottleneckCSP, [1024, False]], # 9

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 1, Conv, [512, 1, 1]],
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 17 (P4/16-medium)
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, BottleneckCSP, [512, False]], # 13

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, BottleneckCSP, [256, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P3/8-small)
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 18 (P3/8-small)

[-2, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P4/16-medium)

[-2, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, BottleneckCSP, [1024, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 26 (P5/32-large)

[[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
[[], 1, Detect, [nc, anchors]], # Detect(P5, P4, P3)
]
49 changes: 28 additions & 21 deletions models/yolov5s.yaml
Expand Up @@ -5,41 +5,48 @@ width_multiple: 0.50 # layer channel multiple

# anchors
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
- [30,61, 62,45, 59,119] # P4/16
- [10,13, 16,30, 33,23] # P3/8

# yolov5 backbone
# YOLOv5 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Focus, [64, 3]], # 1-P1/2
[-1, 1, Conv, [128, 3, 2]], # 2-P2/4
[-1, 3, Bottleneck, [128]],
[-1, 1, Conv, [256, 3, 2]], # 4-P3/8
[[-1, 1, Focus, [64, 3]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, BottleneckCSP, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 9, BottleneckCSP, [256]],
[-1, 1, Conv, [512, 3, 2]], # 6-P4/16
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, BottleneckCSP, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 8-P5/32
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 6, BottleneckCSP, [1024]], # 10
]

# yolov5 head
# YOLOv5 head
head:
[[-1, 3, BottleneckCSP, [1024, False]], # 11
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 12 (P5/32-large)
[[-1, 3, BottleneckCSP, [1024, False]], # 9

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 1, Conv, [512, 1, 1]],
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 17 (P4/16-medium)
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, BottleneckCSP, [512, False]], # 13

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, BottleneckCSP, [256, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P3/8-small)
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 18 (P3/8-small)

[-2, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P4/16-medium)

[-2, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, BottleneckCSP, [1024, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 26 (P5/32-large)

[[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
[[], 1, Detect, [nc, anchors]], # Detect(P5, P4, P3)
]
49 changes: 28 additions & 21 deletions models/yolov5x.yaml
Expand Up @@ -5,41 +5,48 @@ width_multiple: 1.25 # layer channel multiple

# anchors
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
- [30,61, 62,45, 59,119] # P4/16
- [10,13, 16,30, 33,23] # P3/8

# yolov5 backbone
# YOLOv5 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Focus, [64, 3]], # 1-P1/2
[-1, 1, Conv, [128, 3, 2]], # 2-P2/4
[-1, 3, Bottleneck, [128]],
[-1, 1, Conv, [256, 3, 2]], # 4-P3/8
[[-1, 1, Focus, [64, 3]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, BottleneckCSP, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 9, BottleneckCSP, [256]],
[-1, 1, Conv, [512, 3, 2]], # 6-P4/16
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, BottleneckCSP, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 8-P5/32
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 6, BottleneckCSP, [1024]], # 10
]

# yolov5 head
# YOLOv5 head
head:
[[-1, 3, BottleneckCSP, [1024, False]], # 11
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 12 (P5/32-large)
[[-1, 3, BottleneckCSP, [1024, False]], # 9

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 1, Conv, [512, 1, 1]],
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 17 (P4/16-medium)
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, BottleneckCSP, [512, False]], # 13

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, BottleneckCSP, [256, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P3/8-small)
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 18 (P3/8-small)

[-2, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P4/16-medium)

[-2, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, BottleneckCSP, [1024, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 26 (P5/32-large)

[[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
[[], 1, Detect, [nc, anchors]], # Detect(P5, P4, P3)
]
8 changes: 5 additions & 3 deletions utils/utils.py
Expand Up @@ -1094,12 +1094,14 @@ def plot_study_txt(f='study.txt', x=None): # from utils.utils import *; plot_st

ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [33.5, 39.1, 42.5, 45.9, 49., 50.5],
'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')

ax2.grid()
ax2.set_xlim(0, 30)
ax2.set_ylim(25, 50)
ax2.set_xlabel('GPU Latency (ms)')
ax2.set_ylim(28, 50)
ax2.set_yticks(np.arange(30, 55, 5))
ax2.set_xlabel('GPU Speed (ms/img)')
ax2.set_ylabel('COCO AP val')
ax2.legend(loc='lower right')
ax2.grid()
plt.savefig('study_mAP_latency.png', dpi=300)
plt.savefig(f.replace('.txt', '.png'), dpi=200)

Expand Down

8 comments on commit 364fcfd

@yxNONG
Copy link
Contributor

@yxNONG yxNONG commented on 364fcfd Jun 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glenn-jocher thanks for great work! the neck is change to PAN now, my question is, is the pretrained weight also been updated now?

@glenn-jocher
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yxNONG yes, everything is updated including the pretrained checkpoints. If you delete your current checkpoints it will download the new ones automatically. It will not attempt to download new weights if it seems a matching checkpoint filename in your directory however.

@lucasjinreal
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glenn-jocher how's the speed and accuracy tradeoff leveraged? readmes seems newest?

@HaxThePlanet
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excited for this, any performance numbers?

@glenn-jocher
Copy link
Member Author

@glenn-jocher glenn-jocher commented on 364fcfd Jun 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HaxThePlanet @jinfagang yes, the models are faster, smaller and more accurate. You can see this in the README table changes above.

Before:

Model APval APtest AP50 SpeedGPU FPSGPU params FLOPs
YOLOv5-s 35.5 35.5 55.0 2.1ms 476 7.1M 12.6B
YOLOv5-m 42.7 42.7 62.4 3.2ms 312 22.0M 39.0B
YOLOv5-l 45.7 45.9 65.1 4.1ms 243 50.3M 89.0B
YOLOv5-x 47.2 47.3 66.6 6.5ms 153 95.9M 170.3B
YOLOv3-SPP 45.6 45.5 65.2 4.8ms 208 63.0M 118.0B

After:

Model APval APtest AP50 SpeedGPU FPSGPU params FLOPS
YOLOv5s 36.5 36.5 55.6 2.1ms 476 7.5M 13.2B
YOLOv5m 43.4 43.4 62.4 3.0ms 333 21.8M 39.4B
YOLOv5l 46.6 46.7 65.4 3.9ms 256 47.8M 88.1B
YOLOv5x 48.4 48.4 66.9 6.1ms 164 89.0M 166.4B
YOLOv3-SPP 45.6 45.5 65.2 4.5ms 222 63.0M 118.0B

@Laughing-q
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see that all the pretrained checkpoints are based on F16, i wonder that all the results are based on F16 or F32 ?

@glenn-jocher
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Laughing-q all checkpoints and results are in FP16.

@glenn-jocher
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All older FPN models are still accessible in yolov5/models/hub Screen Shot 2020-06-24 at 11 59 52 AM

And the pretrained weights are accessable from the Google Drive folder in the hub/ directory:
Screen Shot 2020-06-24 at 12 01 31 PM

Please sign in to comment.