-
Notifications
You must be signed in to change notification settings - Fork 75
/
ls_gan.py
314 lines (242 loc) · 10.6 KB
/
ls_gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
""" (LS GAN) https://arxiv.org/abs/1611.04076
Least Squares GAN
The output of LSGAN's D is unbounded unless passed through an activation
function. In this implementation, we include a sigmoid activation function as
this empirically improves visualizations for binary MNIST.
Tackles the vanishing gradients problem associated with GANs by swapping out
the cross entropy loss function with the least squares (L2) loss function.
The authors show that minimizing this objective is equivalent to minimizing the
Pearson chi-squared divergence. They claim that using the L2 loss function
penalizes samples that appear to be real to the discriminator, but lie far away
from the decision boundary. In this way, the generated images are made to appear
closer to real data. It also stabilizes training.
"""
import torch, torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import os
import matplotlib.pyplot as plt
import numpy as np
from itertools import product
from tqdm import tqdm
from utils import *
class Generator(nn.Module):
""" Generator. Input is noise, output is a generated image.
"""
def __init__(self, image_size, hidden_dim, z_dim):
super().__init__()
self.linear = nn.Linear(z_dim, hidden_dim)
self.generate = nn.Linear(hidden_dim, image_size)
def forward(self, x):
activated = F.relu(self.linear(x))
generation = torch.sigmoid(self.generate(activated))
return generation
class Discriminator(nn.Module):
""" Critic (not trained to classify). Input is an image (real or generated),
output is approximate least-squares divergence.
"""
def __init__(self, image_size, hidden_dim, output_dim):
super().__init__()
self.linear = nn.Linear(image_size, hidden_dim)
self.discriminate = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
activated = F.relu(self.linear(x))
discrimination = torch.sigmoid(self.discriminate(activated))
return discrimination
class LSGAN(nn.Module):
""" Super class to contain both Discriminator (D) and Generator (G)
"""
def __init__(self, image_size, hidden_dim, z_dim, output_dim=1):
super().__init__()
self.__dict__.update(locals())
self.G = Generator(image_size, hidden_dim, z_dim)
self.D = Discriminator(image_size, hidden_dim, output_dim)
self.shape = int(image_size ** 0.5)
class LSGANTrainer:
""" Object to hold data iterators, train a GAN variant
"""
def __init__(self, model, train_iter, val_iter, test_iter, viz=False):
self.model = to_cuda(model)
self.name = model.__class__.__name__
self.train_iter = train_iter
self.val_iter = val_iter
self.test_iter = test_iter
self.Glosses = []
self.Dlosses = []
self.viz = viz
self.num_epochs = 0
def train(self, num_epochs, G_lr=1e-4, D_lr=1e-4, D_steps=1):
""" Train a least-squares GAN with Gradient Penalty
Logs progress using G loss, D loss, G(x), D(G(x)), visualizations
of Generator output.
Inputs:
num_epochs: int, number of epochs to train for
G_lr: float, learning rate for generator's Adam optimize
D_lr: float, learning rate for discriminator's Adam optimizer
D_steps: int, ratio for how often to train D compared to G
"""
# Initialize optimizers
G_optimizer = optim.Adam(params=[p for p in self.model.G.parameters()
if p.requires_grad], lr=G_lr)
D_optimizer = optim.Adam(params=[p for p in self.model.D.parameters()
if p.requires_grad], lr=D_lr)
# Approximate steps/epoch given D_steps per epoch -->
# roughly train in the same way as if D_step (1) == G_step (1)
epoch_steps = int(np.ceil(len(self.train_iter) / (D_steps)))
# Begin training
for epoch in tqdm(range(1, num_epochs+1)):
self.model.train()
G_losses, D_losses = [], []
for _ in range(epoch_steps):
D_step_loss = []
for _ in range(D_steps):
# Reshape images
images = self.process_batch(self.train_iter)
# TRAINING D: Zero out gradients for D
D_optimizer.zero_grad()
# Train D to approximate the distance between real, generated
D_loss = self.train_D(images)
# Update parameters
D_loss.backward()
D_optimizer.step()
# Log results, backpropagate the discriminator network
D_step_loss.append(D_loss.item())
# So that G_loss and D_loss have the same number of entries.
D_losses.append(np.mean(D_step_loss))
# TRAINING G: Zero out gradients for G
G_optimizer.zero_grad()
# Train the generator to (roughly) minimize the approximated
# least-squares distance
G_loss = self.train_G(images)
# Log results, update parameters
G_losses.append(G_loss.item())
G_loss.backward()
G_optimizer.step()
# Save progress
self.Glosses.extend(G_losses)
self.Dlosses.extend(D_losses)
# Progress logging
print ("Epoch[%d/%d], G Loss: %.4f, D Loss: %.4f"
%(epoch, num_epochs, np.mean(G_losses), np.mean(D_losses)))
self.num_epochs += 1
# Visualize generator progress
if self.viz:
self.generate_images(epoch)
plt.show()
def train_D(self, images, a=0, b=1):
""" Run 1 step of training for discriminator
Input:
images: batch of images (reshaped to [batch_size, -1])
Output:
D_loss: L2 loss for discriminator,
0.50 * E[(D(x) - a)^2] + 0.50 * E[(D(G(z)) - b)^2],
where a and b are labels for generated (0) and real (1) data
"""
# Sample noise, an output from the generator
noise = self.compute_noise(images.shape[0], self.model.z_dim)
G_output = self.model.G(noise)
# Use the discriminator to sample real, generated images
DX_score = self.model.D(images) # D(x)
DG_score = self.model.D(G_output) # D(G(z))
# Compute L2 loss for D
D_loss = (0.50 * torch.mean((DX_score - b)**2)) \
+ (0.50 * torch.mean((DG_score - a)**2))
return D_loss
def train_G(self, images, c=1):
""" Run 1 step of training for generator
Input:
images: batch of images (reshaped to [batch_size, -1])
Output:
G_loss: L2 loss for G,
0.50 * E[(D(G(z)) - c)^2],
where c is the label that G wants D to believe for fake data (1)
"""
# Get noise, classify it using G, then classify the output of G using D.
noise = self.compute_noise(images.shape[0], self.model.z_dim) # z
G_output = self.model.G(noise) # G(z)
DG_score = self.model.D(G_output) # D(G(z))
# Compute L2 loss for G
G_loss = 0.50 * torch.mean((DG_score - c)**2)
return G_loss
def compute_noise(self, batch_size, z_dim):
""" Compute random noise for input into the Generator G """
return to_cuda(torch.randn(batch_size, z_dim))
def process_batch(self, iterator):
""" Generate a process batch to be input into the Discriminator D """
images, _ = next(iter(iterator))
images = to_cuda(images.view(images.shape[0], -1))
return images
def generate_images(self, epoch, num_outputs=36, save=True):
""" Visualize progress of generator learning """
# Turn off any regularization
self.model.eval()
# Sample noise vector
noise = self.compute_noise(num_outputs, self.model.z_dim)
# Transform noise to image
images = self.model.G(noise)
# Reshape to proper image size
images = images.view(images.shape[0],
self.model.shape,
self.model.shape,
-1).squeeze()
# Plot
plt.close()
grid_size, k = int(num_outputs**0.5), 0
fig, ax = plt.subplots(grid_size, grid_size, figsize=(5, 5))
for i, j in product(range(grid_size), range(grid_size)):
ax[i,j].get_xaxis().set_visible(False)
ax[i,j].get_yaxis().set_visible(False)
ax[i,j].imshow(images[k].data.numpy(), cmap='gray')
k += 1
# Save images if desired
if save:
outname = '../viz/' + self.name + '/'
if not os.path.exists(outname):
os.makedirs(outname)
torchvision.utils.save_image(images.unsqueeze(1).data,
outname + 'reconst_%d.png'
%(epoch), nrow=grid_size)
def viz_loss(self):
""" Visualize loss for the generator, discriminator """
# Set style, figure size
plt.style.use('ggplot')
plt.rcParams["figure.figsize"] = (8,6)
# Plot Discriminator loss in red
plt.plot(np.linspace(1, self.num_epochs, len(self.Dlosses)),
self.Dlosses,
'r')
# Plot Generator loss in green
plt.plot(np.linspace(1, self.num_epochs, len(self.Dlosses)),
self.Glosses,
'g')
# Add legend, title
plt.legend(['Discriminator', 'Generator'])
plt.title(self.name)
plt.show()
def save_model(self, savepath):
""" Save model state dictionary """
torch.save(self.model.state_dict(), savepath)
def load_model(self, loadpath):
""" Load state dictionary into model """
state = torch.load(loadpath)
self.model.load_state_dict(state)
if __name__ == "__main__":
# Load in binarized MNIST data, separate into data loaders
train_iter, val_iter, test_iter = get_data()
# Init model
model = LSGAN(image_size=784,
hidden_dim=400,
z_dim=20)
# Init trainer
trainer = LSGANTrainer(model=model,
train_iter=train_iter,
val_iter=val_iter,
test_iter=test_iter,
viz=False)
# Train
trainer.train(num_epochs=25,
G_lr=1e-4,
D_lr=1e-4,
D_steps=1)