Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segmentation support & other enchancements #40

Merged
merged 32 commits into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
e9ec11c
generic get_model
AyushExel Oct 31, 2022
857d88a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2022
dbf9a59
first commit
AyushExel Nov 4, 2022
d0bcc77
generic get_model
AyushExel Oct 31, 2022
c18a4c4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2022
7b90bfd
update
AyushExel Nov 6, 2022
2cbc72f
update
AyushExel Nov 7, 2022
bad94fc
update
AyushExel Nov 7, 2022
bc7a296
update
AyushExel Nov 7, 2022
bf0c48b
update
AyushExel Nov 7, 2022
82eb537
Deconflicting default.yaml
glenn-jocher Nov 7, 2022
3ce34a8
Merge master
glenn-jocher Nov 7, 2022
587b62b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2022
4d187e0
ops.py pre-commit fixes
glenn-jocher Nov 7, 2022
e10fecb
Merge remote-tracking branch 'origin/callbacks' into callbacks
glenn-jocher Nov 7, 2022
f776c0d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2022
14ea7ec
segment/val.py pre-commit fixes
glenn-jocher Nov 7, 2022
81e19e1
Merge remote-tracking branch 'origin/callbacks' into callbacks
glenn-jocher Nov 7, 2022
a2c19f6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2022
7e808f7
fix errors from merge
AyushExel Nov 7, 2022
466ff5d
segment/train.py pre-commit fixes
glenn-jocher Nov 7, 2022
453dab5
comment TODO
AyushExel Nov 7, 2022
328ff39
fix ci
AyushExel Nov 7, 2022
20af456
Add CLS CLI CI
glenn-jocher Nov 7, 2022
bb8d54d
Update CI imgsz 32 no torch 1.7
glenn-jocher Nov 7, 2022
822db9a
fix
glenn-jocher Nov 7, 2022
231fb55
MNIST160
glenn-jocher Nov 7, 2022
c77f125
Merge branch 'main' into callbacks
glenn-jocher Nov 8, 2022
233bae0
update
AyushExel Nov 8, 2022
292c78e
Add yolov5n-seg.yaml
glenn-jocher Nov 8, 2022
9408e60
update: add seg model
AyushExel Nov 8, 2022
9b6e777
Remove model/dir (duplicate of models/ dir)
glenn-jocher Nov 8, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ jobs:
os: [ ubuntu-latest ]
python-version: [ '3.10' ]
model: [ yolov5n ]
include:
torch: [ latest ]
# include:
# - os: ubuntu-latest
# python-version: '3.7' # '3.6.8' min
# model: yolov5n
Expand All @@ -31,10 +32,10 @@ jobs:
# - os: ubuntu-latest
# python-version: '3.9'
# model: yolov5n
- os: ubuntu-latest
python-version: '3.8' # torch 1.7.0 requires python >=3.6, <=3.8
model: yolov5n
torch: '1.7.0' # min torch version CI https://pypi.org/project/torchvision/
# - os: ubuntu-latest
# python-version: '3.8' # torch 1.7.0 requires python >=3.6, <=3.8
# model: yolov5n
# torch: '1.7.0' # min torch version CI https://pypi.org/project/torchvision/
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
Expand Down Expand Up @@ -93,9 +94,8 @@ jobs:
- name: Test segmentation
shell: bash # for Windows compatibility
run: |
echo "TODO"
python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-segments epochs=1 img_size=64
- name: Test classification
shell: bash # for Windows compatibility
run: |
echo "TODO"
# python ultralytics/yolo/v8/classify/train.py model=resnet18 data=mnist2560 epochs=1 img_size=64
python ultralytics/yolo/v8/classify/train.py model=resnet18 data=mnist160 epochs=1 img_size=32
3 changes: 2 additions & 1 deletion ultralytics/yolo/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from itertools import repeat
from multiprocessing.pool import Pool
from pathlib import Path
from typing import OrderedDict

import torchvision
from tqdm import tqdm
Expand Down Expand Up @@ -205,7 +206,7 @@ def __getitem__(self, i):
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
else:
sample = self.torch_transforms(im)
return sample, j
return OrderedDict(img=sample, cls=j)


