Skip to content

Commit

Permalink
feat(train): 添加SSD训练完整实现
Browse files Browse the repository at this point in the history
  • Loading branch information
zjZSTU committed May 11, 2020
1 parent cc89578 commit 0735d2b
Show file tree
Hide file tree
Showing 49 changed files with 2,744 additions and 18 deletions.
18 changes: 18 additions & 0 deletions py/configs/efficient_net_b3_ssd300_voc0712.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
MODEL:
NUM_CLASSES: 21
BACKBONE:
NAME: 'efficient_net-b3'
OUT_CHANNELS: (48, 136, 384, 256, 256, 256)
INPUT:
IMAGE_SIZE: 300
DATASETS:
TRAIN: ("voc_2007_trainval", "voc_2012_trainval")
TEST: ("voc_2007_test", )
SOLVER:
MAX_ITER: 160000
LR_STEPS: [105000, 135000]
GAMMA: 0.1
BATCH_SIZE: 24
LR: 1e-3

OUTPUT_DIR: 'outputs/efficient_net_b3_ssd300_voc0712'
27 changes: 27 additions & 0 deletions py/configs/mobilenet_v2_ssd320_voc0712.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
MODEL:
NUM_CLASSES: 21
BOX_HEAD:
PREDICTOR: 'SSDLiteBoxPredictor'
BACKBONE:
NAME: 'mobilenet_v2'
OUT_CHANNELS: (96, 1280, 512, 256, 256, 64)
PRIORS:
FEATURE_MAPS: [20, 10, 5, 3, 2, 1]
STRIDES: [16, 32, 64, 107, 160, 320]
MIN_SIZES: [60, 105, 150, 195, 240, 285]
MAX_SIZES: [105, 150, 195, 240, 285, 330]
ASPECT_RATIOS: [[2, 3], [2, 3], [2, 3], [2, 3], [2, 3], [2, 3]]
BOXES_PER_LOCATION: [6, 6, 6, 6, 6, 6]
INPUT:
IMAGE_SIZE: 320
DATASETS:
TRAIN: ("voc_2007_trainval", "voc_2012_trainval")
TEST: ("voc_2007_test", )
SOLVER:
MAX_ITER: 120000
LR_STEPS: [80000, 100000]
GAMMA: 0.1
BATCH_SIZE: 32
LR: 1e-3

OUTPUT_DIR: 'outputs/mobilenet_v2_ssd320_voc0712'
22 changes: 22 additions & 0 deletions py/configs/vgg_ssd300_coco_trainval35k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
MODEL:
NUM_CLASSES: 81
PRIORS:
FEATURE_MAPS: [38, 19, 10, 5, 3, 1]
STRIDES: [8, 16, 32, 64, 100, 300]
MIN_SIZES: [21, 45, 99, 153, 207, 261]
MAX_SIZES: [45, 99, 153, 207, 261, 315]
ASPECT_RATIOS: [[2], [2, 3], [2, 3], [2, 3], [2], [2]]
BOXES_PER_LOCATION: [4, 6, 6, 6, 4, 4]
INPUT:
IMAGE_SIZE: 300
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival", )
SOLVER:
MAX_ITER: 400000
LR_STEPS: [280000, 360000]
GAMMA: 0.1
BATCH_SIZE: 32
LR: 1e-3

OUTPUT_DIR: 'outputs/vgg_ssd300_coco_trainval35k'
15 changes: 15 additions & 0 deletions py/configs/vgg_ssd300_voc0712.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
MODEL:
NUM_CLASSES: 21
INPUT:
IMAGE_SIZE: 300
DATASETS:
TRAIN: ("voc_2007_trainval", "voc_2012_trainval")
TEST: ("voc_2007_test", )
SOLVER:
MAX_ITER: 120000
LR_STEPS: [80000, 100000]
GAMMA: 0.1
BATCH_SIZE: 32
LR: 1e-3

OUTPUT_DIR: 'outputs/vgg_ssd300_voc0712'
24 changes: 24 additions & 0 deletions py/configs/vgg_ssd512_coco_trainval35k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
MODEL:
NUM_CLASSES: 81
BACKBONE:
OUT_CHANNELS: (512, 1024, 512, 256, 256, 256, 256)
PRIORS:
FEATURE_MAPS: [64, 32, 16, 8, 4, 2, 1]
STRIDES: [8, 16, 32, 64, 128, 256, 512]
MIN_SIZES: [20.48, 51.2, 133.12, 215.04, 296.96, 378.88, 460.8]
MAX_SIZES: [51.2, 133.12, 215.04, 296.96, 378.88, 460.8, 542.72]
ASPECT_RATIOS: [[2], [2, 3], [2, 3], [2, 3], [2, 3], [2], [2]]
BOXES_PER_LOCATION: [4, 6, 6, 6, 6, 4, 4]
INPUT:
IMAGE_SIZE: 512
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival", )
SOLVER:
MAX_ITER: 520000
LR_STEPS: [360000, 480000]
GAMMA: 0.1
BATCH_SIZE: 24
LR: 1e-3

