-
-
Notifications
You must be signed in to change notification settings - Fork 982
/
vae_comparison.py
275 lines (234 loc) · 8.88 KB
/
vae_comparison.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import argparse
import itertools
import os
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
from torch.nn import functional
from torchvision.utils import save_image
from utils.mnist_cached import DATA_DIR, RESULTS_DIR
import pyro
from pyro.contrib.examples import util
from pyro.distributions import Bernoulli, Normal
from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO
from pyro.optim import Adam
"""
Comparison of VAE implementation in PyTorch and Pyro. This example can be
used for profiling purposes.
The PyTorch VAE example is taken (with minor modification) from pytorch/examples.
Source: https://github.com/pytorch/examples/tree/master/vae
"""
TRAIN = "train"
TEST = "test"
OUTPUT_DIR = RESULTS_DIR
# VAE encoder network
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, 20)
self.fc22 = nn.Linear(400, 20)
self.relu = nn.ReLU()
def forward(self, x):
x = x.reshape(-1, 784)
h1 = self.relu(self.fc1(x))
return self.fc21(h1), torch.exp(self.fc22(h1))
# VAE Decoder network
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)
self.relu = nn.ReLU()
def forward(self, z):
h3 = self.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h3))
class VAE(object, metaclass=ABCMeta):
"""
Abstract class for the variational auto-encoder. The abstract method
for training the network is implemented by subclasses.
"""
def __init__(self, args, train_loader, test_loader):
self.args = args
self.vae_encoder = Encoder()
self.vae_decoder = Decoder()
self.train_loader = train_loader
self.test_loader = test_loader
self.mode = TRAIN
def set_train(self, is_train=True):
if is_train:
self.mode = TRAIN
self.vae_encoder.train()
self.vae_decoder.train()
else:
self.mode = TEST
self.vae_encoder.eval()
self.vae_decoder.eval()
@abstractmethod
def compute_loss_and_gradient(self, x):
"""
Given a batch of data `x`, run the optimizer (backpropagate the gradient),
and return the computed loss.
:param x: batch of data or a single datum (MNIST image).
:return: loss computed on the data batch.
"""
return
def model_eval(self, x):
"""
Given a batch of data `x`, run it through the trained VAE network to get
the reconstructed image.
:param x: batch of data or a single datum (MNIST image).
:return: reconstructed image, and the latent z's mean and variance.
"""
z_mean, z_var = self.vae_encoder(x)
if self.mode == TRAIN:
z = Normal(z_mean, z_var.sqrt()).rsample()
else:
z = z_mean
return self.vae_decoder(z), z_mean, z_var
def train(self, epoch):
self.set_train(is_train=True)
train_loss = 0
for batch_idx, (x, _) in enumerate(self.train_loader):
loss = self.compute_loss_and_gradient(x)
train_loss += loss
print(
"====> Epoch: {} \nTraining loss: {:.4f}".format(
epoch, train_loss / len(self.train_loader.dataset)
)
)
def test(self, epoch):
self.set_train(is_train=False)
test_loss = 0
for i, (x, _) in enumerate(self.test_loader):
with torch.no_grad():
recon_x = self.model_eval(x)[0]
test_loss += self.compute_loss_and_gradient(x)
if i == 0:
n = min(x.size(0), 8)
comparison = torch.cat(
[x[:n], recon_x.reshape(self.args.batch_size, 1, 28, 28)[:n]]
)
save_image(
comparison.detach().cpu(),
os.path.join(OUTPUT_DIR, "reconstruction_" + str(epoch) + ".png"),
nrow=n,
)
test_loss /= len(self.test_loader.dataset)
print("Test set loss: {:.4f}".format(test_loss))
class PyTorchVAEImpl(VAE):
"""
Adapted from pytorch/examples.
Source: https://github.com/pytorch/examples/tree/master/vae
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.optimizer = self.initialize_optimizer(lr=1e-3)
def compute_loss_and_gradient(self, x):
self.optimizer.zero_grad()
recon_x, z_mean, z_var = self.model_eval(x)
binary_cross_entropy = functional.binary_cross_entropy(
recon_x, x.reshape(-1, 784)
)
# Uses analytical KL divergence expression for D_kl(q(z|x) || p(z))
# Refer to Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# (https://arxiv.org/abs/1312.6114)
kl_div = -0.5 * torch.sum(1 + z_var.log() - z_mean.pow(2) - z_var)
kl_div /= self.args.batch_size * 784
loss = binary_cross_entropy + kl_div
if self.mode == TRAIN:
loss.backward()
self.optimizer.step()
return loss.item()
def initialize_optimizer(self, lr=1e-3):
model_params = itertools.chain(
self.vae_encoder.parameters(), self.vae_decoder.parameters()
)
return torch.optim.Adam(model_params, lr)
class PyroVAEImpl(VAE):
"""
Implementation of VAE using Pyro. Only the model and the guide specification
is needed to run the optimizer (the objective function does not need to be
specified as in the PyTorch implementation).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.optimizer = self.initialize_optimizer(lr=1e-3)
def model(self, data):
decoder = pyro.module("decoder", self.vae_decoder)
z_mean, z_std = torch.zeros([data.size(0), 20]), torch.ones([data.size(0), 20])
with pyro.plate("data", data.size(0)):
z = pyro.sample("latent", Normal(z_mean, z_std).to_event(1))
img = decoder.forward(z)
pyro.sample(
"obs",
Bernoulli(img, validate_args=False).to_event(1),
obs=data.reshape(-1, 784),
)
def guide(self, data):
encoder = pyro.module("encoder", self.vae_encoder)
with pyro.plate("data", data.size(0)):
z_mean, z_var = encoder.forward(data)
pyro.sample("latent", Normal(z_mean, z_var.sqrt()).to_event(1))
def compute_loss_and_gradient(self, x):
if self.mode == TRAIN:
loss = self.optimizer.step(x)
else:
loss = self.optimizer.evaluate_loss(x)
loss /= self.args.batch_size * 784
return loss
def initialize_optimizer(self, lr):
optimizer = Adam({"lr": lr})
elbo = JitTrace_ELBO() if self.args.jit else Trace_ELBO()
return SVI(self.model, self.guide, optimizer, loss=elbo)
def setup(args):
pyro.set_rng_seed(args.rng_seed)
train_loader = util.get_data_loader(
dataset_name="MNIST",
data_dir=DATA_DIR,
batch_size=args.batch_size,
is_training_set=True,
shuffle=True,
)
test_loader = util.get_data_loader(
dataset_name="MNIST",
data_dir=DATA_DIR,
batch_size=args.batch_size,
is_training_set=False,
shuffle=True,
)
global OUTPUT_DIR
OUTPUT_DIR = os.path.join(RESULTS_DIR, args.impl)
if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR)
pyro.clear_param_store()
return train_loader, test_loader
def main(args):
train_loader, test_loader = setup(args)
if args.impl == "pyro":
vae = PyroVAEImpl(args, train_loader, test_loader)
print("Running Pyro VAE implementation")
elif args.impl == "pytorch":
vae = PyTorchVAEImpl(args, train_loader, test_loader)
print("Running PyTorch VAE implementation")
else:
raise ValueError("Incorrect implementation specified: {}".format(args.impl))
for i in range(args.num_epochs):
vae.train(i)
if not args.skip_eval:
vae.test(i)
if __name__ == "__main__":
assert pyro.__version__.startswith("1.9.1")
parser = argparse.ArgumentParser(description="VAE using MNIST dataset")
parser.add_argument("-n", "--num-epochs", nargs="?", default=10, type=int)
parser.add_argument("--batch_size", nargs="?", default=128, type=int)
parser.add_argument("--rng_seed", nargs="?", default=0, type=int)
parser.add_argument("--impl", nargs="?", default="pyro", type=str)
parser.add_argument("--skip_eval", action="store_true")
parser.add_argument("--jit", action="store_true")
parser.set_defaults(skip_eval=False)
args = parser.parse_args()
main(args)