Skip to content

Commit

Permalink
Detection support (#60)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
  • Loading branch information
3 people committed Dec 3, 2022
1 parent 5a52e76 commit 7ec7cf3
Show file tree
Hide file tree
Showing 9 changed files with 545 additions and 11 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,16 @@ jobs:
- name: Test detection
shell: bash # for Windows compatibility
run: |
echo "TODO"
yolo task=detect mode=train model=yolov5n.yaml data=coco128.yaml epochs=1 img_size=64
yolo task=detect mode=val model=runs/exp/weights/last.pt img_size=64
- name: Test segmentation
shell: bash # for Windows compatibility
# TODO: redo val test without hardcoded weights
run: |
yolo task=segment mode=train model=yolov5n-seg.yaml data=coco128-seg.yaml epochs=1 img_size=64
yolo task=segment mode=val model=runs/exp/weights/last.pt data=coco128-seg.yaml img_size=64
yolo task=segment mode=val model=runs/exp2/weights/last.pt data=coco128-seg.yaml img_size=64
- name: Test classification
shell: bash # for Windows compatibility
run: |
yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 img_size=32
yolo task=classify mode=val model=runs/exp2/weights/last.pt data=mnist160
yolo task=classify mode=val model=runs/exp3/weights/last.pt data=mnist160
6 changes: 3 additions & 3 deletions ultralytics/yolo/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,14 +459,14 @@ def ap_per_class_box_and_mask(
"boxes": {
"p": results_boxes[0],
"r": results_boxes[1],
"ap": results_boxes[3],
"f1": results_boxes[2],
"ap": results_boxes[3],
"ap_class": results_boxes[4]},
"masks": {
"p": results_masks[0],
"r": results_masks[1],
"ap": results_masks[3],
"f1": results_masks[2],
"ap": results_masks[3],
"ap_class": results_masks[4]}}
return results

Expand Down Expand Up @@ -547,7 +547,7 @@ def update(self, results):
Args:
results: tuple(p, r, ap, f1, ap_class)
"""
p, r, all_ap, f1, ap_class_index = results
p, r, f1, all_ap, ap_class_index = results
self.p = p
self.r = r
self.all_ap = all_ap
Expand Down
106 changes: 105 additions & 1 deletion ultralytics/yolo/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,15 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False,


@threaded
def plot_images_and_masks(images, batch_idx, cls, bboxes, masks, paths, confs=None, fname='images.jpg', names=None):
def plot_images_and_masks(images,
batch_idx,
cls,
bboxes,
masks,
confs=None,
paths=None,
fname='images.jpg',
names=None):
# Plot image grid with labels
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
Expand Down Expand Up @@ -327,3 +335,99 @@ def output_to_target(output, max_det=300):
targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
targets = torch.cat(targets, 0).numpy()
return targets[:, 0], targets[:, 1], targets[:, 2:6], targets[:, 6]


@threaded
def plot_images(images, batch_idx, cls, bboxes, confs=None, paths=None, fname='images.jpg', names=None):
# Plot image grid with labels
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
if isinstance(cls, torch.Tensor):
cls = cls.cpu().numpy()
if isinstance(bboxes, torch.Tensor):
bboxes = bboxes.cpu().numpy()
if isinstance(batch_idx, torch.Tensor):
batch_idx = batch_idx.cpu().numpy()

max_size = 1920 # max image size
max_subplots = 16 # max image subplots, i.e. 4x4
bs, _, h, w = images.shape # batch size, _, height, width
bs = min(bs, max_subplots) # limit plot images
ns = np.ceil(bs ** 0.5) # number of subplots (square)
if np.max(images[0]) <= 1:
images *= 255 # de-normalise (optional)

# Build Image
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
for i, im in enumerate(images):
if i == max_subplots: # if last batch has fewer images than we expect
break
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
im = im.transpose(1, 2, 0)
mosaic[y:y + h, x:x + w, :] = im

# Resize (optional)
scale = max_size / ns / max(h, w)
if scale < 1:
h = math.ceil(scale * h)
w = math.ceil(scale * w)
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))

# Annotate
fs = int((h + w) * ns * 0.01) # font size
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
for i in range(i + 1):
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
if paths:
annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
if len(cls) > 0:
idx = batch_idx == i

boxes = xywh2xyxy(bboxes[idx]).T
classes = cls[idx].astype('int')
labels = confs is None # labels if no conf column
conf = None if labels else confs[idx] # check for confidence presence (label vs pred)

if boxes.shape[1]:
if boxes.max() <= 1.01: # if normalized with tolerance 0.01
boxes[[0, 2]] *= w # scale to pixels
boxes[[1, 3]] *= h
elif scale < 1: # absolute coords need scale if image scales
boxes *= scale
boxes[[0, 2]] += x
boxes[[1, 3]] += y
for j, box in enumerate(boxes.T.tolist()):
c = classes[j]
color = colors(c)
c = names[c] if names else c
if labels or conf[j] > 0.25: # 0.25 conf thresh
label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
annotator.box_label(box, label, color=color)
annotator.im.save(fname) # save


def plot_results(file='path/to/results.csv', dir=''):
# Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
save_dir = Path(file).parent if file else Path(dir)
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
ax = ax.ravel()
files = list(save_dir.glob('results*.csv'))
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
for f in files:
try:
data = pd.read_csv(f)
s = [x.strip() for x in data.columns]
x = data.values[:, 0]
for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]):
y = data.values[:, j].astype('float')
# y[y == 0] = np.nan # don't show zero values
ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8)
ax[i].set_title(s[j], fontsize=12)
# if j in [8, 9, 10]: # share train and val loss y axes
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
except Exception as e:
print(f'Warning: Plotting error for {f}: {e}')
ax[1].legend()
fig.savefig(save_dir / 'results.png', dpi=200)
plt.close()
4 changes: 2 additions & 2 deletions ultralytics/yolo/v8/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path

from ultralytics.yolo.v8 import classify, segment
from ultralytics.yolo.v8 import classify, detect, segment

ROOT = Path(__file__).parents[0] # yolov8 ROOT

__all__ = ["classify", "segment"]
__all__ = ["classify", "segment", "detect"]
2 changes: 2 additions & 0 deletions ultralytics/yolo/v8/detect/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from ultralytics.yolo.v8.detect.train import DetectionTrainer, train
from ultralytics.yolo.v8.detect.val import DetectionValidator, val
209 changes: 209 additions & 0 deletions ultralytics/yolo/v8/detect/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import hydra
import torch
import torch.nn as nn

from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
from ultralytics.yolo.utils.modeling.tasks import DetectionModel
from ultralytics.yolo.utils.plotting import plot_images, plot_results
from ultralytics.yolo.utils.torch_utils import de_parallel

from ..segment import SegmentationTrainer
from .val import DetectionValidator


# BaseTrainer python usage
class DetectionTrainer(SegmentationTrainer):

def load_model(self, model_cfg, weights, data):
model = DetectionModel(model_cfg or weights["model"].yaml,
ch=3,
nc=data["nc"],
anchors=self.args.get("anchors"))
if weights:
model.load(weights)
for _, v in model.named_parameters():
v.requires_grad = True # train all layers
return model

def get_validator(self):
return DetectionValidator(self.test_loader, save_dir=self.save_dir, logger=self.console, args=self.args)

def criterion(self, preds, batch):
head = de_parallel(self.model).model[-1]
sort_obj_iou = False
autobalance = False

# init losses
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.args.cls_pw], device=self.device))
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.args.obj_pw], device=self.device))

# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
cp, cn = smooth_BCE(eps=self.args.label_smoothing) # positive, negative BCE targets

# Focal loss
g = self.args.fl_gamma
if self.args.fl_gamma > 0:
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)

balance = {3: [4.0, 1.0, 0.4]}.get(head.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
ssi = list(head.stride).index(16) if autobalance else 0 # stride 16 index
BCEcls, BCEobj, gr, autobalance = BCEcls, BCEobj, 1.0, autobalance

def build_targets(p, targets):
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
nonlocal head
na, nt = head.na, targets.shape[0] # number of anchors, targets
tcls, tbox, indices, anch = [], [], [], []
gain = torch.ones(7, device=self.device) # normalized to gridspace gain
ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, nt)
targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None]), 2) # append anchor indices

g = 0.5 # bias
off = torch.tensor(
[
[0, 0],
[1, 0],
[0, 1],
[-1, 0],
[0, -1], # j,k,l,m
# [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
],
device=self.device).float() * g # offsets

for i in range(head.nl):
anchors, shape = head.anchors[i], p[i].shape
gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]] # xyxy gain

# Match targets to anchors
t = targets * gain # shape(3,n,7)
if nt:
# Matches
r = t[..., 4:6] / anchors[:, None] # wh ratio
j = torch.max(r, 1 / r).max(2)[0] < self.args.anchor_t # compare
# j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
t = t[j] # filter

# Offsets
gxy = t[:, 2:4] # grid xy
gxi = gain[[2, 3]] - gxy # inverse
j, k = ((gxy % 1 < g) & (gxy > 1)).T
l, m = ((gxi % 1 < g) & (gxi > 1)).T
j = torch.stack((torch.ones_like(j), j, k, l, m))
t = t.repeat((5, 1, 1))[j]
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
else:
t = targets[0]
offsets = 0

# Define
bc, gxy, gwh, a = t.chunk(4, 1) # (image, class), grid xy, grid wh, anchors
a, (b, c) = a.long().view(-1), bc.long().T # anchors, image, class
gij = (gxy - offsets).long()
gi, gj = gij.T # grid indices

# Append
indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1))) # image, anchor, grid
tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
anch.append(anchors[a]) # anchors
tcls.append(c) # class

return tcls, tbox, indices, anch

if len(preds) == 2: # eval
_, p = preds
else: # len(3) train
p = preds

targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = targets.to(self.device)

lcls = torch.zeros(1, device=self.device)
lbox = torch.zeros(1, device=self.device)
lobj = torch.zeros(1, device=self.device)
tcls, tbox, indices, anchors = build_targets(p, targets)

# Losses
for i, pi in enumerate(p): # layer index, layer predictions
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj
bs = tobj.shape[0]
n = b.shape[0] # number of targets
if n:
pxy, pwh, _, pcls = pi[b, a, gj, gi].split((2, 2, 1, head.nc), 1) # subset of predictions

# Box regression
pxy = pxy.sigmoid() * 2 - 0.5
pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i]
pbox = torch.cat((pxy, pwh), 1) # predicted box
iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze() # iou(prediction, target)
lbox += (1.0 - iou).mean() # iou loss

# Objectness
iou = iou.detach().clamp(0).type(tobj.dtype)
if sort_obj_iou:
j = iou.argsort()
b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j]
if gr < 1:
iou = (1.0 - gr) + gr * iou
tobj[b, a, gj, gi] = iou # iou ratio

# Classification
if head.nc > 1: # cls loss (only if multiple classes)
t = torch.full_like(pcls, cn, device=self.device) # targets
t[range(n), tcls[i]] = cp
lcls += BCEcls(pcls, t) # BCE

obji = BCEobj(pi[..., 4], tobj)
lobj += obji * balance[i] # obj loss
if autobalance:
balance[i] = balance[i] * 0.9999 + 0.0001 / obji.detach().item()

if autobalance:
balance = [x / balance[ssi] for x in balance]
lbox *= self.args.box
lobj *= self.args.obj
lcls *= self.args.cls

loss = lbox + lobj + lcls
return loss * bs, torch.cat((lbox, lobj, lcls)).detach()

# TODO: improve from API users perspective
def label_loss_items(self, loss_items=None, prefix="train"):
# We should just use named tensors here in future
keys = [f"{prefix}/lbox", f"{prefix}/lobj", f"{prefix}/lcls"]
return dict(zip(keys, loss_items)) if loss_items is not None else keys

def progress_string(self):
return ('\n' + '%11s' * 6) % \
('Epoch', 'GPU_mem', 'box_loss', 'obj_loss', 'cls_loss', 'Size')

def plot_training_samples(self, batch, ni):
images = batch["img"]
cls = batch["cls"].squeeze(-1)
bboxes = batch["bboxes"]
paths = batch["im_file"]
batch_idx = batch["batch_idx"]
plot_images(images, batch_idx, cls, bboxes, paths=paths, fname=self.save_dir / f"train_batch{ni}.jpg")

def plot_metrics(self):
plot_results(file=self.csv) # save results.png


@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
def train(cfg):
cfg.model = cfg.model or "models/yolov5n.yaml"
cfg.data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist")
trainer = DetectionTrainer(cfg)
trainer.train()


if __name__ == "__main__":
"""
CLI usage:
python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-segments epochs=100 img_size=640
TODO:
Direct cli support, i.e, yolov8 classify_train args.epochs 10
"""
train()
Loading

0 comments on commit 7ec7cf3

Please sign in to comment.