-
Notifications
You must be signed in to change notification settings - Fork 4
/
getters.py
36 lines (28 loc) · 1.43 KB
/
getters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from sde import SDE, VpSdeSigmoid, VpSdeCos, GeneralizedSubVpSdeCos, SubVpSdeCos
def get_sde(sde_type: str,
sde_kwargs,) -> SDE:
sigma_min = sde_kwargs['sigma_min'] if 'sigma_min' in sde_kwargs else None
sigma_max = sde_kwargs['sigma_max'] if 'sigma_max' in sde_kwargs else None
gamma = sde_kwargs['gamma'] if 'gamma' in sde_kwargs else None
eta = sde_kwargs['eta'] if 'eta' in sde_kwargs else None
# mle_training = sde_kwargs['mle_training']
if sde_type == 'vp-sigmoid':
# return VpSdeSigmoid(mle_training)
return VpSdeSigmoid()
if sde_type == 'vp-cos':
assert ((sigma_min is not None) and (sigma_max is not None))
# return VpSdeCos(mle_training, sigma_min, sigma_max)
return VpSdeCos(sigma_min, sigma_max)
if sde_type == 'subvp-cos':
assert ((sigma_min is not None) and (sigma_max is not None))
# return SubVpSdeCos(mle_training, sigma_min, sigma_max)
return SubVpSdeCos(sigma_min, sigma_max)
if sde_type == 'generalized-sub-vp-cos':
assert ((sigma_min is not None) and (sigma_max is not None))
assert ((gamma is not None) and (eta is not None))
# return GeneralizedSubVpSdeCos(mle_training, gamma, eta, sigma_min,
# sigma_max)
return GeneralizedSubVpSdeCos(gamma, eta, sigma_min,
sigma_max)
else:
raise NotImplementedError