-
Notifications
You must be signed in to change notification settings - Fork 19
/
gaussianization.py
150 lines (125 loc) · 4.11 KB
/
gaussianization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
r"""Gaussianization flows."""
__all__ = [
'ElementWiseTransform',
'GF',
]
import torch
import torch.nn as nn
from math import prod
from torch import Tensor, Size
from torch.distributions import Transform
from typing import *
from .core import *
from ..distributions import DiagNormal
from ..transforms import *
from ..nn import MLP
from ..utils import unpack
class ElementWiseTransform(LazyTransform):
r"""Creates a lazy element-wise transformation.
Arguments:
features: The number of features.
context: The number of context features.
univariate: The univariate transformation constructor.
shapes: The shapes of the univariate transformation parameters.
kwargs: Keyword arguments passed to :class:`zuko.nn.MLP`.
Example:
>>> t = ElementWiseTransform(3, 4)
>>> t
ElementWiseTransform(
(base): MonotonicAffineTransform()
(hyper): MLP(
(0): Linear(in_features=4, out_features=64, bias=True)
(1): ReLU()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): ReLU()
(4): Linear(in_features=64, out_features=6, bias=True)
)
)
>>> x = torch.randn(3)
>>> x
tensor([2.1983, -1.3182, 0.0329])
>>> c = torch.randn(4)
>>> y = t(c)(x)
>>> t(c).inv(y)
tensor([2.1983, -1.3182, 0.0329])
"""
def __init__(
self,
features: int,
context: int = 0,
univariate: Callable[..., Transform] = MonotonicAffineTransform,
shapes: Sequence[Size] = ((), ()),
**kwargs,
):
super().__init__()
self.univariate = univariate
self.shapes = shapes
self.total = sum(prod(s) for s in shapes)
if context > 0:
self.hyper = MLP(context, features * self.total, **kwargs)
else:
self.phi = nn.ParameterList(torch.randn(features, *s) for s in shapes)
def extra_repr(self) -> str:
base = self.univariate(*map(torch.randn, self.shapes))
return '\n'.join([
f'(base): {base}',
])
def forward(self, c: Tensor = None) -> Transform:
if c is None:
phi = self.phi
else:
phi = self.hyper(c)
phi = phi.unflatten(-1, (-1, self.total))
phi = unpack(phi, self.shapes)
return DependentTransform(self.univariate(*phi), 1)
class GF(Flow):
r"""Creates a gaussianization flow (GF).
Warning:
Invertibility is only guaranteed for features within the interval :math:`[-10,
10]`. It is recommended to standardize features (zero mean, unit variance)
before training.
See also:
:class:`zuko.transforms.GaussianizationTransform`
References:
| Gaussianization Flows (Meng et al., 2020)
| https://arxiv.org/abs/2003.01941
Arguments:
features: The number of features.
context: The number of context features.
transforms: The number of coupling transformations.
components: The number of mixture components in each transformation.
kwargs: Keyword arguments passed to :class:`ElementWiseTransform`.
"""
def __init__(
self,
features: int,
context: int = 0,
transforms: int = 3,
components: int = 8,
**kwargs,
):
transforms = [
ElementWiseTransform(
features=features,
context=context,
univariate=GaussianizationTransform,
shapes=[(components,), (components,)],
**kwargs,
)
for _ in range(transforms)
]
for i in reversed(range(1, len(transforms))):
transforms.insert(
i,
Unconditional(
RotationTransform,
torch.randn(features, features),
),
)
base = Unconditional(
DiagNormal,
torch.zeros(features),
torch.ones(features),
buffer=True,
)
super().__init__(transforms, base)