# TODO: support semantic segmentation
Expand Down
57 changes: 46 additions & 11 deletions ultralytics/yolo/engine/trainer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
"""
Simple training loop; Boilerplate that could apply to any arbitrary neural network,
"""
# TODOs
# 1. finish _set_model_attributes
# 2. allow num_class update for both pretrained and csv_loaded models
# 3. save

import os
import time
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from telnetlib import TLS
from typing import Dict, Union

import torch
Expand Down Expand Up @@ -52,6 +57,8 @@ def __init__(self, config=DEFAULT_CONFIG, overrides={}):

# Model and Dataloaders.
self.trainset, self.testset = self.get_dataset(self.args.data)
if self.args.cfg is not None:
self.model = self.load_cfg(self.args.cfg)
if self.args.model is not None:
self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device)

Expand Down Expand Up @@ -133,6 +140,20 @@ def _setup_train(self, rank):
self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=rank)
self.validator = self.get_validator()
print("created testloader :", rank)
self.console.info(self.progress_string())

def _set_model_attributes(self):
# TODO: fix and use after self.data_dict is available
'''
head = utils.torch_utils.de_parallel(self.model).model[-1]
self.args.box *= 3 / head.nl # scale to layers
self.args.cls *= head.nc / 80 * 3 / head.nl # scale to classes and layers
self.args.obj *= (self.args.img_size / 640) ** 2 * 3 / nl # scale to image size and layers
model.nc = nc # attach number of classes to model
model.hyp = hyp # attach hyperparameters to model
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
model.names = names
'''

