Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use RGB instead of BGR #190

Merged
merged 5 commits into from
Aug 3, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -18,8 +18,9 @@ Color encoding of semantic categories can be found here:
https://docs.google.com/spreadsheets/d/1se8YEtb2detS7OuPE86fXGyD269pMycAWe2mtKUj2W8/edit?usp=sharing

## Updates
- We use configuration files to store most options which were in argument parser. The definitions of options are detailed in ```config/defaults.py```.
- HRNet model is now supported.
- We use configuration files to store most options which were in argument parser. The definitions of options are detailed in ```config/defaults.py```.
- We conform to Pytorch practice in data preprocessing (RGB [0, 1], substract mean, divide std).


## Highlights
@@ -61,7 +62,7 @@ Decoder:
- UPerNet (Pyramid Pooling + FPN head, see [UperNet](https://arxiv.org/abs/1807.10221) for details.)

## Performance:
IMPORTANT: We use our self-trained base model on ImageNet. The model takes the input in BGR form (consistent with opencv) instead of RGB form as used by default implementation of PyTorch. The base model will be automatically downloaded when needed.
IMPORTANT: The base ResNet in our repository is a customized (different from the one in torchvision). The base models will be automatically downloaded when needed.

<table><tbody>
<th valign="bottom">Architecture</th>
6 changes: 3 additions & 3 deletions config/ade20k-resnet101dilated-ppm_deepsup.yaml
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@ MODEL:

TRAIN:
batch_size_per_gpu: 2
num_epoch: 20
num_epoch: 25
start_epoch: 0
epoch_iters: 5000
optim: "SGD"
@@ -33,10 +33,10 @@ TRAIN:

VAL:
visualize: False
checkpoint: "epoch_20.pth"
checkpoint: "epoch_25.pth"

TEST:
checkpoint: "epoch_20.pth"
checkpoint: "epoch_25.pth"
result: "./"

DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup"
6 changes: 3 additions & 3 deletions config/ade20k-resnet50-upernet.yaml
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@ MODEL:

TRAIN:
batch_size_per_gpu: 2
num_epoch: 40
num_epoch: 30
start_epoch: 0
epoch_iters: 5000
optim: "SGD"
@@ -33,10 +33,10 @@ TRAIN:

VAL:
visualize: False
checkpoint: "epoch_40.pth"
checkpoint: "epoch_30.pth"

TEST:
checkpoint: "epoch_40.pth"
checkpoint: "epoch_30.pth"
result: "./"

DIR: "ckpt/ade20k-resnet50-upernet"
146 changes: 76 additions & 70 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
import os
import json
import torch
import cv2
from torchvision import transforms
import numpy as np
import PIL
from PIL import Image


def imresize(im, size, interp='bilinear'):
if interp == 'nearest':
resample = PIL.Image.NEAREST
resample = Image.NEAREST
elif interp == 'bilinear':
resample = PIL.Image.BILINEAR
resample = Image.BILINEAR
elif interp == 'bicubic':
resample = PIL.Image.BICUBIC
resample = Image.BICUBIC
else:
raise Exception('resample method undefined!')

return np.array(
PIL.Image.fromarray(im).resize((size[1], size[0]), resample)
)
return im.resize(size, resample)


class BaseDataset(torch.utils.data.Dataset):
@@ -35,8 +32,8 @@ def __init__(self, odgt, opt, **kwargs):

# mean and std
self.normalize = transforms.Normalize(
mean=[102.9801, 115.9465, 122.7717],
std=[1., 1., 1.])
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1):
if isinstance(odgt, list):
@@ -54,12 +51,17 @@ def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1):
print('# samples: {}'.format(self.num_sample))

def img_transform(self, img):
# image to float
img = img.astype(np.float32)
# 0-255 to 0-1
img = np.float32(np.array(img)) / 255.
img = img.transpose((2, 0, 1))
img = self.normalize(torch.from_numpy(img.copy()))
return img

