Skip to content

Commit

Permalink
train SH
Browse files Browse the repository at this point in the history
  • Loading branch information
tensorboy committed Aug 14, 2018
1 parent d08f4ec commit 6852b84
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 201 deletions.
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -51,6 +51,7 @@ Code repo for reproducing 2017 CVPR Oral paper using pytorch.
- Download the official training format at [Dropbox](https://www.dropbox.com/s/0sj2q24hipiiq5t/COCO.json?dl=0)
- `python train_VGG19.py --batch_size 100 --logdir {where to store tensorboardX logs}`
- `python train_ShuffleNetV2.py --batch_size 160 --logdir {where to store tensorboardX logs}`
- `python train_SH.py --batch_size 160 --logdir {where to store tensorboardX logs}`
## Related repository
- CVPR'16, [Convolutional Pose Machines](https://github.com/shihenw/convolutional-pose-machines-release).
- CVPR'17, [Realtime Multi-Person Pose Estimation](https://github.com/ZheC/Realtime_Multi-Person_Pose_Estimation).
Expand Down
100 changes: 65 additions & 35 deletions network/rtpose_hourglass.py
@@ -1,6 +1,11 @@
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
from collections import OrderedDict

NUM_JOINTS = 18
NUM_LIMBS = 38
class Bottleneck(nn.Module):
expansion = 2

Expand Down Expand Up @@ -52,7 +57,7 @@ def __init__(self, block, num_blocks, planes, depth):
def _make_residual(self, block, num_blocks, planes):
layers = []
for i in range(0, num_blocks):
layers.append(block(planes*block.expansion, planes))
layers.append(block(planes * block.expansion, planes))
return nn.Sequential(*layers)

def _make_hour_glass(self, block, num_blocks, planes, depth):
Expand All @@ -67,15 +72,15 @@ def _make_hour_glass(self, block, num_blocks, planes, depth):
return nn.ModuleList(hg)

def _hour_glass_forward(self, n, x):
up1 = self.hg[n-1][0](x)
up1 = self.hg[n - 1][0](x)
low1 = F.max_pool2d(x, 2, stride=2)
low1 = self.hg[n-1][1](low1)
low1 = self.hg[n - 1][1](low1)

if n > 1:
low2 = self._hour_glass_forward(n-1, low1)
low2 = self._hour_glass_forward(n - 1, low1)
else:
low2 = self.hg[n-1][3](low1)
low3 = self.hg[n-1][2](low2)
low2 = self.hg[n - 1][3](low1)
low3 = self.hg[n - 1][2](low2)
up2 = self.upsample(low3)
out = up1 + up2
return out
Expand All @@ -86,39 +91,49 @@ def forward(self, x):

class HourglassNet(nn.Module):
'''Hourglass model from Newell et al ECCV 2016'''
def __init__(self, block, num_stacks=2, num_blocks=4, num_classes=256):

def __init__(self, block, num_stacks=2, num_blocks=4, paf_classes=NUM_LIMBS*2, ht_classes=NUM_JOINTS+1):
super(HourglassNet, self).__init__()

self.inplanes = 64
self.num_feats = 128
self.num_stacks = num_stacks
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=True)
self.bn1 = nn.BatchNorm2d(self.inplanes)
self.bn1 = nn.BatchNorm2d(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_residual(block, self.inplanes, 1)
self.layer2 = self._make_residual(block, self.inplanes, 1)
self.layer3 = self._make_residual(block, self.num_feats, 1)
self.maxpool = nn.MaxPool2d(2, stride=2)

# build hourglass modules
ch = self.num_feats*block.expansion
hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
ch = self.num_feats * block.expansion
hg, res, fc, score_paf, score_ht, fc_, paf_score_, ht_score_ = \
[], [], [], [], [], [], [], []
for i in range(num_stacks):
hg.append(Hourglass(block, num_blocks, self.num_feats, 4))
res.append(self._make_residual(block, self.num_feats, num_blocks))
fc.append(self._make_fc(ch, ch))
score.append(nn.Conv2d(ch, num_classes, kernel_size=1, bias=True))
if i < num_stacks-1:
score_paf.append(nn.Conv2d(ch, paf_classes, kernel_size=1, bias=True))
score_ht.append(nn.Conv2d(ch, ht_classes, kernel_size=1, bias=True))
if i < num_stacks - 1:
fc_.append(nn.Conv2d(ch, ch, kernel_size=1, bias=True))
score_.append(nn.Conv2d(num_classes, ch, kernel_size=1, bias=True))
paf_score_.append(nn.Conv2d(paf_classes, ch,
kernel_size=1, bias=True))
ht_score_.append(nn.Conv2d(ht_classes, ch,
kernel_size=1, bias=True))
self.hg = nn.ModuleList(hg)
self.res = nn.ModuleList(res)
self.fc = nn.ModuleList(fc)
self.score = nn.ModuleList(score)
self.fc_ = nn.ModuleList(fc_)
self.score_ = nn.ModuleList(score_)

self.score_ht = nn.ModuleList(score_ht)
self.score_paf = nn.ModuleList(score_paf)
self.fc_ = nn.ModuleList(fc_)
self.paf_score_ = nn.ModuleList(paf_score_)
self.ht_score_ = nn.ModuleList(ht_score_)

self._initialize_weights_norm()

def _make_residual(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
Expand All @@ -139,37 +154,52 @@ def _make_fc(self, inplanes, outplanes):
bn = nn.BatchNorm2d(inplanes)
conv = nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=True)
return nn.Sequential(
conv,
bn,
self.relu,
)
conv,
bn,
self.relu,
)

def forward(self, x):
saved_for_loss = []
out = []
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.relu(x)

x = self.layer1(x)
x = self.layer1(x)
x = self.maxpool(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer2(x)
x = self.layer3(x)

for i in range(self.num_stacks):
y = self.hg[i](x)
y = self.res[i](y)
y = self.fc[i](y)
score = self.score[i](y)
out.append(score)
if i < self.num_stacks-1:
score_paf = self.score_paf[i](y)
score_ht = self.score_ht[i](y)
if i < self.num_stacks - 1:
fc_ = self.fc_[i](y)
score_ = self.score_[i](score)
x = x + fc_ + score_

return out

paf_score_ = self.paf_score_[i](score_paf)
ht_score_ = self.ht_score_[i](score_ht)
x = x + fc_ + paf_score_ + ht_score_

saved_for_loss.append(score_paf)
saved_for_loss.append(score_ht)

return (score_paf, score_ht), saved_for_loss

def _initialize_weights_norm(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.normal_(m.weight, std=0.01)
if m.bias is not None: # mobilenet conv2d doesn't add bias
init.constant_(m.bias, 0.0)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

def hg(**kwargs):
model = HourglassNet(Bottleneck, num_stacks=kwargs['num_stacks'], num_blocks=kwargs['num_blocks'],
num_classes=kwargs['num_classes'])
model = HourglassNet(Bottleneck, num_stacks=kwargs['num_stacks'],
num_blocks=kwargs['num_blocks'], paf_classes=kwargs['paf_classes'],
ht_classes=kwargs['ht_classes'])
return model

0 comments on commit 6852b84

Please sign in to comment.