def _do_train(self, rank, world_size):
if world_size > 1:
Expand All @@ -153,13 +174,17 @@ def _do_train(self, rank, world_size):
pbar = tqdm(enumerate(self.train_loader),
total=len(self.train_loader),
bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
tloss = 0
for i, (images, labels) in pbar:
tloss = None
for i, batch in pbar:
# img, label (classification)/ img, targets, paths, _, masks(detection)
# callback hook. on_batch_start
# forward
images, labels = self.preprocess_batch(images, labels)
self.loss = self.criterion(self.model(images), labels)
tloss = (tloss * i + self.loss.item()) / (i + 1)
batch = self.preprocess_batch(batch)

# TODO: warmup, multiscale
preds = self.model(batch["img"])
self.loss, self.loss_items = self.criterion(preds, batch)
tloss = (tloss * i + self.loss_items) / (i + 1) if tloss is not None else self.loss_items

# backward
self.model.zero_grad(set_to_none=True)
Expand All @@ -170,9 +195,13 @@ def _do_train(self, rank, world_size):
self.trigger_callbacks('on_batch_end')

# log
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
loss_len = tloss.shape[0] if len(tloss.size()) else 1
losses = tloss if loss_len > 1 else torch.unsqueeze(tloss, 0)
if rank in {-1, 0}:
pbar.desc = f"{f'{epoch + 1}/{self.args.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36
pbar.set_description(
(" {} " + "{:.3f} " * (2 + loss_len)).format(f'{epoch + 1}/{self.args.epochs}', mem, *losses,
batch["img"].shape[-1]))

if rank in [-1, 0]:
# validation
Expand Down Expand Up @@ -240,6 +269,9 @@ def get_model(self, model, pretrained):

return model

def load_cfg(self, cfg):
raise NotImplementedError("This task trainer doesn't support loading cfg files")

def get_validator(self):
pass

Expand All @@ -250,11 +282,11 @@ def optimizer_step(self):
self.scaler.update()
self.optimizer.zero_grad()

def preprocess_batch(self, images, labels):
def preprocess_batch(self, batch):
"""
Allows custom preprocessing model inputs and ground truths depending on task type
"""
return images.to(self.device, non_blocking=True), labels.to(self.device)
return batch

def validate(self):
"""
Expand All @@ -270,14 +302,17 @@ def validate(self):
def build_targets(self, preds, targets):
pass

def criterion(self, preds, targets):
def criterion(self, preds, batch):
"""
Returns loss and individual loss items as Tensor
"""
pass

def progress_string(self):
"""
Returns progress string depending on task type.
"""
pass
return ''

def usage_help(self):
"""
Expand Down
40 changes: 24 additions & 16 deletions ultralytics/yolo/engine/validator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging

import torch
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm

from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import select_device

Expand All @@ -12,58 +14,64 @@ class BaseValidator:
Base validator class.
"""

def __init__(self, dataloader, device='', half=False, pbar=None, logger=None):
def __init__(self, dataloader, pbar=None, logger=None, args=None):
self.dataloader = dataloader
self.half = half
self.device = select_device(device, dataloader.batch_size)
self.pbar = pbar
self.logger = logger or logging.getLogger()
self.args = args or OmegaConf.load(DEFAULT_CONFIG)
self.device = select_device(self.args.device, dataloader.batch_size)
self.cuda = self.device.type != 'cpu'
self.batch_i = None
self.training = True

def __call__(self, trainer=None, model=None):
"""
Supports validation of a pre-trained model if passed or a model being trained
if trainer is passed (trainer gets priority).
"""
training = trainer is not None
self.training = training
# trainer = trainer or self.trainer_class.get_trainer()
assert training or model is not None, "Either trainer or model is needed for validation"
if training:
model = trainer.model
self.half &= self.device.type != 'cpu'
model = model.half() if self.half else model
self.args.half &= self.device.type != 'cpu'
model = model.half() if self.args.half else model
else: # TODO: handle this when detectMultiBackend is supported
# model = DetectMultiBacked(model)
pass
# TODO: implement init_model_attributes()

model.eval()
dt = Profile(), Profile(), Profile(), Profile()
loss = 0
n_batches = len(self.dataloader)
desc = self.set_desc()
desc = self.get_desc()
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
self.init_metrics()
self.init_metrics(model)
with torch.cuda.amp.autocast(enabled=self.device.type != 'cpu'):
for images, labels in bar:
for batch_i, batch in enumerate(bar):
self.batch_i = batch_i
# pre-process
with dt[0]:
images, labels = self.preprocess_batch(images, labels)
batch = self.preprocess_batch(batch)

# inference
with dt[1]:
preds = model(images)
preds = model(batch["img"])
# TODO: remember to add native augmentation support when implementing model, like:
# preds, train_out = model(im, augment=augment)

# loss
with dt[2]:
if training:
loss += trainer.criterion(preds, labels) / images.shape[0]
loss += trainer.criterion(preds, batch)[0]

# pre-process predictions
with dt[3]:
preds = self.preprocess_preds(preds)

self.update_metrics(preds, labels)
self.update_metrics(preds, batch)

stats = self.get_stats()
self.check_stats(stats)
Expand All @@ -81,16 +89,16 @@ def __call__(self, trainer=None, model=None):

return stats

def preprocess_batch(self, images, labels):
return images.to(self.device, non_blocking=True), labels.to(self.device)
def preprocess_batch(self, batch):
return batch

def preprocess_preds(self, preds):
return preds

def init_metrics(self):
pass

def update_metrics(self, preds, targets):
def update_metrics(self, preds, batch):
pass

def get_stats(self):
Expand All @@ -102,5 +110,5 @@ def check_stats(self, stats):
def print_results(self):
pass

def set_desc(self):
def get_desc(self):
pass
19 changes: 19 additions & 0 deletions ultralytics/yolo/utils/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# Train settings -------------------------------------------------------------------------------------------------------
model: null # i.e. yolov5s.pt
cfg: null # i.e. yolov5s.yaml
data: null # i.e. coco128.yaml
epochs: 300
batch_size: 16
Expand All @@ -20,6 +21,23 @@ optimizer: 'SGD' # choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
verbose: False
seed: 0
local_rank: -1
single_cls: False # train multi-class data as single-class
image_weights: False # use weighted image selection for training
shuffle: True
rect: False # support rectangular training
overlap_mask: True # Segmentation masks overlap
mask_ratio: 4 # Segmentation mask downsample ratio

# Val/Test settings ----------------------------------------------------------------------------------------------------
save_json: False
save_hybrid: False
conf_thres: 0.001
iou_thres: 0.6
max_det: 300
half: True
plots: False
save_txt: False
task: 'val'

# Hyperparameters ------------------------------------------------------------------------------------------------------
lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
Expand Down Expand Up @@ -51,6 +69,7 @@ fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability)
label_smoothing: 0.0

# Hydra configs --------------------------------------------------------------------------------------------------------
hydra:
Expand Down
Loading