Skip to content

Commit

Permalink
v0.5.0 (#124)
Browse files Browse the repository at this point in the history
fixes #118
fixes #111
fixes #84
fixes #67
fixes #88
  • Loading branch information
zhanghang1989 committed Sep 26, 2018
1 parent d4e1955 commit f891919
Show file tree
Hide file tree
Showing 68 changed files with 2,128 additions and 485 deletions.
42 changes: 42 additions & 0 deletions docs/source/_static/js/hidebib.js
@@ -0,0 +1,42 @@
// adapted from: http://www.robots.ox.ac.uk/~vedaldi/assets/hidebib.js
function hideallbibs()
{
var el = document.getElementsByTagName("div") ;
for (var i = 0 ; i < el.length ; ++i) {
if (el[i].className == "paper") {
var bib = el[i].getElementsByTagName("pre") ;
if (bib.length > 0) {
bib [0] .style.display = 'none' ;
}
}
}
}

function togglebib(paperid)
{
var paper = document.getElementById(paperid) ;
var bib = paper.getElementsByTagName('pre') ;
if (bib.length > 0) {
if (bib [0] .style.display == 'none') {
bib [0] .style.display = 'block' ;
} else {
bib [0] .style.display = 'none' ;
}
}
}

function toggleblock(blockId)
{
var block = document.getElementById(blockId);
if (block.style.display == 'none') {
block.style.display = 'block' ;
} else {
block.style.display = 'none' ;
}
}

function hideblock(blockId)
{
var block = document.getElementById(blockId);
block.style.display = 'none' ;
}
2 changes: 1 addition & 1 deletion docs/source/_templates/layout.html
Expand Up @@ -3,5 +3,5 @@
{%- block extrahead %}


<script type="text/javascript" src="http://zhanghang1989.github.io/files/hidebib.js"></script>
<script type="text/javascript" src="../_static/js/hidebib.js"></script>
{% endblock %}
52 changes: 37 additions & 15 deletions docs/source/experiments/segmentation.rst
Expand Up @@ -23,26 +23,34 @@ Test Pre-trained Model

model = encoding.models.get_model('FCN_ResNet50_PContext', pretrained=True)

Prepare the datasets by runing the scripts in the ``scripts/`` folder, for example preparing ``PASCAL Context`` dataset::

python scripts/prepare_pcontext.py
The test script is in the ``experiments/segmentation/`` folder. For evaluating the model (using MS),
for example ``Encnet_ResNet50_PContext``::

python test.py --dataset PContext --model-zoo Encnet_ResNet50_PContext --eval
# pixAcc: 0.7888, mIoU: 0.5056: 100%|████████████████████████| 1276/1276 [46:31<00:00, 2.19s/it]
# pixAcc: 0.792, mIoU: 0.510: 100%|████████████████████████| 1276/1276 [46:31<00:00, 2.19s/it]

The command for training the model can be found by clicking ``cmd`` in the table.

.. role:: raw-html(raw)
:format: html

+----------------------------------+-----------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
| Model | pixAcc | mIoU | Note | Command | Logs |
+==================================+===========+===========+===========+==============================================================================================+============+
| Encnet_ResNet50_PContext | 78.9% | 50.6% | | :raw-html:`<a href="javascript:toggleblock('cmd_enc50_pcont')" class="toggleblock">cmd</a>` | ENC50PC_ |
+----------------------------------+-----------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
| EncNet_ResNet101_PContext | 80.3% | 53.2% | | :raw-html:`<a href="javascript:toggleblock('cmd_enc101_pcont')" class="toggleblock">cmd</a>` | ENC101PC_ |
+----------------------------------+-----------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
| EncNet_ResNet50_ADE | 79.9% | 41.2% | | :raw-html:`<a href="javascript:toggleblock('cmd_enc50_ade')" class="toggleblock">cmd</a>` | ENC50ADE_ |
+----------------------------------+-----------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
| Model | pixAcc | mIoU | Command | Logs |
+==================================+===========+===========+==============================================================================================+============+
| Encnet_ResNet50_PContext | 79.2% | 51.0% | :raw-html:`<a href="javascript:toggleblock('cmd_enc50_pcont')" class="toggleblock">cmd</a>` | ENC50PC_ |
+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
| EncNet_ResNet101_PContext | 80.7% | 54.1% | :raw-html:`<a href="javascript:toggleblock('cmd_enc101_pcont')" class="toggleblock">cmd</a>` | ENC101PC_ |
+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
| EncNet_ResNet50_ADE | 80.1% | 41.5% | :raw-html:`<a href="javascript:toggleblock('cmd_enc50_ade')" class="toggleblock">cmd</a>` | ENC50ADE_ |
+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
| EncNet_ResNet101_ADE | 81.3% | 44.4% | :raw-html:`<a href="javascript:toggleblock('cmd_enc101_ade')" class="toggleblock">cmd</a>` | ENC101ADE_ |
+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+
| EncNet_ResNet101_VOC | N/A | 85.9% | :raw-html:`<a href="javascript:toggleblock('cmd_enc101_voc')" class="toggleblock">cmd</a>` | ENC101VOC_ |
+----------------------------------+-----------+-----------+----------------------------------------------------------------------------------------------+------------+

.. _ENC50PC: https://github.com/zhanghang1989/image-data/blob/master/encoding/segmentation/logs/encnet_resnet50_pcontext.log?raw=true
.. _ENC101PC: https://github.com/zhanghang1989/image-data/blob/master/encoding/segmentation/logs/encnet_resnet101_pcontext.log?raw=true
Expand Down Expand Up @@ -71,6 +79,19 @@ Test Pre-trained Model
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ADE20K --model EncNet --aux --se-loss
</code>


<code xml:space="preserve" id="cmd_enc101_ade" style="display: none; text-align: left; white-space: pre-wrap">
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset ADE20K --model EncNet --aux --se-loss --backbone resnet101
</code>

<code xml:space="preserve" id="cmd_enc101_voc" style="display: none; text-align: left; white-space: pre-wrap">
# First finetuning COCO dataset pretrained model on augmented set
# You can also train from scratch on COCO by yourself
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset Pascal_aug --model-zoo EncNet_Resnet101_COCO --aux --se-loss --lr 0.001 --syncbn --ngpus 4 --checkname res101
# Finetuning on original set
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset Pascal_voc --model encnet --aux --se-loss --backbone resnet101 --lr 0.0001 --syncbn --ngpus 4 --checkname res101 --resume runs/Pascal_aug/encnet/res101/checkpoint.params
</code>

Quick Demo
~~~~~~~~~~

Expand Down Expand Up @@ -116,13 +137,14 @@ Train Your Own Model

CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --dataset pcontext --model encnet --aux --se-loss

- Detail training options, please run ``python train.py -h``.
- Detail training options, please run ``python train.py -h``. Commands for reproducing pre-trained models can be found in the table.

- The validation metrics during the training only using center-crop is just for monitoring the
training correctness purpose. For evaluating the pretrained model on validation set using MS,
please use the command::
.. hint::
The validation metrics during the training only using center-crop is just for monitoring the
training correctness purpose. For evaluating the pretrained model on validation set using MS,
please use the command::

CUDA_VISIBLE_DEVICES=0,1,2,3 python test.py --dataset pcontext --model encnet --aux --se-loss --resume mycheckpoint --eval
CUDA_VISIBLE_DEVICES=0,1,2,3 python test.py --dataset pcontext --model encnet --aux --se-loss --resume mycheckpoint --eval

Citation
--------
Expand Down
15 changes: 9 additions & 6 deletions docs/source/experiments/texture.rst
Expand Up @@ -17,24 +17,29 @@ Test Pre-trained Model

- Install PyTorch Encoding (if not yet). Please follow the installation guide `Installing PyTorch Encoding <../notes/compile.html>`_.

- Download the `MINC-2500 <http://opensurfaces.cs.cornell.edu/publications/minc/>`_ dataset to ``$HOME/data/minc-2500/`` folder. Download pre-trained model (pre-trained on train-1 split using single training size of 224, with an error rate of :math:`19.70\%` using single crop on test-1 set)::
- Download the `MINC-2500 <http://opensurfaces.cs.cornell.edu/publications/minc/>`_ dataset using the providied script::

cd PyTorch-Encoding/experiments/recognition
cd PyTorch-Encoding/
python scripts/prepare_minc.py

- Download pre-trained model (pre-trained on train-1 split using single training size of 224, with an error rate of :math:`19.70\%` using single crop on test-1 set)::

cd experiments/recognition
python model/download_models.py

- Test pre-trained model on MINC-2500::

python main.py --dataset minc --model deepten --nclass 23 --resume deepten_minc.pth --eval
# Teriminal Output:
# Loss: 1.005 | Err: 19.704% (1133/5750): 100%|████████████████████| 23/23 [00:18<00:00, 1.26it/s]
# Loss: 1.005 | Err: 18.96% (1090/5750): 100%|████████████████████| 23/23 [00:18<00:00, 1.26it/s]


Train Your Own Model
--------------------

- Example training command for training above model::

CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --dataset minc --model deepten --nclass 23 --model deepten --batch-size 512 --lr 0.004 --epochs 80 --lr-step 60
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --dataset minc --model deepten --nclass 23 --model deepten --batch-size 512 --lr 0.004 --epochs 80 --lr-step 60 --lr-scheduler step

- Detail training options::

Expand All @@ -56,8 +61,6 @@ Train Your Own Model
--checkname set the checkpoint name
--eval evaluating

.. todo::
Provide example code for extracting features.

Extending the Software
----------------------
Expand Down
4 changes: 2 additions & 2 deletions docs/source/functions.rst
Expand Up @@ -20,10 +20,10 @@ encoding.functions
.. autofunction:: aggregate


:hidden:`scaledL2`
:hidden:`scaled_l2`
~~~~~~~~~~~~~~~~~~~

.. autofunction:: scaledL2
.. autofunction:: scaled_l2


:hidden:`sum_square`
Expand Down
2 changes: 1 addition & 1 deletion encoding/__init__.py
Expand Up @@ -10,4 +10,4 @@

"""An optimized PyTorch package with CUDA backend."""
from .version import __version__
from . import nn, functions, dilated, parallel, utils, models, datasets
from . import nn, functions, dilated, parallel, utils, models, datasets, optimizer
4 changes: 4 additions & 0 deletions encoding/datasets/__init__.py
@@ -1,14 +1,18 @@
from .base import *
from .coco import COCOSegmentation
from .ade20k import ADE20KSegmentation
from .pascal_voc import VOCSegmentation
from .pascal_aug import VOCAugSegmentation
from .pcontext import ContextSegmentation
from .cityscapes import CitySegmentation

datasets = {
'coco': COCOSegmentation,
'ade20k': ADE20KSegmentation,
'pascal_voc': VOCSegmentation,
'pascal_aug': VOCAugSegmentation,
'pcontext': ContextSegmentation,
'citys': CitySegmentation,
}

def get_segmentation_dataset(name, **kwargs):
Expand Down
13 changes: 9 additions & 4 deletions encoding/datasets/ade20k.py
Expand Up @@ -58,8 +58,8 @@ def __getitem__(self, index):
return img, mask

def _mask_transform(self, mask):
target = np.array(mask).astype('int32') - 1
return torch.from_numpy(target).long()
target = np.array(mask).astype('int64') - 1
return torch.from_numpy(target)

def __len__(self):
return len(self.images)
Expand Down Expand Up @@ -90,17 +90,22 @@ def get_path_pairs(img_folder, mask_folder):
img_folder = os.path.join(folder, 'images/training')
mask_folder = os.path.join(folder, 'annotations/training')
img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
print('len(img_paths):', len(img_paths))
assert len(img_paths) == 20210
elif split == 'val':
img_folder = os.path.join(folder, 'images/validation')
mask_folder = os.path.join(folder, 'annotations/validation')
img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
assert len(img_paths) == 2000
else:
assert split == 'trainval'
train_img_folder = os.path.join(folder, 'images/training')
train_mask_folder = os.path.join(folder, 'annotations/training')
val_img_folder = os.path.join(folder, 'images/validation')
val_mask_folder = os.path.join(folder, 'annotations/validation')
train_img_paths, train_mask_paths = get_path_pairs(train_img_folder, train_mask_folder)
val_img_paths, val_mask_paths = get_path_pairs(val_img_folder, val_mask_folder)
return train_img_paths + val_img_paths, train_mask_paths + val_mask_paths

img_paths = train_img_paths + val_img_paths
mask_paths = train_mask_paths + val_mask_paths
assert len(img_paths) == 22210
return img_paths, mask_paths
7 changes: 3 additions & 4 deletions encoding/datasets/base.py
Expand Up @@ -37,6 +37,9 @@ def num_class(self):
def pred_offset(self):
raise NotImplemented

def make_pred(self, x):
return x + self.pred_offset

def _val_sync_transform(self, img, mask):
outsize = self.crop_size
short_size = outsize
Expand Down Expand Up @@ -75,10 +78,6 @@ def _sync_transform(self, img, mask):
ow = int(1.0 * w * oh / h)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# random rotate -10~10, mask using NN rotate
deg = random.uniform(-10, 10)
img = img.rotate(deg, resample=Image.BILINEAR)
mask = mask.rotate(deg, resample=Image.NEAREST)
# pad crop
if short_size < crop_size:
padh = crop_size - oh if oh < crop_size else 0
Expand Down

0 comments on commit f891919

Please sign in to comment.