OUTPUT_DIR: 'outputs/vgg_ssd512_coco_trainval35k'
24 changes: 24 additions & 0 deletions py/configs/vgg_ssd512_voc0712.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
MODEL:
NUM_CLASSES: 21
BACKBONE:
OUT_CHANNELS: (512, 1024, 512, 256, 256, 256, 256)
PRIORS:
FEATURE_MAPS: [64, 32, 16, 8, 4, 2, 1]
STRIDES: [8, 16, 32, 64, 128, 256, 512]
MIN_SIZES: [35.84, 76.8, 153.6, 230.4, 307.2, 384.0, 460.8]
MAX_SIZES: [76.8, 153.6, 230.4, 307.2, 384.0, 460.8, 537.65]
ASPECT_RATIOS: [[2], [2, 3], [2, 3], [2, 3], [2, 3], [2], [2]]
BOXES_PER_LOCATION: [4, 6, 6, 6, 6, 4, 4]
INPUT:
IMAGE_SIZE: 512
DATASETS:
TRAIN: ("voc_2007_trainval", "voc_2012_trainval")
TEST: ("voc_2007_test", )
SOLVER:
MAX_ITER: 120000
LR_STEPS: [80000, 100000]
GAMMA: 0.1
BATCH_SIZE: 24
LR: 1e-3

OUTPUT_DIR: 'outputs/vgg_ssd512_voc0712'
70 changes: 70 additions & 0 deletions py/outputs/vgg_ssd300_voc0712/log.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
2020-05-11 14:38:39,024 SSD INFO: Using 1 GPUs
2020-05-11 14:38:39,426 SSD INFO: Namespace(config_file='configs/vgg_ssd300_voc0712.yaml', distributed=False, eval_step=2500, local_rank=0, log_step=10, num_gpus=1, opts=[], save_step=2500, skip_test=False, use_tensorboard=True)
2020-05-11 14:38:41,864 SSD INFO: Loaded configuration file configs/vgg_ssd300_voc0712.yaml
2020-05-11 14:38:47,925 SSD INFO:
MODEL:
NUM_CLASSES: 21
INPUT:
IMAGE_SIZE: 300
DATASETS:
TRAIN: ("voc_2007_trainval", "voc_2012_trainval")
TEST: ("voc_2007_test", )
SOLVER:
MAX_ITER: 120000
LR_STEPS: [80000, 100000]
GAMMA: 0.1
BATCH_SIZE: 32
LR: 1e-3

OUTPUT_DIR: 'outputs/vgg_ssd300_voc0712'
2020-05-11 14:38:49,140 SSD INFO: Running with config:
DATASETS:
TEST: ('voc_2007_test',)
TRAIN: ('voc_2007_trainval', 'voc_2012_trainval')
DATA_LOADER:
NUM_WORKERS: 8
PIN_MEMORY: True
INPUT:
IMAGE_SIZE: 300
PIXEL_MEAN: [123, 117, 104]
MODEL:
BACKBONE:
NAME: vgg
OUT_CHANNELS: (512, 1024, 512, 256, 256, 256)
PRETRAINED: True
BOX_HEAD:
NAME: SSDBoxHead
PREDICTOR: SSDBoxPredictor
CENTER_VARIANCE: 0.1
DEVICE: cuda
META_ARCHITECTURE: SSDDetector
NEG_POS_RATIO: 3
NUM_CLASSES: 21
PRIORS:
ASPECT_RATIOS: [[2], [2, 3], [2, 3], [2, 3], [2], [2]]
BOXES_PER_LOCATION: [4, 6, 6, 6, 4, 4]
CLIP: True
FEATURE_MAPS: [38, 19, 10, 5, 3, 1]
MAX_SIZES: [60, 111, 162, 213, 264, 315]
MIN_SIZES: [30, 60, 111, 162, 213, 264]
STRIDES: [8, 16, 32, 64, 100, 300]
SIZE_VARIANCE: 0.2
THRESHOLD: 0.5
OUTPUT_DIR: outputs/vgg_ssd300_voc0712
SOLVER:
BATCH_SIZE: 32
GAMMA: 0.1
LR: 0.001
LR_STEPS: [80000, 100000]
MAX_ITER: 120000
MOMENTUM: 0.9
WARMUP_FACTOR: 0.3333333333333333
WARMUP_ITERS: 500
WEIGHT_DECAY: 0.0005
TEST:
BATCH_SIZE: 10
CONFIDENCE_THRESHOLD: 0.01
MAX_PER_CLASS: -1
MAX_PER_IMAGE: 100
NMS_THRESHOLD: 0.45
2020-05-11 14:42:47,577 SSD.trainer INFO: No checkpoint found.
2 changes: 1 addition & 1 deletion py/ssd/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from ssd.config.path_catlog import DatasetCatalog
from .voc import VOCDataset
# from .coco import COCODataset
from .coco import COCODataset