def segm_transform(self, segm):
# to tensor, -1 to 149
segm = torch.from_numpy(np.array(segm)).long() - 1
return segm

# Round x to the nearest multiple of p and x' >= x
def round2nearest_multiple(self, x, p):
return ((x - 1) // p + 1) * p
@@ -69,7 +71,6 @@ class TrainDataset(BaseDataset):
def __init__(self, root_dataset, odgt, opt, batch_per_gpu=1, **kwargs):
super(TrainDataset, self).__init__(odgt, opt, **kwargs)
self.root_dataset = root_dataset
self.random_flip = opt.random_flip
# down sampling rate of segm labe
self.segm_downsampling_rate = opt.segm_downsampling_rate
self.batch_per_gpu = batch_per_gpu
@@ -124,71 +125,74 @@ def __getitem__(self, index):

# calculate the BATCH's height and width
# since we concat more than one samples, the batch's h and w shall be larger than EACH sample
batch_resized_size = np.zeros((self.batch_per_gpu, 2), np.int32)
batch_widths = np.zeros(self.batch_per_gpu, np.int32)
batch_heights = np.zeros(self.batch_per_gpu, np.int32)
for i in range(self.batch_per_gpu):
img_height, img_width = batch_records[i]['height'], batch_records[i]['width']
this_scale = min(
this_short_size / min(img_height, img_width), \
self.imgMaxSize / max(img_height, img_width))
img_resized_height, img_resized_width = img_height * this_scale, img_width * this_scale
batch_resized_size[i, :] = img_resized_height, img_resized_width
batch_resized_height = np.max(batch_resized_size[:, 0])
batch_resized_width = np.max(batch_resized_size[:, 1])
batch_widths[i] = img_width * this_scale
batch_heights[i] = img_height * this_scale

# Here we must pad both input image and segmentation map to size h' and w' so that p | h' and p | w'
batch_resized_height = int(self.round2nearest_multiple(batch_resized_height, self.padding_constant))
batch_resized_width = int(self.round2nearest_multiple(batch_resized_width, self.padding_constant))

assert self.padding_constant >= self.segm_downsampling_rate,\
'padding constant must be equal or large than segm downsamping rate'
batch_images = torch.zeros(self.batch_per_gpu, 3, batch_resized_height, batch_resized_width)
batch_width = np.max(batch_widths)
batch_height = np.max(batch_heights)
batch_width = int(self.round2nearest_multiple(batch_width, self.padding_constant))
batch_height = int(self.round2nearest_multiple(batch_height, self.padding_constant))

assert self.padding_constant >= self.segm_downsampling_rate, \
'padding constant must be equal or large than segm downsamping rate'
batch_images = torch.zeros(
self.batch_per_gpu, 3, batch_height, batch_width)
batch_segms = torch.zeros(
self.batch_per_gpu, batch_resized_height // self.segm_downsampling_rate, \
batch_resized_width // self.segm_downsampling_rate).long()
self.batch_per_gpu,
batch_height // self.segm_downsampling_rate,
batch_width // self.segm_downsampling_rate).long()

for i in range(self.batch_per_gpu):
this_record = batch_records[i]

# load image and label
image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
img = cv2.imread(image_path, cv2.IMREAD_COLOR)
segm = cv2.imread(segm_path, cv2.IMREAD_GRAYSCALE)

assert(img.ndim == 3)
assert(segm.ndim == 2)
assert(img.shape[0] == segm.shape[0])
assert(img.shape[1] == segm.shape[1])
img = Image.open(image_path).convert('RGB')
segm = Image.open(segm_path)
assert(segm.mode == "L")
assert(img.size[0] == segm.size[0])
assert(img.size[1] == segm.size[1])

if self.random_flip is True:
random_flip = np.random.choice([0, 1])
if random_flip == 1:
img = cv2.flip(img, 1)
segm = cv2.flip(segm, 1)
# random_flip
if np.random.choice([0, 1]):
img = img.transpose(Image.FLIP_LEFT_RIGHT)
segm = segm.transpose(Image.FLIP_LEFT_RIGHT)

# note that each sample within a mini batch has different scale param
img = imresize(img, (batch_resized_size[i, 0], batch_resized_size[i, 1]), interp='bilinear')
segm = imresize(segm, (batch_resized_size[i, 0], batch_resized_size[i, 1]), interp='nearest')

# to avoid seg label misalignment
segm_rounded_height = self.round2nearest_multiple(segm.shape[0], self.segm_downsampling_rate)
segm_rounded_width = self.round2nearest_multiple(segm.shape[1], self.segm_downsampling_rate)
segm_rounded = np.zeros((segm_rounded_height, segm_rounded_width), dtype='uint8')
segm_rounded[:segm.shape[0], :segm.shape[1]] = segm

img = imresize(img, (batch_widths[i], batch_heights[i]), interp='bilinear')
segm = imresize(segm, (batch_widths[i], batch_heights[i]), interp='nearest')

# further downsample seg label, need to avoid seg label misalignment
segm_rounded_width = self.round2nearest_multiple(segm.size[0], self.segm_downsampling_rate)
segm_rounded_height = self.round2nearest_multiple(segm.size[1], self.segm_downsampling_rate)
segm_rounded = Image.new('L', (segm_rounded_width, segm_rounded_height), 0)
segm_rounded.paste(segm, (0, 0))
segm = imresize(
segm_rounded,
(segm_rounded.shape[0] // self.segm_downsampling_rate, \
segm_rounded.shape[1] // self.segm_downsampling_rate), \
(segm_rounded.size[0] // self.segm_downsampling_rate, \
segm_rounded.size[1] // self.segm_downsampling_rate), \
interp='nearest')

# image transform
# image transform, to torch float tensor 3xHxW
img = self.img_transform(img)

# segm transform, to torch long tensor HxW
segm = self.segm_transform(segm)

# put into batch arrays
batch_images[i][:, :img.shape[1], :img.shape[2]] = img
batch_segms[i][:segm.shape[0], :segm.shape[1]] = torch.from_numpy(segm.astype(np.int)).long()
batch_segms[i][:segm.shape[0], :segm.shape[1]] = segm

batch_segms = batch_segms - 1 # label from -1 to 149
output = dict()
output['img_data'] = batch_images
output['seg_label'] = batch_segms
@@ -209,10 +213,13 @@ def __getitem__(self, index):
# load image and label
image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
img = cv2.imread(image_path, cv2.IMREAD_COLOR)
segm = cv2.imread(segm_path, cv2.IMREAD_GRAYSCALE)
img = Image.open(image_path).convert('RGB')
segm = Image.open(segm_path)
assert(segm.mode == "L")
assert(img.size[0] == segm.size[0])
assert(img.size[1] == segm.size[1])

ori_height, ori_width, _ = img.shape
ori_width, ori_height = img.size

img_resized_list = []
for this_short_size in self.imgSizes:
@@ -222,24 +229,23 @@ def __getitem__(self, index):
target_height, target_width = int(ori_height * scale), int(ori_width * scale)

# to avoid rounding in network
target_height = self.round2nearest_multiple(target_height, self.padding_constant)
target_width = self.round2nearest_multiple(target_width, self.padding_constant)
target_height = self.round2nearest_multiple(target_height, self.padding_constant)

# resize
img_resized = cv2.resize(img.copy(), (target_width, target_height))
# resize images
img_resized = imresize(img, (target_width, target_height), interp='bilinear')

# image transform
# image transform, to torch float tensor 3xHxW
img_resized = self.img_transform(img_resized)

img_resized = torch.unsqueeze(img_resized, 0)
img_resized_list.append(img_resized)

segm = torch.from_numpy(segm.astype(np.int)).long()
# segm transform, to torch long tensor HxW
segm = self.segm_transform(segm)
batch_segms = torch.unsqueeze(segm, 0)

batch_segms = batch_segms - 1 # label from -1 to 149
output = dict()
output['img_ori'] = img.copy()
output['img_ori'] = np.array(img)
output['img_data'] = [x.contiguous() for x in img_resized_list]
output['seg_label'] = batch_segms.contiguous()
output['info'] = this_record['fpath_img']
@@ -255,11 +261,11 @@ def __init__(self, odgt, opt, **kwargs):

def __getitem__(self, index):
this_record = self.list_sample[index]
# load image and label
# load image
image_path = this_record['fpath_img']
img = cv2.imread(image_path, cv2.IMREAD_COLOR)
img = Image.open(image_path).convert('RGB')

ori_height, ori_width, _ = img.shape
ori_width, ori_height = img.size

img_resized_list = []
for this_short_size in self.imgSizes:
@@ -269,19 +275,19 @@ def __getitem__(self, index):
target_height, target_width = int(ori_height * scale), int(ori_width * scale)

# to avoid rounding in network
target_height = self.round2nearest_multiple(target_height, self.padding_constant)
target_width = self.round2nearest_multiple(target_width, self.padding_constant)
target_height = self.round2nearest_multiple(target_height, self.padding_constant)

# resize
img_resized = cv2.resize(img.copy(), (target_width, target_height))
# resize images
img_resized = imresize(img, (target_width, target_height), interp='bilinear')

# image transform
# image transform, to torch float tensor 3xHxW
img_resized = self.img_transform(img_resized)
img_resized = torch.unsqueeze(img_resized, 0)
img_resized_list.append(img_resized)

output = dict()
output['img_ori'] = img.copy()
output['img_ori'] = np.array(img)
output['img_data'] = [x.contiguous() for x in img_resized_list]
output['info'] = this_record['fpath_img']
return output
4 changes: 2 additions & 2 deletions demo_test.sh
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

# Image and model names
TEST_IMG=ADE_val_00001519.jpg
MODEL_PATH=baseline-resnet50dilated-ppm_deepsup
MODEL_PATH=ade20k-resnet50dilated-ppm_deepsup
RESULT_PATH=./

ENCODER=$MODEL_PATH/encoder_epoch_20.pth
@@ -28,4 +28,4 @@ python3 -u test.py \
--cfg config/ade20k-resnet50dilated-ppm_deepsup.yaml \
DIR $MODEL_PATH \
TEST.result ./ \
TEST.suffix _epoch_20.pth
TEST.checkpoint epoch_20.pth
7 changes: 2 additions & 5 deletions eval.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
from utils import AverageMeter, colorEncode, accuracy, intersectionAndUnion, setup_logger
from lib.nn import user_scattered_collate, async_copy_to
from lib.utils import as_numpy
import cv2
from PIL import Image
from tqdm import tqdm

colors = loadmat('data/color150.mat')['colors']
@@ -35,10 +35,7 @@ def visualize_result(data, pred, dir_result):
axis=1).astype(np.uint8)

img_name = info.split('/')[-1]
cv2.imwrite(
os.path.join(dir_result, img_name.replace('.jpg', '.png')),
im_vis
)
Image.fromarray(im_vis).save(os.path.join(dir_result, img_name.replace('.jpg', '.png')))


def evaluate(segmentation_module, loader, cfg, gpu):
7 changes: 2 additions & 5 deletions eval_multipro.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
from utils import AverageMeter, colorEncode, accuracy, intersectionAndUnion, parse_devices, setup_logger
from lib.nn import user_scattered_collate, async_copy_to
from lib.utils import as_numpy
import cv2
from PIL import Image
from tqdm import tqdm

colors = loadmat('data/color150.mat')['colors']
@@ -36,10 +36,7 @@ def visualize_result(data, pred, dir_result):
axis=1).astype(np.uint8)

img_name = info.split('/')[-1]
cv2.imwrite(
os.path.join(dir_result, img_name.replace('.jpg', '.png')),
im_vis
)
Image.fromarray(im_vis).save(os.path.join(dir_result, img_name.replace('.jpg', '.png')))


def evaluate(segmentation_module, loader, cfg, gpu_id, result_queue):
Loading