Skip to content

Commit

Permalink
update deeplabv3
Browse files Browse the repository at this point in the history
  • Loading branch information
tfzhou committed Jul 3, 2021
1 parent 76e319a commit 3101207
Show file tree
Hide file tree
Showing 18 changed files with 292 additions and 157 deletions.
13 changes: 8 additions & 5 deletions README.md
Expand Up @@ -4,7 +4,7 @@

> [**Exploring Cross-Image Pixel Contrast for Semantic Segmentation**](https://arxiv.org/abs/2101.11939),
> [Wenguan Wang](https://sites.google.com/view/wenguanwang/), [Tianfei Zhou](https://www.tfzhou.com/), [Fisher Yu](https://www.yf.io/), [Jifeng Dai](https://jifengdai.org/), [Ender Konukoglu](https://scholar.google.com/citations?user=OeEMrhQAAAAJ&hl=en) and [Luc Van Gool](https://scholar.google.com/citations?user=TwMib_QAAAAJ&hl=en) <br>
> *arXiv technical report ([arXiv 2101.11939](https://arxiv.org/abs/2101.11939))*
> *arXiv technical report ([arXiv 2101.11939](https://arxiv.org/abs/2101.11939))*
## News

Expand Down Expand Up @@ -37,10 +37,13 @@ Please follow the [Getting Started](https://github.com/openseg-group/openseg.pyt

### Cityscapes Dataset

| Backbone | Train Set | Val Set | Iterations | Batch Size | Contrast Loss | Memory | mIoU | Log | CKPT |Script |
| --------- | --------- | ------- | ---------- | ---------- | ------------- | ------ | ----- | --- | ---- | ---- |
| HRNet-W48 | train | val | 40000 | 8 | N | N | 79.27 | [log](https://github.com/tfzhou/pretrained_weights/releases/download/v0.1/hrnet_w48_lr1x_hrnet_ce.log) | [ckpt](https://github.com/tfzhou/pretrained_weights/releases/download/v0.1/hrnet_w48_lr1x_hrnet_ce_max_performance.pth) |```scripts/cityscapes/hrnet/run_h_48_d_4.sh```|
| HRNet-W48 | train | val | 40000 | 8 | Y | N | 80.18 | [log](https://github.com/tfzhou/pretrained_weights/releases/download/v0.1/hrnet_w48_contrast_lr1x_hrnet_contrast_t0.1.log) | [ckpt](https://github.com/tfzhou/pretrained_weights/releases/download/v0.1/hrnet_w48_contrast_lr1x_hrnet_contrast_t0.1_max_performance.pth) |```scripts/cityscapes/hrnet/run_h_48_d_4_contrast.sh```|
| Backbone | Model | Train Set | Val Set | Iterations | Batch Size | Contrast Loss | Memory | mIoU | Log | CKPT |Script |
| --------- | ---------- | --------- | ------- | ---------- | ---------- | ------------- | ------ | ----- | --- | ---- | ---- |
| ResNet-101| DeepLab-V3 |train | val | 40000 | 8 | N | N | 72.75 | [log](https://github.com/tfzhou/pretrained_weights/releases/download/v0.1/deeplab_v3_deepbase_resnet101_dilated8_deeplab_v3.log) | [ckpt](https://github.com/tfzhou/pretrained_weights/releases/download/v0.1/deeplab_v3_deepbase_resnet101_dilated8_deeplab_v3_max_performance.pth) |```scripts/cityscapes/deeplab/run_r_101_d_8_deeplabv3_train.sh```|
| ResNet-101| DeepLab-V3 |train | val | 40000 | 8 | Y | N | 77.67 | [log](https://github.com/tfzhou/pretrained_weights/releases/download/v0.1/deeplab_v3_contrast_deepbase_resnet101_dilated8_deeplab_v3_contrast.log) | [ckpt](https://github.com/tfzhou/pretrained_weights/releases/download/v0.1/deeplab_v3_contrast_deepbase_resnet101_dilated8_deeplab_v3_contrast_max_performance.pth) |```scripts/cityscapes/deeplab/run_r_101_d_8_deeplabv3_contrast_train.sh```|
| HRNet-W48 | HRNet-W48 |train | val | 40000 | 8 | N | N | 79.27 | [log](https://github.com/tfzhou/pretrained_weights/releases/download/v0.1/hrnet_w48_lr1x_hrnet_ce.log) | [ckpt](https://github.com/tfzhou/pretrained_weights/releases/download/v0.1/hrnet_w48_lr1x_hrnet_ce_max_performance.pth) |```scripts/cityscapes/hrnet/run_h_48_d_4.sh```|
| HRNet-W48 | HRNet-W48 |train | val | 40000 | 8 | Y | N | 80.18 | [log](https://github.com/tfzhou/pretrained_weights/releases/download/v0.1/hrnet_w48_contrast_lr1x_hrnet_contrast_t0.1.log) | [ckpt](https://github.com/tfzhou/pretrained_weights/releases/download/v0.1/hrnet_w48_contrast_lr1x_hrnet_contrast_t0.1_max_performance.pth) |```scripts/cityscapes/hrnet/run_h_48_d_4_contrast.sh```|
_It seems that the DeepLab-V3 baseline does not produce the expected performance on the new codebase. I will tune this later._


### Study of the temperature
Expand Down
11 changes: 11 additions & 0 deletions configs/cityscapes/R_101_D_8.json
Expand Up @@ -132,5 +132,16 @@
"ohem_minkeep": 100000,
"ohem_thresh": 0.7
}
},
"contrast": {
"proj_dim": 256,
"temperature": 0.1,
"base_temperature": 0.07,
"max_samples": 1024,
"max_views": 100,
"stride": 8,
"warmup_iters": 5000,
"loss_weight": 0.1,
"use_rmi": false
}
}
14 changes: 7 additions & 7 deletions lib/loss/loss_contrast.py
Expand Up @@ -210,25 +210,25 @@ def __init__(self, configer=None):

self.contrast_criterion = PixelContrastLoss(configer=configer)

def forward(self, preds, target):
def forward(self, preds, target, with_embed=False):
h, w = target.size(1), target.size(2)

assert "seg" in preds
assert "seg_aux" in preds
assert "embed" in preds

seg = preds['seg']
seg_aux = preds['seg_aux']

embedding = preds['embedding'] if 'embedding' in preds else None
embedding = preds['embed']

pred = F.interpolate(input=seg, size=(h, w), mode='bilinear', align_corners=True)
pred_aux = F.interpolate(input=seg_aux, size=(h, w), mode='bilinear', align_corners=True)
loss = self.seg_criterion([pred_aux, pred], target)

if embedding is not None:
_, predict = torch.max(seg, 1)
_, predict = torch.max(seg, 1)
loss_contrast = self.contrast_criterion(embedding, target, predict)

loss_contrast = self.contrast_criterion(embedding, target, predict)
if with_embed is True:
return loss + self.loss_weight * loss_contrast

return loss
return loss + 0 * loss_contrast # just a trick to avoid errors in distributed training
7 changes: 4 additions & 3 deletions lib/models/model_manager.py
Expand Up @@ -25,7 +25,7 @@

# HRNet
from lib.models.nets.hrnet import HRNet_W48, HRNet_W48_CONTRAST
from lib.models.nets.hrnet import HRNet_W48_OCR, HRNet_W48_ASPOCR, HRNet_W48_OCR_B, HRNet_W48_OCR_B_HA
from lib.models.nets.hrnet import HRNet_W48_OCR, HRNet_W48_OCR_B, HRNet_W48_OCR_B_HA, HRNet_W48_OCR_CONTRAST

# OCNet
from lib.models.nets.ocnet import BaseOCNet, AspOCNet
Expand All @@ -41,7 +41,7 @@

from lib.utils.tools.logger import Logger as Log

from lib.models.nets.deeplab import DeepLabV3, DeepLabV3_MobileNet, DeepLabV3_MobileNetV3, DeepLabV3_MobileNetV1
from lib.models.nets.deeplab import DeepLabV3, DeepLabV3Contrast

from lib.models.nets.ms_ocrnet import MscaleOCR

Expand All @@ -66,15 +66,16 @@
'hrnet_w48': HRNet_W48,
'hrnet_w48_ocr': HRNet_W48_OCR,
'hrnet_w48_ocr_b': HRNet_W48_OCR_B,
'hrnet_w48_asp_ocr': HRNet_W48_ASPOCR,
# CE2P series
'ce2p_asp_ocrnet': CE2P_ASPOCR,
'ce2p_ocrnet': CE2P_OCRNet,
'ce2p_ideal_ocrnet': CE2P_IdealOCRNet,
# baseline series
'fcnet': FcnNet,
'hrnet_w48_contrast': HRNet_W48_CONTRAST,
'hrnet_w48_ocr_contrast': HRNet_W48_OCR_CONTRAST,
'deeplab_v3': DeepLabV3,
'deeplab_v3_contrast': DeepLabV3Contrast,
'ms_ocr': MscaleOCR,
'hrnet_w48_ocr_b_ha': HRNet_W48_OCR_B_HA,
}
Expand Down
81 changes: 22 additions & 59 deletions lib/models/nets/deeplab.py
@@ -1,81 +1,44 @@
import torch.nn as nn

from lib.models.backbones.backbone_selector import BackboneSelector
from lib.models.modules.decoder_block import DeepLabHead, DeepLabHead_MobileNet, DeepLabHead_MobileNet_V3, DeepLabHead_MobileNet_V1
from lib.models.modules.decoder_block import DeepLabHead
from lib.models.modules.projection import ProjectionHead

class DeepLabV3_MobileNetV1(nn.Module):
def __init__(self, configer):
super(DeepLabV3_MobileNetV1, self).__init__()

self.configer = configer
self.num_classes = self.configer.get('data', 'num_classes')
self.backbone = BackboneSelector(configer).get_backbone()

self.decoder = DeepLabHead_MobileNet_V1(num_classes=self.num_classes,
bn_type=self.configer.get('network', 'bn_type'))

for m in self.decoder.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()

def forward(self, x_):
x = self.backbone(x_)

x = self.decoder(x)

return x

class DeepLabV3_MobileNetV3(nn.Module):
class DeepLabV3Contrast(nn.Module):
def __init__(self, configer):
super(DeepLabV3_MobileNetV3, self).__init__()
super(DeepLabV3Contrast, self).__init__()

self.configer = configer
self.num_classes = self.configer.get('data', 'num_classes')
self.backbone = BackboneSelector(configer).get_backbone()
self.proj_dim = self.configer.get('contrast', 'proj_dim')

self.decoder = DeepLabHead_MobileNet_V3(num_classes=self.num_classes,
bn_type=self.configer.get('network', 'bn_type'))

for m in self.decoder.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()

def forward(self, x_):
x = self.backbone(x_)

x = self.decoder(x)
# extra added layers
if "wide_resnet38" in self.configer.get('network', 'backbone'):
in_channels = [2048, 4096]
else:
in_channels = [1024, 2048]

return x
self.proj_head = ProjectionHead(dim_in=in_channels[1], proj_dim=self.proj_dim)

self.decoder = DeepLabHead(num_classes=self.num_classes, bn_type=self.configer.get('network', 'bn_type'))

class DeepLabV3_MobileNet(nn.Module):
def __init__(self, configer):
super(DeepLabV3_MobileNet, self).__init__()

self.configer = configer
self.num_classes = self.configer.get('data', 'num_classes')
self.backbone = BackboneSelector(configer).get_backbone()

self.decoder = DeepLabHead_MobileNet(num_classes=self.num_classes,
bn_type=self.configer.get('network', 'bn_type'))

for m in self.decoder.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
for modules in [self.proj_head, self.decoder]:
for m in modules.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()

def forward(self, x_):
def forward(self, x_, with_embed=False, is_eval=False):
x = self.backbone(x_)

x = self.decoder(x)
embedding = self.proj_head(x[-1])

return x
x = self.decoder(x[-4:])

return {'embed': embedding, 'seg_aux': x[1], 'seg': x[0]}

class DeepLabV3(nn.Module):
def __init__(self, configer):
Expand Down
54 changes: 33 additions & 21 deletions lib/models/nets/hrnet.py
Expand Up @@ -66,6 +66,7 @@ def __init__(self, configer):
self.configer = configer
self.num_classes = self.configer.get('data', 'num_classes')
self.backbone = BackboneSelector(configer).get_backbone()
self.proj_dim = self.configer.get('contrast', 'proj_dim')

# extra added layers
in_channels = 720 # 48 + 96 + 192 + 384
Expand All @@ -76,7 +77,7 @@ def __init__(self, configer):
nn.Conv2d(in_channels, self.num_classes, kernel_size=1, stride=1, padding=0, bias=False)
)

self.proj_head = ProjectionHead(dim_in=in_channels)
self.proj_head = ProjectionHead(dim_in=in_channels, proj_dim=self.proj_dim)

def forward(self, x_, with_embed=False, is_eval=False):
x = self.backbone(x_)
Expand All @@ -94,31 +95,38 @@ def forward(self, x_, with_embed=False, is_eval=False):
return {'seg': out, 'embed': emb}


class HRNet_W48_ASPOCR(nn.Module):
class HRNet_W48_OCR_CONTRAST(nn.Module):
def __init__(self, configer):
super(HRNet_W48_ASPOCR, self).__init__()
super(HRNet_W48_OCR_CONTRAST, self).__init__()
self.configer = configer
self.num_classes = self.configer.get('data', 'num_classes')
self.backbone = BackboneSelector(configer).get_backbone()
self.proj_dim = self.configer.get('contrast', 'proj_dim')

# extra added layers
in_channels = 720 # 48 + 96 + 192 + 384
from lib.models.modules.spatial_ocr_block import SpatialOCR_ASP_Module
self.asp_ocr_head = SpatialOCR_ASP_Module(features=720,
hidden_features=256,
out_features=256,
dilations=(24, 48, 72),
num_classes=self.num_classes,
bn_type=self.configer.get('network', 'bn_type'))

self.cls_head = nn.Conv2d(256, self.num_classes, kernel_size=1, stride=1, padding=0, bias=False)
self.aux_head = nn.Sequential(
in_channels = 720
self.conv3x3 = nn.Sequential(
nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1),
ModuleHelper.BNReLU(512, bn_type=self.configer.get('network', 'bn_type')),
nn.Conv2d(512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=False)
)
from lib.models.modules.spatial_ocr_block import SpatialGather_Module
self.ocr_gather_head = SpatialGather_Module(self.num_classes)
from lib.models.modules.spatial_ocr_block import SpatialOCR_Module
self.ocr_distri_head = SpatialOCR_Module(in_channels=512,
key_channels=256,
out_channels=512,
scale=1,
dropout=0.05,
bn_type=self.configer.get('network', 'bn_type'))
self.cls_head = nn.Conv2d(512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True)
self.aux_head = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
ModuleHelper.BNReLU(in_channels, bn_type=self.configer.get('network', 'bn_type')),
nn.Conv2d(in_channels, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True)
)

def forward(self, x_):
self.proj_head = ProjectionHead(dim_in=in_channels, proj_dim=self.proj_dim)

def forward(self, x_, with_embed=False, is_eval=False):
x = self.backbone(x_)
_, _, h, w = x[0].size()

Expand All @@ -130,12 +138,16 @@ def forward(self, x_):
feats = torch.cat([feat1, feat2, feat3, feat4], 1)
out_aux = self.aux_head(feats)

feats = self.asp_ocr_head(feats, out_aux)
emb = self.proj_head(feats)

feats = self.conv3x3(feats)

context = self.ocr_gather_head(feats, out_aux)
feats = self.ocr_distri_head(feats, context)

out = self.cls_head(feats)

out_aux = F.interpolate(out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True)
out = F.interpolate(out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True)
return out_aux, out
return {'seg': out, 'seg_aux': out_aux, 'embed': emb}


class HRNet_W48_OCR(nn.Module):
Expand Down
14 changes: 4 additions & 10 deletions scripts/cityscapes/deeplab/job_run_r_101_d_8_deeplabv3.sh
@@ -1,26 +1,20 @@
#!/bin/bash
#BSUB -n 16
#BSUB -W 72:00
#BSUB -W 24:00
#BSUB -R "rusage[mem=4000,ngpus_excl_p=4,scratch=10000]"
#BSUB -R "select[gpu_model0=TITANRTX]"
#BSUB -R "select[gpu_model0=GeForceRTX2080Ti]"
#BSUB -J "deeplab_v3"
#BSUB -B
#BSUB -N
#BSUB -oo logs/

# activate env
#source /cluster/home/tiazhou/miniconda3/etc/profile.d/conda.sh
#conda activate pytorch-1.7.1

source ../../../pytorch-1.7.1/bin/activate
source ../../../../pytorch-1.7.1/bin/activate

# copy data
rsync -aP /cluster/work/cvl/tiazhou/data/CityscapesZIP/openseg.tar ${TMPDIR}/
mkdir ${TMPDIR}/Cityscapes/
tar -xf ${TMPDIR}/openseg.tar -C ${TMPDIR}/Cityscapes/

ls -l ${TMPDIR}/Cityscapes/train/label | wc -l
ls -l ${TMPDIR}/Cityscapes/val/label | wc -l
tar -xf ${TMPDIR}/openseg.tar -C ${TMPDIR}/Cityscapes

# copy assets
rsync -aP /cluster/work/cvl/tiazhou/assets/openseg/resnet101-imagenet.pth ${TMPDIR}/resnet101-imagenet.pth
Expand Down
25 changes: 25 additions & 0 deletions scripts/cityscapes/deeplab/job_run_r_101_d_8_deeplabv3_contrast.sh
@@ -0,0 +1,25 @@
#!/bin/bash
#BSUB -n 16
#BSUB -W 24:00
#BSUB -R "rusage[mem=4000,ngpus_excl_p=4,scratch=10000]"
#BSUB -R "select[gpu_model0=GeForceRTX2080Ti]"
#BSUB -J "deeplab_v3_contrast"
#BSUB -B
#BSUB -N
#BSUB -oo logs/


source ../../../../pytorch-1.7.1/bin/activate

# copy data
rsync -aP /cluster/work/cvl/tiazhou/data/CityscapesZIP/openseg.tar ${TMPDIR}/
mkdir ${TMPDIR}/Cityscapes/
tar -xf ${TMPDIR}/openseg.tar -C ${TMPDIR}/Cityscapes

# copy assets
rsync -aP /cluster/work/cvl/tiazhou/assets/openseg/resnet101-imagenet.pth ${TMPDIR}/resnet101-imagenet.pth

# define scratch dir
SCRATCH_DIR="/cluster/scratch/tiazhou/Openseg"

sh run_r_101_d_8_deeplabv3_contrast_train.sh train 'deeplab_v3_contrast' ${TMPDIR} ${SCRATCH_DIR} 'ss'

0 comments on commit 3101207

Please sign in to comment.