-
Notifications
You must be signed in to change notification settings - Fork 0
/
MCRegressor.py
95 lines (84 loc) · 2.58 KB
/
MCRegressor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy as np
from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin
from sklearn.metrics import r2_score
from sklearn.utils import check_array, check_X_y
from sklearn.utils.extmath import safe_sparse_dot
from sklearn.utils.validation import check_is_fitted
from optimization import (
W_init,
check_alpha,
check_kappa,
check_t,
get_lambda,
get_sigma,
mc_gd_descent,
)
class MCRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
def __init__(
self,
alpha=0.01,
t=0.9,
kappa=21,
random_key=None,
a_hat=0.00,
M=10,
M_prime=0,
lr=1e-2,
N_steps=2000,
init_method="zeros",
):
self.alpha = alpha
self.t = t
self.kappa = kappa
self.M = M
self.M_prime = M_prime
self.a_hat = a_hat
self.lr = lr
self.N_steps = N_steps
self.init_method = init_method
self.random_key = random_key
def fit(self, X, phi_y):
# check that X and y have correct shape
X, phi_y = check_X_y(X, phi_y, multi_output=True)
# params = dict with 'alpha', 't', 'kappa'
self.n_features_ = X.shape[1]
self.m_ = X.shape[0]
self.dim_H_ = phi_y.shape[1]
# check hyperparameters
check_kappa(self.kappa, X) # check that kappa is valid upper bound for X
check_alpha(self.alpha)
check_t(self.t)
# pre-compute
self.sigma_ = get_sigma(self.m_, self.alpha, self.t, self.kappa)
self.lambda_ = get_lambda(self.t, self.sigma_, self.m_, self.alpha)
print(f"Lambda: {self.lambda_}")
print(f"Sigma: {self.sigma_}")
# should be jnp from here --> add checks
predictor_shape = (self.dim_H_, self.n_features_)
W_0 = W_init(predictor_shape, self.init_method)
# training with jax SGD
print("MC descent")
W, self.random_key = mc_gd_descent(
X,
phi_y,
W_0,
self.sigma_,
self.a_hat,
self.lambda_,
self.lr,
self.N_steps,
self.M,
self.M_prime,
self.random_key,
)
W = np.array(W)
self.coef_ = W
return self
def predict(self, X):
check_is_fitted(self, ["coef_"])
X = check_array(X)
return safe_sparse_dot(X, self.coef_.T, dense_output=True)
def score(self, X, phi_y):
X, phi_y = check_X_y(X, phi_y, multi_output=True)
phi_y_pred = self.predict(X)
return r2_score(phi_y_pred, phi_y)