-
Notifications
You must be signed in to change notification settings - Fork 0
/
encoder.py
130 lines (114 loc) · 4.34 KB
/
encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch.nn as nn
from torchvision import models
from utils import normalize_imagenet
def get_encoder(model_name,c_dim,normalize = True,use_linear = True):
if model_name == 'resnet18':
return Resnet18(c_dim,normalize,use_linear)
elif model_name == 'resnet34':
return Resnet34(c_dim,normalize,use_linear)
elif model_name == 'resnet50':
return Resnet50(c_dim,normalize,use_linear)
elif model_name == 'resnet101':
return Resnet101(c_dim,normalize,use_linear)
else:
raise ValueError('The network name must be resnet* where * could be 18,34,50,101')
class Resnet18(nn.Module):
r''' ResNet-18 encoder network for image input.
Args:
c_dim (int): output dimension of the latent embedding
normalize (bool): whether the input images should be normalized
use_linear (bool): whether a final linear layer should be used
'''
def __init__(self, c_dim, normalize=True, use_linear=True):
super().__init__()
self.normalize = normalize
self.use_linear = use_linear
self.features = models.resnet18(pretrained=True)
self.features.fc = nn.Sequential()
if use_linear:
self.fc = nn.Linear(512, c_dim)
elif c_dim == 512:
self.fc = nn.Sequential()
else:
raise ValueError('c_dim must be 512 if use_linear is False')
def forward(self, x):
if self.normalize:
x = normalize_imagenet(x)
net = self.features(x)
out = self.fc(net)
return out
class Resnet34(nn.Module):
r''' ResNet-34 encoder network.
Args:
c_dim (int): output dimension of the latent embedding
normalize (bool): whether the input images should be normalized
use_linear (bool): whether a final linear layer should be used
'''
def __init__(self, c_dim, normalize=True, use_linear=True):
super().__init__()
self.normalize = normalize
self.use_linear = use_linear
self.features = models.resnet34(pretrained=True)
self.features.fc = nn.Sequential()
if use_linear:
self.fc = nn.Linear(512, c_dim)
elif c_dim == 512:
self.fc = nn.Sequential()
else:
raise ValueError('c_dim must be 512 if use_linear is False')
def forward(self, x):
if self.normalize:
x = normalize_imagenet(x)
net = self.features(x)
out = self.fc(net)
return out
class Resnet50(nn.Module):
r''' ResNet-50 encoder network.
Args:
c_dim (int): output dimension of the latent embedding
normalize (bool): whether the input images should be normalized
use_linear (bool): whether a final linear layer should be used
'''
def __init__(self, c_dim, normalize=True, use_linear=True):
super().__init__()
self.normalize = normalize
self.use_linear = use_linear
self.features = models.resnet50(pretrained=True)
self.features.fc = nn.Sequential()
if use_linear:
self.fc = nn.Linear(2048, c_dim)
elif c_dim == 2048:
self.fc = nn.Sequential()
else:
raise ValueError('c_dim must be 2048 if use_linear is False')
def forward(self, x):
if self.normalize:
x = normalize_imagenet(x)
net = self.features(x)
out = self.fc(net)
return out
class Resnet101(nn.Module):
r''' ResNet-101 encoder network.
Args:
c_dim (int): output dimension of the latent embedding
normalize (bool): whether the input images should be normalized
use_linear (bool): whether a final linear layer should be used
'''
def __init__(self, c_dim, normalize=True, use_linear=True):
super().__init__()
self.normalize = normalize
self.use_linear = use_linear
self.features = models.resnet50(pretrained=True)
self.features.fc = nn.Sequential()
if use_linear:
self.fc = nn.Linear(2048, c_dim)
elif c_dim == 2048:
self.fc = nn.Sequential()
else:
raise ValueError('c_dim must be 2048 if use_linear is False')
def forward(self, x):
if self.normalize:
x = normalize_imagenet(x)
net = self.features(x)
out = self.fc(net)
return out