_DATASETS = {
'VOCDataset': VOCDataset,
Expand Down
92 changes: 92 additions & 0 deletions py/ssd/data/datasets/coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os
import torch.utils.data
import numpy as np
from PIL import Image

from ssd.structures.container import Container


class COCODataset(torch.utils.data.Dataset):
class_names = ('__background__',
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush')

def __init__(self, data_dir, ann_file, transform=None, target_transform=None, remove_empty=False):
from pycocotools.coco import COCO
self.coco = COCO(ann_file)
self.data_dir = data_dir
self.transform = transform
self.target_transform = target_transform
self.remove_empty = remove_empty
if self.remove_empty:
# when training, images without annotations are removed.
self.ids = list(self.coco.imgToAnns.keys())
else:
# when testing, all images used.
self.ids = list(self.coco.imgs.keys())
coco_categories = sorted(self.coco.getCatIds())
self.coco_id_to_contiguous_id = {coco_id: i + 1 for i, coco_id in enumerate(coco_categories)}
self.contiguous_id_to_coco_id = {v: k for k, v in self.coco_id_to_contiguous_id.items()}

def __getitem__(self, index):
image_id = self.ids[index]
boxes, labels = self._get_annotation(image_id)
image = self._read_image(image_id)
if self.transform:
image, boxes, labels = self.transform(image, boxes, labels)
if self.target_transform:
boxes, labels = self.target_transform(boxes, labels)
targets = Container(
boxes=boxes,
labels=labels,
)
return image, targets, index

def get_annotation(self, index):
image_id = self.ids[index]
return image_id, self._get_annotation(image_id)

def __len__(self):
return len(self.ids)

def _get_annotation(self, image_id):
ann_ids = self.coco.getAnnIds(imgIds=image_id)
ann = self.coco.loadAnns(ann_ids)
# filter crowd annotations
ann = [obj for obj in ann if obj["iscrowd"] == 0]
boxes = np.array([self._xywh2xyxy(obj["bbox"]) for obj in ann], np.float32).reshape((-1, 4))
labels = np.array([self.coco_id_to_contiguous_id[obj["category_id"]] for obj in ann], np.int64).reshape((-1,))
# remove invalid boxes
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
boxes = boxes[keep]
labels = labels[keep]
return boxes, labels

def _xywh2xyxy(self, box):
x1, y1, w, h = box
return [x1, y1, x1 + w, y1 + h]

def get_img_info(self, index):
image_id = self.ids[index]
img_data = self.coco.imgs[image_id]
return img_data

def _read_image(self, image_id):
file_name = self.coco.loadImgs(image_id)[0]['file_name']
image_file = os.path.join(self.data_dir, file_name)
image = Image.open(image_file).convert("RGB")
image = np.array(image)
return image
24 changes: 24 additions & 0 deletions py/ssd/data/datasets/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from ssd.data.datasets import VOCDataset, COCODataset
from .coco import coco_evaluation
from .voc import voc_evaluation


def evaluate(dataset, predictions, output_dir, **kwargs):
"""evaluate dataset using different methods based on dataset type.
Args:
dataset: Dataset object
predictions(list[(boxes, labels, scores)]): Each item in the list represents the
prediction results for one image. And the index should match the dataset index.
output_dir: output folder, to save evaluation files or results.
Returns:
evaluation result
"""
args = dict(
dataset=dataset, predictions=predictions, output_dir=output_dir, **kwargs,
)
if isinstance(dataset, VOCDataset):
return voc_evaluation(**args)
elif isinstance(dataset, COCODataset):
return coco_evaluation(**args)
else:
raise NotImplementedError

0 comments on commit 0735d2b

Please sign in to comment.