Skip to content

Commit 9eac347

Browse files
authored
v2 joint regression (#395)
1 parent 4f8d524 commit 9eac347

File tree

10 files changed

+103
-41
lines changed

10 files changed

+103
-41
lines changed

pgml-extension/examples/joint_regression.sql

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ SELECT pgml.load_dataset('linnerud');
1111
SELECT * FROM pgml.linnerud LIMIT 10;
1212

1313
-- train a simple model on the data
14-
SELECT * FROM pgml.train_joint('Exercise vs Physiology', 'regression', 'pgml.linnerud', ARRAY['weight', 'waste', 'pulse']);
14+
SELECT * FROM pgml.train_joint('Exercise vs Physiology', 'regression', 'pgml.linnerud', ARRAY['weight', 'waist', 'pulse']);
1515

1616
-- check out the predictions
17-
SELECT weight, waste, pulse, pgml.predict_joint('Exercise vs Physiology', ARRAY[chins, situps, jumps]) AS prediction
17+
SELECT weight, waist, pulse, pgml.predict_joint('Exercise vs Physiology', ARRAY[chins, situps, jumps]) AS prediction
1818
FROM pgml.linnerud
1919
LIMIT 10;
2020

@@ -24,7 +24,7 @@ SELECT * FROM pgml.train_joint('Exercise vs Physiology', algorithm => 'lasso');
2424
SELECT * FROM pgml.train_joint('Exercise vs Physiology', algorithm => 'elastic_net');
2525
SELECT * FROM pgml.train_joint('Exercise vs Physiology', algorithm => 'least_angle');
2626
SELECT * FROM pgml.train_joint('Exercise vs Physiology', algorithm => 'lasso_least_angle');
27-
SELECT * FROM pgml.train_joint('Exercise vs Physiology', algorithm => 'orthoganl_matching_pursuit');
27+
SELECT * FROM pgml.train_joint('Exercise vs Physiology', algorithm => 'orthogonal_matching_pursuit');
2828
SELECT * FROM pgml.train_joint('Exercise vs Physiology', algorithm => 'bayesian_ridge');
2929
SELECT * FROM pgml.train_joint('Exercise vs Physiology', algorithm => 'automatic_relevance_determination');
3030
SELECT * FROM pgml.train_joint('Exercise vs Physiology', algorithm => 'stochastic_gradient_descent');
@@ -77,6 +77,6 @@ SELECT * FROM pgml.deploy('Exercise vs Physiology', 'rollback');
7777
SELECT * FROM pgml.deploy('Exercise vs Physiology', 'best_score', 'svm');
7878

7979
-- check out the improved predictions
80-
SELECT weight, waste, pulse, pgml.predict_joint('Exercise vs Physiology', ARRAY[chins, situps, jumps]) AS prediction
80+
SELECT weight, waist, pulse, pgml.predict_joint('Exercise vs Physiology', ARRAY[chins, situps, jumps]) AS prediction
8181
FROM pgml.linnerud
8282
LIMIT 10;

pgml-extension/src/api.rs

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ pub fn python_version() -> String {
102102
fn version() -> String {
103103
crate::VERSION.to_string()
104104
}
105-
106105
#[allow(clippy::too_many_arguments)]
107106
#[pg_extern]
108107
fn train(
@@ -126,6 +125,50 @@ fn train(
126125
name!(algorithm, String),
127126
name!(deployed, bool),
128127
),
128+
> {
129+
train_joint(
130+
project_name,
131+
task,
132+
relation_name,
133+
match y_column_name {
134+
Some(y_column_name) => Some(vec![y_column_name.to_string()]),
135+
None => None,
136+
},
137+
algorithm,
138+
hyperparams,
139+
search,
140+
search_params,
141+
search_args,
142+
test_size,
143+
test_sampling,
144+
runtime,
145+
automatic_deploy,
146+
)
147+
}
148+
149+
#[allow(clippy::too_many_arguments)]
150+
#[pg_extern]
151+
fn train_joint(
152+
project_name: &str,
153+
task: Option<default!(Task, "NULL")>,
154+
relation_name: Option<default!(&str, "NULL")>,
155+
y_column_name: Option<default!(Vec<String>, "NULL")>,
156+
algorithm: default!(Algorithm, "'linear'"),
157+
hyperparams: default!(JsonB, "'{}'"),
158+
search: Option<default!(Search, "NULL")>,
159+
search_params: default!(JsonB, "'{}'"),
160+
search_args: default!(JsonB, "'{}'"),
161+
test_size: default!(f32, 0.25),
162+
test_sampling: default!(Sampling, "'last'"),
163+
runtime: Option<default!(Runtime, "NULL")>,
164+
automatic_deploy: Option<default!(bool, true)>,
165+
) -> impl std::iter::Iterator<
166+
Item = (
167+
name!(project, String),
168+
name!(task, String),
169+
name!(algorithm, String),
170+
name!(deployed, bool),
171+
),
129172
> {
130173
let project = match Project::find_by_name(project_name) {
131174
Some(project) => project,
@@ -364,6 +407,11 @@ fn deploy(
364407

365408
#[pg_extern]
366409
fn predict(project_name: &str, features: Vec<f32>) -> f32 {
410+
predict_joint(project_name, features)[0]
411+
}
412+
413+
#[pg_extern]
414+
fn predict_joint(project_name: &str, features: Vec<f32>) -> Vec<f32> {
367415
let mut projects = PROJECT_NAME_TO_PROJECT_ID.lock();
368416
let project_id = match projects.get(project_name) {
369417
Some(project_id) => *project_id,
@@ -415,7 +463,12 @@ fn snapshot(
415463
test_size: default!(f32, 0.25),
416464
test_sampling: default!(Sampling, "'last'"),
417465
) -> impl std::iter::Iterator<Item = (name!(relation, String), name!(y_column_name, String))> {
418-
Snapshot::create(relation_name, y_column_name, test_size, test_sampling);
466+
Snapshot::create(
467+
relation_name,
468+
vec![y_column_name.to_string()],
469+
test_size,
470+
test_sampling,
471+
);
419472
vec![(relation_name.to_string(), y_column_name.to_string())].into_iter()
420473
}
421474

@@ -442,7 +495,7 @@ fn load_dataset(
442495
#[pg_extern]
443496
fn model_predict(model_id: i64, features: Vec<f32>) -> f32 {
444497
let estimator = crate::orm::file::find_deployed_estimator_by_model_id(model_id);
445-
estimator.predict(&features)
498+
estimator.predict(&features)[0]
446499
}
447500

448501
#[pg_extern]

pgml-extension/src/bindings/lightgbm.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, task: Task) -> Box<dyn Bind
7878

7979
impl Bindings for Estimator {
8080
/// Predict a novel datapoint.
81-
fn predict(&self, features: &[f32]) -> f32 {
82-
self.predict_batch(features)[0]
81+
fn predict(&self, features: &[f32]) -> Vec<f32> {
82+
self.predict_batch(features)
8383
}
8484

8585
/// Predict a novel datapoint.

pgml-extension/src/bindings/linfa.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ impl LinearRegression {
5252

5353
impl Bindings for LinearRegression {
5454
/// Predict a novel datapoint.
55-
fn predict(&self, features: &[f32]) -> f32 {
56-
self.predict_batch(features)[0]
55+
fn predict(&self, features: &[f32]) -> Vec<f32> {
56+
self.predict_batch(features)
5757
}
5858

5959
/// Predict a novel datapoint.
@@ -182,8 +182,8 @@ impl LogisticRegression {
182182

183183
impl Bindings for LogisticRegression {
184184
/// Predict a novel datapoint.
185-
fn predict(&self, features: &[f32]) -> f32 {
186-
self.predict_batch(features)[0]
185+
fn predict(&self, features: &[f32]) -> Vec<f32> {
186+
self.predict_batch(features)
187187
}
188188

189189
/// Predict a novel datapoint.
@@ -290,8 +290,8 @@ impl Svm {
290290

291291
impl Bindings for Svm {
292292
/// Predict a novel datapoint.
293-
fn predict(&self, features: &[f32]) -> f32 {
294-
self.predict_batch(features)[0]
293+
fn predict(&self, features: &[f32]) -> Vec<f32> {
294+
self.predict_batch(features)
295295
}
296296

297297
/// Predict a novel datapoint.

pgml-extension/src/bindings/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ pub type Fit = fn(dataset: &Dataset, hyperparams: &Hyperparams) -> Box<dyn Bindi
2222
/// implement serde.
2323
pub trait Bindings: Send + Sync {
2424
/// Predict a novel datapoint.
25-
fn predict(&self, features: &[f32]) -> f32;
25+
fn predict(&self, features: &[f32]) -> Vec<f32>;
2626

2727
/// Predict a set of datapoints.
2828
fn predict_batch(&self, features: &[f32]) -> Vec<f32>;

pgml-extension/src/bindings/sklearn.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,7 @@
6767
}
6868

6969

70-
def estimator(algorithm, num_features, hyperparams):
71-
"""Returns the correct estimator based on algorithm names
72-
we defined internally.
73-
74-
Parameters:
75-
- algorithm: The human-readable name of the algorithm (see dict above).
76-
- num_features: The number of features in X.
77-
- hyperparams: JSON of hyperparameters.
78-
"""
79-
return estimator_joint(algorithm, num_features, 1, hyperparams)
80-
81-
82-
def estimator_joint(algorithm, num_features, num_targets, hyperparams):
70+
def estimator(algorithm, num_features, num_targets, hyperparams):
8371
"""Returns the correct estimator based on algorithm names we defined
8472
internally (see dict above).
8573
@@ -97,6 +85,22 @@ def estimator_joint(algorithm, num_features, num_targets, hyperparams):
9785

9886
def train(X_train, y_train):
9987
instance = _ALGORITHM_MAP[algorithm](**hyperparams)
88+
if num_targets > 1 and algorithm in [
89+
"bayesian_ridge_regression",
90+
"automatic_relevance_determination_regression",
91+
"stochastic_gradient_descent_regression",
92+
"passive_aggressive_regression",
93+
"theil_sen_regression",
94+
"huber_regression",
95+
"quantile_regression",
96+
"svm_regression",
97+
"nu_svm_regression",
98+
"linear_svm_regression",
99+
"ada_boost_regression",
100+
"gradient_boosting_trees_regression",
101+
"lightgbm_regression",
102+
]:
103+
instance = sklearn.multioutput.MultiOutputRegressor(instance)
100104

101105
X_train = np.asarray(X_train).reshape((-1, num_features))
102106

pgml-extension/src/bindings/sklearn.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ fn fit(
307307
&[
308308
String::from(algorithm_task).into_py(py),
309309
dataset.num_features.into_py(py),
310+
dataset.num_labels.into_py(py),
310311
hyperparams.into_py(py),
311312
],
312313
),
@@ -350,8 +351,8 @@ impl std::fmt::Debug for Estimator {
350351

351352
impl Bindings for Estimator {
352353
/// Predict a novel datapoint.
353-
fn predict(&self, features: &[f32]) -> f32 {
354-
self.predict_batch(features)[0]
354+
fn predict(&self, features: &[f32]) -> Vec<f32> {
355+
self.predict_batch(features)
355356
}
356357

357358
/// Predict a novel datapoint.

pgml-extension/src/bindings/xgboost.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,8 @@ impl std::fmt::Debug for Estimator {
227227

228228
impl Bindings for Estimator {
229229
/// Predict a novel datapoint.
230-
fn predict(&self, features: &[f32]) -> f32 {
231-
self.predict_batch(features)[0]
230+
fn predict(&self, features: &[f32]) -> Vec<f32> {
231+
self.predict_batch(features)
232232
}
233233

234234
/// Predict a novel datapoint.

pgml-extension/src/orm/model.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,17 @@ impl Model {
5757
// Set the runtime to one we recommend, unless the user knows better.
5858
let runtime = match runtime {
5959
Some(runtime) => runtime,
60-
None => match algorithm {
61-
Algorithm::xgboost => Runtime::rust,
62-
Algorithm::lightgbm => Runtime::rust,
63-
Algorithm::linear => match project.task {
64-
Task::classification => Runtime::python,
65-
Task::regression => Runtime::rust,
60+
None => match snapshot.y_column_name.len() {
61+
1 => match algorithm {
62+
Algorithm::xgboost => Runtime::rust,
63+
Algorithm::lightgbm => Runtime::rust,
64+
Algorithm::linear => match project.task {
65+
Task::classification => Runtime::python,
66+
Task::regression => Runtime::rust,
67+
},
68+
_ => Runtime::python,
6669
},
70+
// Joint regression is only supported in Python
6771
_ => Runtime::python,
6872
},
6973
};

pgml-extension/src/orm/snapshot.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ impl Snapshot {
119119

120120
pub fn create(
121121
relation_name: &str,
122-
y_column_name: &str,
122+
y_column_name: Vec<String>,
123123
test_size: f32,
124124
test_sampling: Sampling,
125125
) -> Snapshot {
@@ -130,7 +130,7 @@ impl Snapshot {
130130
Some(1),
131131
Some(vec![
132132
(PgBuiltInOids::TEXTOID.oid(), relation_name.into_datum()),
133-
(PgBuiltInOids::TEXTARRAYOID.oid(), vec![y_column_name].into_datum()),
133+
(PgBuiltInOids::TEXTARRAYOID.oid(), y_column_name.into_datum()),
134134
(PgBuiltInOids::FLOAT4OID.oid(), test_size.into_datum()),
135135
(PgBuiltInOids::TEXTOID.oid(), test_sampling.to_string().into_datum()),
136136
(PgBuiltInOids::TEXTOID.oid(), status.to_string().into_datum()),

0 commit comments

Comments
 (0)