/
gman.py
333 lines (316 loc) · 10.1 KB
/
gman.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
import argparse
import torch
import torch.utils.data as data
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.optim import Adam
from torchgan.losses import MinimaxDiscriminatorLoss, MinimaxGeneratorLoss
from torchgan.models import DCGANDiscriminator, DCGANGenerator
from torchgan.trainer import ParallelTrainer, Trainer
class MultiDiscriminatorMinimaxLoss(MinimaxDiscriminatorLoss):
def __init__(self, *args, lambd=0.001, **kwargs):
super(MultiDiscriminatorMinimaxLoss, self).__init__(*args, **kwargs)
self.lambd = lambd
def train_ops(
self,
generator,
discriminator1,
discriminator2,
discriminator3,
discriminator4,
discriminator5,
optimizer_discriminator1,
optimizer_discriminator2,
optimizer_discriminator3,
optimizer_discriminator4,
optimizer_discriminator5,
real_inputs,
device,
):
batch_size = real_inputs.size(0)
noise = torch.randn(batch_size, generator.encoding_dims, device=device)
optimizer_discriminator1.zero_grad()
optimizer_discriminator2.zero_grad()
optimizer_discriminator3.zero_grad()
optimizer_discriminator4.zero_grad()
optimizer_discriminator5.zero_grad()
fake = generator(noise).detach()
dx1 = discriminator1(real_inputs)
dx2 = discriminator2(real_inputs)
dx3 = discriminator3(real_inputs)
dx4 = discriminator4(real_inputs)
dx5 = discriminator5(real_inputs)
dgz1 = discriminator1(fake)
dgz2 = discriminator2(fake)
dgz3 = discriminator3(fake)
dgz4 = discriminator4(fake)
dgz5 = discriminator5(fake)
V1 = self.forward(dx1, dgz1)
V2 = self.forward(dx2, dgz2)
V3 = self.forward(dx3, dgz3)
V4 = self.forward(dx4, dgz4)
V5 = self.forward(dx5, dgz5)
exp_V1 = torch.exp(self.lambd * V1)
exp_V2 = torch.exp(self.lambd * V2)
exp_V3 = torch.exp(self.lambd * V3)
exp_V4 = torch.exp(self.lambd * V4)
exp_V5 = torch.exp(self.lambd * V5)
loss = (exp_V1 * V1 + exp_V2 * V2 + exp_V3 * V3 + exp_V4 * V4 +
exp_V5 * V5) / (exp_V1 + exp_V2 + exp_V3 + exp_V4 + exp_V5)
loss.backward()
optimizer_discriminator1.step()
optimizer_discriminator2.step()
optimizer_discriminator3.step()
optimizer_discriminator4.step()
optimizer_discriminator5.step()
return loss.item()
class MultiDiscriminatorGeneratorLoss(MinimaxGeneratorLoss):
def __init__(self, *args, lambd=0.001, **kwargs):
super(MultiDiscriminatorGeneratorLoss, self).__init__(*args, **kwargs)
self.lambd = lambd
def train_ops(
self,
generator,
discriminator1,
discriminator2,
discriminator3,
discriminator4,
discriminator5,
optimizer_generator,
batch_size,
device,
):
noise = torch.randn(batch_size, generator.encoding_dims, device=device)
optimizer_generator.zero_grad()
fake = generator(noise)
dgz1 = discriminator1(fake)
dgz2 = discriminator2(fake)
dgz3 = discriminator3(fake)
dgz4 = discriminator4(fake)
dgz5 = discriminator5(fake)
V1 = self.forward(dgz1)
V2 = self.forward(dgz2)
V3 = self.forward(dgz3)
V4 = self.forward(dgz4)
V5 = self.forward(dgz5)
exp_V1 = torch.exp(self.lambd * V1)
exp_V2 = torch.exp(self.lambd * V2)
exp_V3 = torch.exp(self.lambd * V3)
exp_V4 = torch.exp(self.lambd * V4)
exp_V5 = torch.exp(self.lambd * V5)
loss = (exp_V1 * V1 + exp_V2 * V2 + exp_V3 * V3 + exp_V4 * V4 +
exp_V5 * V5) / (exp_V1 + exp_V2 + exp_V3 + exp_V4 + exp_V5)
loss.backward()
optimizer_generator.step()
return loss.item()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-f",
"--data_dir",
help="directory where mnist/cifar10 will be downloaded/is available",
default="./")
parser.add_argument(
"-d",
"--dataset",
choices=["mnist", "cifar10"],
help="Dataset to use",
default="mnist")
parser.add_argument(
"-sc",
"--step_channels",
help="step channels for the generator and discriminators",
type=int,
default=16)
parser.add_argument(
"-lr",
"--learning_rate",
help="The learning rate for the optimizers",
type=float,
default=0.0002)
parser.add_argument(
"--cpu",
type=int,
help="Set it to 1 if cpu is to be used for training",
default=0)
parser.add_argument(
"-m",
"--multigpu",
choices=[0, 1],
type=int,
help="Choose 1 if multiple GPUs are available for training",
default=0)
parser.add_argument(
"-l",
"--list_gpus",
type=int,
nargs='+',
help="List of GPUs to be used for training. Used if -m is set to 1",
default=[0, 1])
parser.add_argument(
"-b",
"--batch_size",
type=int,
help="Batch Size for training",
default=32)
parser.add_argument(
"-s",
"--sample_size",
type=int,
help="Number of Images Generated per Epoch",
default=64)
parser.add_argument(
"-c",
"--checkpoint",
help="Place to store the trained model",
default="./gman_")
parser.add_argument(
"-r",
"--reconstructions",
help="Directory to store the generated images",
default="./gman_images")
parser.add_argument(
"-e",
"--epochs",
help="Total epochs for which the model will be trained",
default=20,
type=int)
args = parser.parse_args()
transformations = []
if args.dataset == "cifar10":
channels = 3
dataset = dsets.CIFAR10
norm = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
else:
channels = 1
dataset = dsets.MNIST
transformations.append(transforms.Resize((32, 32)))
norm = transforms.Normalize(mean=(0.5, ), std=(0.5, ))
transformations.append(transforms.ToTensor())
transformations.append(norm)
transformations = transforms.Compose(transformations)
network_configuration = {
"generator": {
"name": DCGANGenerator,
"args": {
"out_channels": channels,
"step_channels": args.step_channels
},
"optimizer": {
"name": Adam,
"args": {
"lr": args.learning_rate,
"betas": (0.5, 0.999)
}
},
},
"discriminator1": {
"name": DCGANDiscriminator,
"args": {
"in_channels": channels,
"step_channels": args.step_channels
},
"optimizer": {
"name": Adam,
"args": {
"lr": args.learning_rate,
"betas": (0.5, 0.999)
}
},
},
"discriminator2": {
"name": DCGANDiscriminator,
"args": {
"in_channels": channels,
"step_channels": args.step_channels
},
"optimizer": {
"name": Adam,
"args": {
"lr": args.learning_rate,
"betas": (0.5, 0.999)
}
},
},
"discriminator3": {
"name": DCGANDiscriminator,
"args": {
"in_channels": channels,
"step_channels": args.step_channels
},
"optimizer": {
"name": Adam,
"args": {
"lr": args.learning_rate,
"betas": (0.5, 0.999)
}
},
},
"discriminator4": {
"name": DCGANDiscriminator,
"args": {
"in_channels": channels,
"step_channels": args.step_channels
},
"optimizer": {
"name": Adam,
"args": {
"lr": args.learning_rate,
"betas": (0.5, 0.999)
}
},
},
"discriminator5": {
"name": DCGANDiscriminator,
"args": {
"in_channels": channels,
"step_channels": args.step_channels
},
"optimizer": {
"name": Adam,
"args": {
"lr": args.learning_rate,
"betas": (0.5, 0.999)
}
},
},
}
losses = [
MultiDiscriminatorGeneratorLoss(),
MultiDiscriminatorMinimaxLoss()
]
if args.cpu == 0 and args.multigpu == 1:
trainer = ParallelTrainer(
network_configuration,
losses,
args.list_gpus,
epochs=args.epochs,
sample_size=args.sample_size,
checkpoints=args.checkpoint,
retain_checkpoints=1,
recon=args.reconstructions,
)
else:
if args.cpu == 1:
device = torch.device("cpu")
else:
device = torch.device("cuda:0")
trainer = Trainer(
network_configuration,
losses,
device=device,
epochs=args.epochs,
sample_size=args.sample_size,
checkpoints=args.checkpoint,
retain_checkpoints=1,
recon=args.reconstructions,
)
train_dataset = dataset(
root=args.data_dir,
train=True,
download=True,
transform=transformations)
train_loader = data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True)
trainer(train_loader)
trainer.complete()