-
Notifications
You must be signed in to change notification settings - Fork 94
/
posenet.py
67 lines (57 loc) · 2.38 KB
/
posenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from torch import nn
from models.layers import Conv, Hourglass, Pool, Residual
from task.loss import HeatmapLoss
class UnFlatten(nn.Module):
def forward(self, input):
return input.view(-1, 256, 4, 4)
class Merge(nn.Module):
def __init__(self, x_dim, y_dim):
super(Merge, self).__init__()
self.conv = Conv(x_dim, y_dim, 1, relu=False, bn=False)
def forward(self, x):
return self.conv(x)
class PoseNet(nn.Module):
def __init__(self, nstack, inp_dim, oup_dim, bn=False, increase=0, **kwargs):
super(PoseNet, self).__init__()
self.nstack = nstack
self.pre = nn.Sequential(
Conv(3, 64, 7, 2, bn=True, relu=True),
Residual(64, 128),
Pool(2, 2),
Residual(128, 128),
Residual(128, inp_dim)
)
self.hgs = nn.ModuleList( [
nn.Sequential(
Hourglass(4, inp_dim, bn, increase),
) for i in range(nstack)] )
self.features = nn.ModuleList( [
nn.Sequential(
Residual(inp_dim, inp_dim),
Conv(inp_dim, inp_dim, 1, bn=True, relu=True)
) for i in range(nstack)] )
self.outs = nn.ModuleList( [Conv(inp_dim, oup_dim, 1, relu=False, bn=False) for i in range(nstack)] )
self.merge_features = nn.ModuleList( [Merge(inp_dim, inp_dim) for i in range(nstack-1)] )
self.merge_preds = nn.ModuleList( [Merge(oup_dim, inp_dim) for i in range(nstack-1)] )
self.nstack = nstack
self.heatmapLoss = HeatmapLoss()
def forward(self, imgs):
## our posenet
x = imgs.permute(0, 3, 1, 2) #x of size 1,3,inpdim,inpdim
x = self.pre(x)
combined_hm_preds = []
for i in range(self.nstack):
hg = self.hgs[i](x)
feature = self.features[i](hg)
preds = self.outs[i](feature)
combined_hm_preds.append(preds)
if i < self.nstack - 1:
x = x + self.merge_preds[i](preds) + self.merge_features[i](feature)
return torch.stack(combined_hm_preds, 1)
def calc_loss(self, combined_hm_preds, heatmaps):
combined_loss = []
for i in range(self.nstack):
combined_loss.append(self.heatmapLoss(combined_hm_preds[0][:,i], heatmaps))
combined_loss = torch.stack(combined_loss, dim=1)
return combined_loss