/
multivariable_flow.py
271 lines (206 loc) · 9.8 KB
/
multivariable_flow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
"""
This code is mostly taken from https://github.com/slachapelle/dcdi/tree/master/dcdi.
It implements the Deep Sigmoidal Flow and Gaussian MLP for continuous data prediction.
"""
import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
import numpy as np
from tqdm.auto import tqdm
import sys
sys.path.append('../')
from causal_discovery.multivariable_mlp import MultivarLinear, MultivarMLP, InputMask
delta = 1e-6
c = - 0.5 * np.log(2 * np.pi)
def log(x):
return torch.log(x * 1e2) - np.log(1e2)
def log_normal(x, mean, log_var, eps=0.00001):
return - (x - mean) ** 2 / (2. * torch.exp(log_var) + eps) - log_var / 2. + c
def logsigmoid(x):
return -softplus(-x)
def log_sum_exp(A, axis=-1, sum_op=torch.sum):
def maximum(x):
return x.max(axis)[0]
A_max = oper(A, maximum, axis, True)
def summation(x):
return sum_op(torch.exp(x - A_max), axis)
B = torch.log(oper(A, summation, axis, True)) + A_max
return B
def oper(array, oper, axis=-1, keepdims=False):
a_oper = oper(array)
if keepdims:
shape = []
for s in array.size():
shape.append(s)
shape[axis] = -1
a_oper = a_oper.view(*shape)
return a_oper
def softplus(x):
return F.softplus(x) + delta
def softmax(x, dim=-1):
e_x = torch.exp(x - x.max(dim=dim, keepdim=True)[0])
out = e_x / e_x.sum(dim=dim, keepdim=True)
return out
class BaseFlow(torch.nn.Module):
def sample(self, n=1, context=None, **kwargs):
dim = self.dim
if isinstance(self.dim, int):
dim = [dim, ]
spl = Variable(torch.FloatTensor(n, *dim).normal_())
lgd = Variable(torch.from_numpy(
np.zeros(n).astype('float32')))
if context is None:
context = Variable(torch.from_numpy(
np.ones((n, self.context_dim)).astype('float32')))
if hasattr(self, 'gpu'):
if self.gpu:
spl = spl.cuda()
lgd = lgd.cuda()
context = context.gpu()
return self.forward((spl, lgd, context))
def cuda(self):
self.gpu = True
return super(BaseFlow, self).cuda()
class SigmoidFlow(BaseFlow):
"""
Layer used to build Deep sigmoidal flows
Parameters:
-----------
num_ds_dim: uint
The number of hidden units
"""
def __init__(self, num_ds_dim=4):
super(SigmoidFlow, self).__init__()
self.num_ds_dim = num_ds_dim
def act_a(self, x):
return softplus(x)
def act_b(self, x):
return x
def act_w(self, x):
return softmax(x, dim=2)
def forward(self, x, logdet, dsparams, mollify=0.0, delta=delta):
ndim = self.num_ds_dim
# Apply activation functions to the parameters produced by the hypernetwork
a_ = self.act_a(dsparams[:, :, 0: 1 * ndim])
b_ = self.act_b(dsparams[:, :, 1 * ndim: 2 * ndim])
w = self.act_w(dsparams[:, :, 2 * ndim: 3 * ndim])
a = a_ * (1 - mollify) + 1.0 * mollify
b = b_ * (1 - mollify) + 0.0 * mollify
pre_sigm = a * x[:, :, None] + b # C
sigm = torch.sigmoid(pre_sigm)
x_pre = torch.sum(w * sigm, dim=2) # D
x_pre_clipped = x_pre * (1 - delta) + delta * 0.5
x_ = log(x_pre_clipped) - log(1 - x_pre_clipped) # Logit function (so H)
xnew = x_
logj = F.log_softmax(dsparams[:, :, 2 * ndim: 3 * ndim], dim=2) + \
logsigmoid(pre_sigm) + \
logsigmoid(-pre_sigm) + log(a)
logj = log_sum_exp(logj, 2).sum(2)
logdet_ = logj + np.log(1 - delta) - (log(x_pre_clipped) + log(-x_pre_clipped + 1))
logdet += logdet_
return xnew, logdet
class DeepSigmoidalFlowModel(nn.Module):
def __init__(self, num_vars, hidden_dims, flow_n_layers=1, flow_hid_dim=16):
"""
Deep Sigmoidal Flow model
:param int num_vars: number of variables
:param int cond_n_layers: number of layers in the conditioner
:param int cond_hid_dim: number of hidden units in the layers of the conditioner
:param str cond_nonlin: type of non-linearity used in the conditioner
:param int flow_n_layers: number of DSF layers
:param int flow_hid_dim: number of hidden units in the DSF layers
:param boolean intervention: True if use interventional version (DCDI)
:param str intervention_type: Either perfect or imperfect
:param str intervention_knowledge: Either known or unkown
:param int num_regimes: total number of regimes in the data
"""
super().__init__()
flow_n_conditioned = flow_hid_dim
# Conditioner model initialization
n_conditioned_params = flow_n_conditioned * 3 * flow_n_layers # Number of conditional params for each variable
self.condition_mlp = MultivarMLP(input_dims=num_vars*2,
hidden_dims=hidden_dims,
output_dims=n_conditioned_params,
extra_dims=[num_vars],
actfn=lambda : nn.LeakyReLU(inplace=True),
pre_layers=InputMask(None, concat_mask=True))
# Flow model initialization
self.flow_n_layers = flow_n_layers
self.flow_hid_dim = flow_hid_dim
self.flow_n_params_per_var = flow_hid_dim * 3 * flow_n_layers # total number of params
self.flow_n_cond_params_per_var = n_conditioned_params # number of conditional params
self.flow_n_params_per_layer = flow_hid_dim * 3 # number of params in each flow layer
self.flow = SigmoidFlow(flow_hid_dim)
self.num_vars = num_vars
# Shared density parameters (i.e, those that are not produced by the conditioner)
self.shared_density_params = torch.nn.Parameter(torch.zeros(self.flow_n_params_per_var -
self.flow_n_cond_params_per_var))
self.reset_params()
def reset_params(self):
if "flow" in self.__dict__:
self.flow.reset_parameters()
if "shared_density_params" in self.__dict__:
self.shared_density_params.data.uniform_(-0.001, 0.001)
def forward(self, x, mask):
"""
Compute the log likelihood of x given some density specification.
:param x: torch.Tensor, shape=(batch_size, num_vars), the input for which to compute the likelihood.
:param density_params: tuple of torch.Tensor, len=n_vars, shape of elements=(batch_size, n_flow_params_per_var)
The parameters of the DSF model that were produced by the conditioner.
:return: pseudo joint log-likelihood
"""
# Convert the shape to (batch_size, n_vars, n_flow_params_per_var)
density_params = self.condition_mlp(x, mask)
assert len(density_params.shape) == 3
assert density_params.shape[0] == x.shape[0]
assert density_params.shape[1] == self.num_vars
assert density_params.shape[2] == self.flow_n_cond_params_per_var
# Inject shared parameters here
# Add the shared density parameters in each layer's parameter vectors
# The shared parameters are different for each layer
# All batch elements receive the same shared parameters
conditional = density_params.view(density_params.shape[0], density_params.shape[1], self.flow_n_layers, 3, -1)
shared = \
self.shared_density_params.view(self.flow_n_layers, 3, -1)[None, None, :, :, :].repeat(conditional.shape[0],
conditional.shape[1],
1, 1, 1)
density_params = torch.cat((conditional, shared), -1).view(conditional.shape[0], conditional.shape[1], -1)
assert density_params.shape[2] == self.flow_n_params_per_var
logdet = Variable(torch.zeros((x.shape[0], self.num_vars), device=x.device))
h = x.view(x.size(0), -1)
for i in range(self.flow_n_layers):
# Extract params of the current flow layer. Shape is (batch_size, n_vars, self.flow_n_params_per_layer)
params = density_params[:, :, i * self.flow_n_params_per_layer: (i + 1) * self.flow_n_params_per_layer]
h, logdet = self.flow(h, logdet, params)
assert x.shape[0] == h.shape[0]
assert x.shape[1] == h.shape[1]
zeros = Variable(torch.zeros(x.shape[0], self.num_vars, device=x.device))
# Not the joint NLL until we have a DAG
pseudo_joint_nll = - log_normal(h, zeros, zeros + 1.0) - logdet
return pseudo_joint_nll
@property
def device(self):
return self.condition_mlp.device
class GaussianNoiseModel(nn.Module):
def __init__(self, num_vars, hidden_dims):
super().__init__()
self.condition_mlp = MultivarMLP(input_dims=num_vars*2, # Times two because of the mask
hidden_dims=hidden_dims*2,
output_dims=2,
extra_dims=[num_vars],
actfn=lambda : nn.LeakyReLU(inplace=True),
pre_layers=InputMask(None, concat_mask=True))
def forward(self, x, mask):
y = self.condition_mlp(x, mask)
nll = - log_normal(x, y[...,0], torch.exp(y[...,1]))
return nll
@property
def device(self):
return self.condition_mlp.device
def create_continuous_model(num_vars, hidden_dims, use_flow_model=False):
if use_flow_model:
flow = DeepSigmoidalFlowModel(num_vars, hidden_dims)
else:
flow = GaussianNoiseModel(num_vars, hidden_dims)
return flow