/
cnf.py
95 lines (76 loc) · 3.02 KB
/
cnf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn as nn
from torchdiffeq import odeint_adjoint as odeint
from .wrappers.cnf_regularization import RegularizedODEfunc
__all__ = ["CNF"]
class CNF(nn.Module):
def __init__(self, odefunc, T=1.0, train_T=False, regularization_fns=None, solver='dopri5', atol=1e-5, rtol=1e-5):
super(CNF, self).__init__()
if train_T:
self.register_parameter("sqrt_end_time", nn.Parameter(torch.sqrt(torch.tensor(T))))
else:
self.register_buffer("sqrt_end_time", torch.sqrt(torch.tensor(T)))
nreg = 0
if regularization_fns is not None:
odefunc = RegularizedODEfunc(odefunc, regularization_fns)
nreg = len(regularization_fns)
self.odefunc = odefunc
self.nreg = nreg
self.regularization_states = None
self.solver = solver
self.atol = atol
self.rtol = rtol
self.test_solver = solver
self.test_atol = atol
self.test_rtol = rtol
self.solver_options = {}
def forward(self, z, logpz=None, integration_times=None, reverse=False):
if logpz is None:
_logpz = torch.zeros(z.shape[0], 1).to(z)
else:
_logpz = logpz
if integration_times is None:
integration_times = torch.tensor([0.0, self.sqrt_end_time * self.sqrt_end_time]).to(z)
if reverse:
integration_times = _flip(integration_times, 0)
# Refresh the odefunc statistics.
self.odefunc.before_odeint()
# Add regularization states.
reg_states = tuple(torch.tensor(0).to(z) for _ in range(self.nreg))
if self.training:
state_t = odeint(
self.odefunc,
(z, _logpz) + reg_states,
integration_times.to(z),
atol=[self.atol, self.atol] + [1e20] * len(reg_states) if self.solver == 'dopri5' else self.atol,
rtol=[self.rtol, self.rtol] + [1e20] * len(reg_states) if self.solver == 'dopri5' else self.rtol,
method=self.solver,
options=self.solver_options,
)
else:
state_t = odeint(
self.odefunc,
(z, _logpz),
integration_times.to(z),
atol=self.test_atol,
rtol=self.test_rtol,
method=self.test_solver,
)
if len(integration_times) == 2:
state_t = tuple(s[1] for s in state_t)
z_t, logpz_t = state_t[:2]
self.regularization_states = state_t[2:]
if logpz is not None:
return z_t, logpz_t
else:
return z_t
def get_regularization_states(self):
reg_states = self.regularization_states
self.regularization_states = None
return reg_states
def num_evals(self):
return self.odefunc._num_evals.item()
def _flip(x, dim):
indices = [slice(None)] * x.dim()
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
return x[tuple(indices)]