/
bglow.py
763 lines (601 loc) · 24.7 KB
/
bglow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def compute_same_pad(kernel_size, stride):
if isinstance(kernel_size, int):
kernel_size = [kernel_size]
if isinstance(stride, int):
stride = [stride]
assert len(stride) == len(
kernel_size
), "Pass kernel size and stride both as int, or both as equal length iterable"
return [((k - 1) * s + 1) // 2 for k, s in zip(kernel_size, stride)]
def uniform_binning_correction(x, n_bits=8):
"""Replaces x^i with q^i(x) = U(x, x + 1.0 / 256.0).
Args:
x: 4-D Tensor of shape (NCHW)
n_bits: optional.
Returns:
x: x ~ U(x, x + 1.0 / 256)
objective: Equivalent to -q(x)*log(q(x)).
"""
b, c, h, w = x.size()
n_bins = 2**n_bits
chw = c * h * w
x += torch.zeros_like(x).uniform_(0, 1.0 / n_bins)
objective = -math.log(n_bins) * chw * torch.ones(b, device=x.device)
return x, objective
def split_feature(tensor, type="split"):
"""
type = ["split", "cross"]
"""
C = tensor.size(1)
if type == "split":
return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...]
elif type == "cross":
return tensor[:, 0::2, ...], tensor[:, 1::2, ...]
def gaussian_p(mean, logs, x):
"""
lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) }
k = 1 (Independent)
Var = logs ** 2
"""
c = math.log(2 * math.pi)
return -0.5 * (logs * 2.0 + ((x - mean)**2) / torch.exp(logs * 2.0) + c)
def gaussian_likelihood(mean, logs, x):
p = gaussian_p(mean, logs, x)
return torch.sum(p, dim=[1, 2, 3])
def gaussian_sample(mean, logs, temperature=1):
# Sample from Gaussian with temperature
z = torch.normal(mean, torch.exp(logs) * temperature)
return z
def squeeze2d(input, factor):
if factor == 1:
return input
B, C, H, W = input.size()
assert H % factor == 0 and W % factor == 0, "H or W modulo factor is not 0"
x = input.view(B, C, H // factor, factor, W // factor, factor)
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
x = x.view(B, C * factor * factor, H // factor, W // factor)
return x
def unsqueeze2d(input, factor):
if factor == 1:
return input
factor2 = factor**2
B, C, H, W = input.size()
assert C % (factor2) == 0, "C module factor squared is not 0"
x = input.view(B, C // factor2, factor, factor, H, W)
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
x = x.view(B, C // (factor2), H * factor, W * factor)
return x
class _ActNorm(nn.Module):
"""
Activation Normalization
Initialize the bias and scale with a given minibatch,
so that the output per-channel have zero mean and unit variance for that.
After initialization, `bias` and `logs` will be trained as parameters.
"""
def __init__(self, num_features, scale=1.0):
super().__init__()
# register mean and scale
size = [1, num_features, 1, 1]
self.bias = nn.Parameter(torch.zeros(*size))
self.logs = nn.Parameter(torch.zeros(*size))
self.num_features = num_features
self.scale = scale
self.inited = False
def initialize_parameters(self, input):
if not self.training:
raise ValueError("In Eval mode, but ActNorm not inited")
with torch.no_grad():
bias = -torch.mean(input.clone(), dim=[0, 2, 3], keepdim=True)
vars = torch.mean((input.clone() + bias)**2,
dim=[0, 2, 3],
keepdim=True)
logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6))
self.bias.data.copy_(bias.data)
self.logs.data.copy_(logs.data)
self.inited = True
def _center(self, input, reverse=False):
if reverse:
return input - self.bias
else:
return input + self.bias
def _scale(self, input, logdet=None, reverse=False):
if reverse:
input = input * torch.exp(-self.logs)
else:
input = input * torch.exp(self.logs)
if logdet is not None:
"""
logs is log_std of `mean of channels`
so we need to multiply by number of pixels
"""
b, c, h, w = input.shape
dlogdet = torch.sum(self.logs) * h * w
if reverse:
dlogdet *= -1
logdet = logdet + dlogdet
return input, logdet
def forward(self, input, logdet=None, reverse=False):
self._check_input_dim(input)
if not self.inited:
self.initialize_parameters(input)
if reverse:
input, logdet = self._scale(input, logdet, reverse)
input = self._center(input, reverse)
else:
input = self._center(input, reverse)
input, logdet = self._scale(input, logdet, reverse)
return input, logdet
class ActNorm2d(_ActNorm):
def __init__(self, num_features, scale=1.0):
super().__init__(num_features, scale)
def _check_input_dim(self, input):
assert len(input.size()) == 4
assert input.size(1) == self.num_features, (
"[ActNorm]: input should be in shape as `BCHW`,"
" channels should be {} rather than {}".format(
self.num_features, input.size()))
class LinearZeros(nn.Module):
def __init__(self, in_channels, out_channels, logscale_factor=3):
super().__init__()
self.linear = nn.Linear(in_channels, out_channels)
self.linear.weight.data.zero_()
self.linear.bias.data.zero_()
self.logscale_factor = logscale_factor
self.logs = nn.Parameter(torch.zeros(out_channels))
def forward(self, input):
output = self.linear(input)
return output * torch.exp(self.logs * self.logscale_factor)
class Conv2d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding="same",
do_actnorm=True,
weight_std=0.05,
):
super().__init__()
if padding == "same":
padding = compute_same_pad(kernel_size, stride)
elif padding == "valid":
padding = 0
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias=(not do_actnorm),
)
# init weight with std
self.conv.weight.data.normal_(mean=0.0, std=weight_std)
if not do_actnorm:
self.conv.bias.data.zero_()
else:
self.actnorm = ActNorm2d(out_channels)
self.do_actnorm = do_actnorm
def forward(self, input):
x = self.conv(input)
if self.do_actnorm:
x, _ = self.actnorm(x)
return x
class Conv2dZeros(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=(3, 3),
stride=(1, 1),
padding="same",
logscale_factor=3,
):
super().__init__()
if padding == "same":
padding = compute_same_pad(kernel_size, stride)
elif padding == "valid":
padding = 0
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
padding)
self.conv.weight.data.zero_()
self.conv.bias.data.zero_()
self.logscale_factor = logscale_factor
self.logs = nn.Parameter(torch.zeros(out_channels, 1, 1))
def forward(self, input):
output = self.conv(input)
return output * torch.exp(self.logs * self.logscale_factor)
class Permute2d(nn.Module):
def __init__(self, num_channels, shuffle):
super().__init__()
self.num_channels = num_channels
self.indices = torch.arange(self.num_channels - 1,
-1,
-1,
dtype=torch.long)
self.indices_inverse = torch.zeros((self.num_channels),
dtype=torch.long)
for i in range(self.num_channels):
self.indices_inverse[self.indices[i]] = i
if shuffle:
self.reset_indices()
def reset_indices(self):
shuffle_idx = torch.randperm(self.indices.shape[0])
self.indices = self.indices[shuffle_idx]
for i in range(self.num_channels):
self.indices_inverse[self.indices[i]] = i
def forward(self, input, reverse=False):
assert len(input.size()) == 4
if not reverse:
input = input[:, self.indices, :, :]
return input
else:
return input[:, self.indices_inverse, :, :]
class Split2d(nn.Module):
def __init__(self, num_channels):
super().__init__()
self.conv = Conv2dZeros(num_channels // 2, num_channels)
def split2d_prior(self, z):
h = self.conv(z)
return split_feature(h, "cross")
def forward(self, input, logdet=0.0, reverse=False, temperature=None):
if reverse:
z1 = input
mean, logs = self.split2d_prior(z1)
z2 = gaussian_sample(mean, logs, temperature)
z = torch.cat((z1, z2), dim=1)
return z, logdet
else:
z1, z2 = split_feature(input, "split")
mean, logs = self.split2d_prior(z1)
logdet = gaussian_likelihood(mean, logs, z2) + logdet
return z1, logdet
class SqueezeLayer(nn.Module):
def __init__(self, factor):
super().__init__()
self.factor = factor
def forward(self, input, logdet=None, reverse=False):
if reverse:
output = unsqueeze2d(input, self.factor)
else:
output = squeeze2d(input, self.factor)
return output, logdet
class InvertibleConv1x1(nn.Module):
def __init__(self, num_channels, LU_decomposed):
super().__init__()
w_shape = [num_channels, num_channels]
w_init = torch.qr(torch.randn(*w_shape))[0]
if not LU_decomposed:
self.weight = nn.Parameter(torch.Tensor(w_init))
else:
p, lower, upper = torch.lu_unpack(*torch.lu(w_init))
s = torch.diag(upper)
sign_s = torch.sign(s)
log_s = torch.log(torch.abs(s))
upper = torch.triu(upper, 1)
l_mask = torch.tril(torch.ones(w_shape), -1)
eye = torch.eye(*w_shape)
self.register_buffer("p", p)
self.register_buffer("sign_s", sign_s)
self.lower = nn.Parameter(lower)
self.log_s = nn.Parameter(log_s)
self.upper = nn.Parameter(upper)
self.l_mask = l_mask
self.eye = eye
self.w_shape = w_shape
self.LU_decomposed = LU_decomposed
def get_weight(self, input, reverse):
b, c, h, w = input.shape
if not self.LU_decomposed:
dlogdet = torch.slogdet(self.weight)[1] * h * w
if reverse:
weight = torch.inverse(self.weight)
else:
weight = self.weight
else:
self.l_mask = self.l_mask.to(self.lower.device)
self.eye = self.eye.to(self.lower.device)
lower = self.lower * self.l_mask + self.eye
u = self.upper * self.l_mask.transpose(0, 1).contiguous().to(
self.upper.device)
u += torch.diag(self.sign_s * torch.exp(self.log_s))
dlogdet = torch.sum(self.log_s) * h * w
if reverse:
u_inv = torch.inverse(u)
l_inv = torch.inverse(lower)
p_inv = torch.inverse(self.p)
weight = torch.matmul(u_inv, torch.matmul(l_inv, p_inv))
else:
weight = torch.matmul(self.p, torch.matmul(lower, u))
return weight.view(self.w_shape[0], self.w_shape[1], 1,
1).to(input.device), dlogdet.to(input.device)
def forward(self, input, logdet=None, reverse=False):
"""
log-det = log|abs(|W|)| * pixels
"""
weight, dlogdet = self.get_weight(input, reverse)
if not reverse:
z = F.conv2d(input, weight)
if logdet is not None:
logdet = logdet + dlogdet
return z, logdet
else:
z = F.conv2d(input, weight)
if logdet is not None:
logdet = logdet - dlogdet
return z, logdet
def get_block(in_channels, out_channels, hidden_channels):
block = nn.Sequential(
Conv2d(in_channels, hidden_channels),
nn.ReLU(inplace=False),
Conv2d(hidden_channels, hidden_channels, kernel_size=(1, 1)),
nn.ReLU(inplace=False),
Conv2dZeros(hidden_channels, out_channels),
)
return block
class FlowStep(nn.Module):
def __init__(
self,
in_channels,
hidden_channels,
actnorm_scale,
flow_permutation,
flow_coupling,
LU_decomposed,
):
super().__init__()
self.flow_coupling = flow_coupling
self.actnorm = ActNorm2d(in_channels, actnorm_scale)
# 2. permute
if flow_permutation == "invconv":
self.invconv = InvertibleConv1x1(in_channels,
LU_decomposed=LU_decomposed)
self.flow_permutation = lambda z, logdet, rev: self.invconv(
z, logdet, rev)
elif flow_permutation == "shuffle":
self.shuffle = Permute2d(in_channels, shuffle=True)
self.flow_permutation = lambda z, logdet, rev: (
self.shuffle(z, rev),
logdet,
)
else:
self.reverse = Permute2d(in_channels, shuffle=False)
self.flow_permutation = lambda z, logdet, rev: (
self.reverse(z, rev),
logdet,
)
# 3. coupling
if flow_coupling == "additive":
self.block = get_block(in_channels // 2, in_channels // 2,
hidden_channels)
elif flow_coupling == "affine":
self.block = get_block(in_channels // 2, in_channels,
hidden_channels)
def forward(self, input, logdet=None, reverse=False):
if not reverse:
return self.normal_flow(input, logdet)
else:
return self.reverse_flow(input, logdet)
def normal_flow(self, input, logdet):
assert input.size(1) % 2 == 0
# 1. actnorm
z, logdet = self.actnorm(input, logdet=logdet, reverse=False)
# 2. permute
z, logdet = self.flow_permutation(z, logdet, False)
# 3. coupling
z1, z2 = split_feature(z, "split")
if self.flow_coupling == "additive":
z2 = z2 + self.block(z1)
elif self.flow_coupling == "affine":
h = self.block(z1)
shift, scale = split_feature(h, "cross")
scale = torch.sigmoid(scale + 2.0)
z2 = z2 + shift
z2 = z2 * scale
logdet = torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet
z = torch.cat((z1, z2), dim=1)
return z, logdet
def reverse_flow(self, input, logdet):
assert input.size(1) % 2 == 0
# 1.coupling
z1, z2 = split_feature(input, "split")
if self.flow_coupling == "additive":
z2 = z2 - self.block(z1)
elif self.flow_coupling == "affine":
h = self.block(z1)
shift, scale = split_feature(h, "cross")
scale = torch.sigmoid(scale + 2.0)
z2 = z2 / scale
z2 = z2 - shift
logdet = -torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet
z = torch.cat((z1, z2), dim=1)
# 2. permute
z, logdet = self.flow_permutation(z, logdet, True)
# 3. actnorm
z, logdet = self.actnorm(z, logdet=logdet, reverse=True)
return z, logdet
class FlowNet(nn.Module):
def __init__(
self,
image_shape,
hidden_channels,
K,
L,
actnorm_scale,
flow_permutation,
flow_coupling,
LU_decomposed,
):
super().__init__()
self.layers = nn.ModuleList()
self.output_shapes = []
self.K = K
self.L = L
H, W, C = image_shape
for i in range(L):
# 1. Squeeze
C, H, W = C * 4, H // 2, W // 2
self.layers.append(SqueezeLayer(factor=2))
self.output_shapes.append([-1, C, H, W])
# 2. K FlowStep
for _ in range(K):
self.layers.append(
FlowStep(
in_channels=C,
hidden_channels=hidden_channels,
actnorm_scale=actnorm_scale,
flow_permutation=flow_permutation,
flow_coupling=flow_coupling,
LU_decomposed=LU_decomposed,
))
self.output_shapes.append([-1, C, H, W])
# 3. Split2d
if i < L - 1:
self.layers.append(Split2d(num_channels=C))
self.output_shapes.append([-1, C // 2, H, W])
C = C // 2
def forward(self, input, logdet=0.0, reverse=False, temperature=None):
if reverse:
return self.decode(input, temperature)
else:
return self.encode(input, logdet)
def encode(self, z, logdet=0.0):
for layer, shape in zip(self.layers, self.output_shapes):
z, logdet = layer(z, logdet, reverse=False)
return z, logdet
def decode(self, z, temperature=None):
for layer in reversed(self.layers):
if isinstance(layer, Split2d):
z, logdet = layer(z,
logdet=0,
reverse=True,
temperature=temperature)
else:
z, logdet = layer(z, logdet=0, reverse=True)
return z
class BGlow(nn.Module):
r'''
This class implements the normalized flow model, allowing to generate samples close to the true distribution. A flow-based model is dedicated to train an encoder that encodes the input as a hidden variable and makes the hidden variable obey the standard normal distribution. By good design, the encoder should be reversible. On this basis, as soon as the encoder is trained, the corresponding decoder can be used to generate samples from a Gaussian distribution according to the inverse operation. In particular, the Glow model is a easy-to-use flow-based model that replaces the operation of permutating the channel axes by introducing a 1x1 reversible convolution.
- Paper: Kingma D P, Dhariwal P. Glow: Generative flow with invertible 1x1 convolutions[J]. Advances in neural information processing systems, 2018, 31.
- URL: https://arxiv.org/abs/1807.03039
- Related Project: https://github.com/y0ast/Glow-PyTorch/
- Related Project: https://github.com/ikostrikov/pytorch-flows/
Below is a recommended suite for use in EEG generation:
.. code-block:: python
eeg = torch.randn(1, 4, 32, 32)
model = BGlow()
nll_loss = model(eeg)
fake_X = model(num=1, temperature=1.0)
Args:
in_channels (int): The feature dimension of each electrode. (default: :obj:`4`)
grid_size (tuple): Spatial dimensions of grid-like EEG representation. (default: :obj:`(32, 32)`)
hid_channels (int): The basic hidden channels in the network blocks. (default: :obj:`64`)
num_steps (int): The number of steps in the flow, each step contains an affine coupling layer, an invertible 1x1 conv and an actnorm layer. (default: :obj:`32`)
num_blocks (int): Number of blocks, each block includes split, step of flow and squeeze. (default: :obj:`3`)
actnorm_scale (float): The pre-defined scale factor in the actnorm layer. (default: :obj:`1.0`)
flow_permutation (str): The used flow permutation method, options include :obj:`invconv`, :obj:`shuffle` and :obj:`reverse`. (default: :obj:`invconv`)
flow_coupling (str): The used flow coupling method, options include :obj:`additive` and :obj:`affine`. (default: :obj:`affine`)
LU_decomposed (bool): Whether to use LU decomposed 1x1 convs. (default: :obj:`True`)
learnable_prior (bool): Whether to train top layer (prior). (default: :obj:`True`)
... automethod:: log_probs
... automethod:: sample
'''
def __init__(
self,
in_channels: int = 4,
grid_size: tuple = (32, 32),
hidden_channels: int = 64,
num_steps: int = 32,
num_blocks: int = 3,
actnorm_scale: float = 1.0,
flow_permutation: str = "invconv",
flow_coupling: str = "affine",
LU_decomposed: bool = True,
learn_top: bool = True,
):
super().__init__()
self.flow = FlowNet(
image_shape=[grid_size[0], grid_size[1], in_channels],
hidden_channels=hidden_channels,
K=num_steps,
L=num_blocks,
actnorm_scale=actnorm_scale,
flow_permutation=flow_permutation,
flow_coupling=flow_coupling,
LU_decomposed=LU_decomposed,
)
self.learn_top = learn_top
# learned prior
if learn_top:
C = self.flow.output_shapes[-1][1]
self.learn_top_fn = Conv2dZeros(C * 2, C * 2)
self.register_buffer(
"prior_h",
torch.zeros([
1,
self.flow.output_shapes[-1][1] * 2,
self.flow.output_shapes[-1][2],
self.flow.output_shapes[-1][3],
]),
)
def prior(self, num=None):
if num is not None:
h = self.prior_h.repeat(num, 1, 1, 1)
else:
# Hardcoded a batch size of 32 here
h = self.prior_h.repeat(32, 1, 1, 1)
if self.learn_top:
h = self.learn_top_fn(h)
return split_feature(h, "split")
def forward(self, x: torch.Tensor) -> torch.Tensor:
r'''
Args:
x (torch.Tensor): EEG signal representation. The ideal input shape is :obj:`[n, 4, 32, 32]`. Here, :obj:`n` corresponds to the batch size, :obj:`4` corresponds to the :obj:`in_channels`, and :obj:`(32, 32)` corresponds to the :obj:`grid_size`.
y (torch.Tensor): Category labels (int) for a batch of samples The shape should be :obj:`[n,]`. Here, :obj:`n` corresponds to the batch size.
Returns:
torch.Tensor: The latent representation.
torch.Tensor: The bit per dimension (BPD) negative log-likelihood.
'''
b, c, h, w = x.shape
x, logdet = uniform_binning_correction(x)
z, objective = self.flow(x, logdet=logdet, reverse=False)
mean, logs = self.prior(x.shape[0])
objective += gaussian_likelihood(mean, logs, z)
# Full objective - converted to bits per dimension
bpd = (-objective) / (math.log(2.0) * c * h * w)
return z, bpd
def reverse(self,
z: torch.Tensor,
temperature: float = 1.0) -> torch.Tensor:
x = self.flow(z, temperature=temperature, reverse=True)
return x
def log_probs(self, x: torch.Tensor) -> torch.Tensor:
r'''
Args:
x (torch.Tensor): EEG signal representation. The ideal input shape is :obj:`[n, 4, 32, 32]`. Here, :obj:`n` corresponds to the batch size, :obj:`4` corresponds to the :obj:`in_channels`, and :obj:`(32, 32)` corresponds to the :obj:`grid_size`.
y (torch.Tensor): Category labels (int) for a batch of samples The shape should be :obj:`[n,]`. Here, :obj:`n` corresponds to the batch size.
Returns:
torch.Tensor: The bit per dimension (BPD) negative log-likelihood.
'''
_, bpd = self.forward(x)
return bpd
def sample(self, num: int = 1, temperature: float = 1.0) -> torch.Tensor:
r'''
Args:
num (int): The number of samples to generate. (default: :obj:`1`)
temperature (float): The hyper-parameter, temperature, to sample from gaussian distributions. (default: :obj:`1.0`)
Returns:
torch.Tensor: the generated results, which should have the same shape as the input noise, i.e., :obj:`[n, 4, 32, 32]`. Here, :obj:`n` corresponds to the batch size, :obj:`4` corresponds to :obj:`in_channels`, and :obj:`(32, 32)` corresponds to :obj:`grid_size`.
'''
mean, logs = self.prior(num)
z = gaussian_sample(mean, logs, temperature)
x = self.reverse(z, temperature=temperature)
return x
def set_actnorm_init(self):
for name, m in self.named_modules():
if isinstance(m, ActNorm2d):
m.inited = True