Skip to content

PyTorch implementation for semantic segmentation (DeepLabV3+, UNet, etc.)

License

Notifications You must be signed in to change notification settings

RJT1990/pytorch-segmentation

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

32 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PytorchSegmentation

This repository implements general network for semantic segmentation.
You can train various networks like DeepLabV3+, PSPNet, UNet, etc., just by writing the config file.

DeepLabV3+

Pretrained model

You can run pretrained model converted from official tensorflow model.

DeepLabV3+(Xception65+ASPP)

$ cd tf_model
$ wget http://download.tensorflow.org/models/deeplabv3_cityscapes_train_2018_02_06.tar.gz
$ tar -xvf deeplabv3_cityscapes_train_2018_02_06.tar.gz
$ cd ../src
$ python -m converter.convert_xception65 ../tf_model/deeplabv3_cityscapes_train/model.ckpt 19 ../model/cityscapes_deeplab_v3_plus/model.pth

Then you can test the performance of trained network.

$ python eval.py

MobilenetV2

$ cd tf_model
$ wget http://download.tensorflow.org/models/deeplabv3_mnv2_cityscapes_train_2018_02_05.tar.gz
$ tar -xvf deeplabv3_mnv2_cityscapes_train_2018_02_05.tar.gz
$ cd ../src
$ python -m converter.mobilenetv2 ../tf_model/deeplabv3_mnv2_cityscapes_train/model.ckpt 19 ../model/cityscapes_mobilnetv2/model.pth

How to train

In order to train model, you have only to setup config file.
For example, write config file as below and save it as config/pascal_unet_res18_scse.yaml.

Net:
  enc_type: 'resnet18'
  dec_type: 'unet_scse'
  num_filters: 8
  pretrained: True
Data:
  dataset: 'pascal'
  target_size: (512, 512)
Train:
  max_epoch: 20
  batch_size: 2
  fp16: True
  resume: False
  pretrained_path:
Loss:
  loss_type: 'Lovasz'
  ignore_index: 255
Optimizer:
  mode: 'adam'
  base_lr: 0.001
  t_max: 10

Then you can train this model by:

$ python train.py ../config/pascal_unet_res18_scse.yaml

Dataset

Directory tree

.
├── config
├── data
│   ├── cityscapes
│   │   ├── gtFine
│   │   └── leftImg8bit
│   └── pascal_voc_2012
│        └── VOCdevkit
│            └── VOC2012
│                ├── JPEGImages
│                ├── SegmentationClass
│                └── SegmentationClassAug
├── logs
├── model
└── src
    ├── dataset
    ├── logger
    ├── losses
    │   ├── binary
    │   └── multi
    ├── models
    └── utils

Environments

  • OS: Ubuntu18.04
  • python: 3.7.0
  • pytorch: 1.0.0
  • pretrainedmodels: 0.7.4
  • albumentations: 0.1.8

if you want to train models in fp16

  • NVIDIA/apex: 0.1

Reference

Encoder

Decoder

SCSE

IBN

OC

PSP

ASPP

About

PyTorch implementation for semantic segmentation (DeepLabV3+, UNet, etc.)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.9%
  • Shell 0.1%