Skip to content

Commit

Permalink
fix augmentation mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
poodarchu committed Oct 16, 2021
1 parent c9ebdb1 commit 0398918
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 9 deletions.
16 changes: 7 additions & 9 deletions datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
"""
from pathlib import Path

import numpy as np

import torch
import torch.utils.data
import torchvision
Expand Down Expand Up @@ -120,19 +122,15 @@ def make_coco_transforms(image_set):
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
scales = np.arange(240, 1600).tolist()

if image_set == 'train':
return T.Compose([
T.RandomHorizontalFlip(),
T.RandomSelect(
T.RandomResize(scales, max_size=1333),
T.Compose([
T.RandomResize([400, 500, 600]),
T.RandomSizeCrop(384, 600),
T.RandomResize(scales, max_size=1333),
])
),
T.RandomResize(scales),
T.RandomSizeCrop(240, 1600),
T.RandomResize([800, ], max_size=1333),
T.RandomDistortion(0.5, 0.5, 0.5, 0.5),
normalize,
])

Expand Down
19 changes: 19 additions & 0 deletions datasets/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import torchvision.transforms as T
import torchvision.transforms.functional as F

import numpy as np
import cv2

from util.box_ops import box_xyxy_to_cxcywh
from util.misc import interpolate

Expand Down Expand Up @@ -258,6 +261,22 @@ def __call__(self, image, target=None):
return image, target


class RandomDistortion(object):
"""
Distort image w.r.t hue, saturation and exposure.
"""

def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, prob=0.5):
self.prob = prob
self.tfm = T.ColorJitter(brightness, contrast, saturation, hue)

def __call__(self, img, target=None):
if np.random.random() < self.prob:
return self.tfm(img), target
else:
return img, target


class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
Expand Down
1 change: 1 addition & 0 deletions train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python -m torch.distributed.launch --master_port=3141 --nproc_per_node 16 --use_env main.py --coco_path ~/datasets/COCO/ --batch_size 3 --lr 0.001

0 comments on commit 0398918

Please sign in to comment.