Skip to content

Commit 3d02750

Browse files
authored
Tests and checks (#341)
1 parent 5a6ab85 commit 3d02750

File tree

6 files changed

+125
-2
lines changed

6 files changed

+125
-2
lines changed

pgml-extension/pgml_rust/sql/schema.sql

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
---
2+
--- Validate we have the necessary Python dependencies.
3+
---
4+
SELECT pgml_rust.validate_python_dependencies();
5+
16
---
27
--- Track of updates to data
38
---

pgml-extension/pgml_rust/src/engines/sklearn.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,23 @@ use crate::orm::task::Task;
1919

2020
use pgx::*;
2121

22+
#[pg_extern]
23+
pub fn validate_python_dependencies() {
24+
Python::with_gil(|py| {
25+
for module in ["xgboost", "lightgbm", "numpy", "sklearn"] {
26+
match py.import(module) {
27+
Ok(_) => (),
28+
Err(_) => {
29+
panic!(
30+
"The {} package is missing. Install it with `sudo pip3 install {}`",
31+
module, module
32+
);
33+
}
34+
}
35+
}
36+
});
37+
}
38+
2239
#[pg_extern]
2340
pub fn sklearn_version() -> String {
2441
let mut version = String::new();
@@ -64,6 +81,7 @@ fn sklearn_algorithm_name(task: Task, algorithm: Algorithm) -> &'static str {
6481
Algorithm::least_angle => "least_angle_regression",
6582
Algorithm::lasso_least_angle => "lasso_least_angle_regression",
6683
Algorithm::linear_svm => "linear_svm_regression",
84+
Algorithm::lightgbm => "lightgbm_regression",
6785
_ => panic!("{:?} does not support regression", algorithm),
6886
},
6987

@@ -85,6 +103,7 @@ fn sklearn_algorithm_name(task: Task, algorithm: Algorithm) -> &'static str {
85103
Algorithm::gradient_boosting_trees => "gradient_boosting_trees_classification",
86104
Algorithm::hist_gradient_boosting => "hist_gradient_boosting_classification",
87105
Algorithm::linear_svm => "linear_svm_classification",
106+
Algorithm::lightgbm => "lightgbm_classification",
88107
_ => panic!("{:?} does not support classification", algorithm),
89108
},
90109
}

pgml-extension/pgml_rust/src/engines/wrappers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import sklearn.gaussian_process
1212
import sklearn.model_selection
1313
import xgboost as xgb
14+
import lightgbm
1415
import numpy as np
1516
import pickle
1617
import json
@@ -61,6 +62,8 @@
6162
"xgboost_classification": xgb.XGBClassifier,
6263
"xgboost_random_forest_regression": xgb.XGBRFRegressor,
6364
"xgboost_random_forest_classification": xgb.XGBRFClassifier,
65+
"lightgbm_regression": lightgbm.LGBMRegressor,
66+
"lightgbm_classification": lightgbm.LGBMClassifier,
6467
}
6568

6669

pgml-extension/pgml_rust/tests/binary_classification.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ SELECT * FROM pgml_rust.train('Breast Cancer Detection', algorithm => 'random_fo
7878
-- Gradient Boosting
7979
SELECT * FROM pgml_rust.train('Breast Cancer Detection', algorithm => 'xgboost', hyperparams => '{"n_estimators": 10}');
8080
SELECT * FROM pgml_rust.train('Breast Cancer Detection', algorithm => 'xgboost_random_forest', hyperparams => '{"n_estimators": 10}');
81-
-- SELECT * FROM pgml_rust.train('Breast Cancer Detection', algorithm => 'lightgbm', hyperparams => '{"n_estimators": 1}');
81+
SELECT * FROM pgml_rust.train('Breast Cancer Detection', algorithm => 'lightgbm', hyperparams => '{"n_estimators": 100}');
8282
-- Histogram Gradient Boosting is too expensive for normal tests on even a toy dataset
8383
-- SELECT * FROM pgml_rust.train('Breast Cancer Detection', algorithim => 'hist_gradient_boosting', hyperparams => '{"max_iter": 2}');
8484

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
-- This example trains models on the sklean digits dataset
2+
-- which is a copy of the test set of the UCI ML hand-written digits datasets
3+
-- https://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits
4+
--
5+
-- This demonstrates using a table with a single array feature column
6+
-- for classification.
7+
--
8+
-- The final result after a few seconds of training is not terrible. Maybe not perfect
9+
-- enough for mission critical applications, but it's telling how quickly "off the shelf"
10+
-- solutions can solve problems these days.
11+
12+
-- Exit on error (psql)
13+
\set ON_ERROR_STOP true
14+
15+
SELECT pgml_rust.load_dataset('digits');
16+
17+
-- view the dataset
18+
SELECT left(image::text, 40) || ',...}', target FROM pgml_rust.digits LIMIT 10;
19+
20+
-- train a simple model to classify the data
21+
SELECT * FROM pgml_rust.train('Handwritten Digits', 'classification', 'pgml_rust.digits', 'target');
22+
23+
-- check out the predictions
24+
SELECT target, pgml_rust.predict('Handwritten Digits', image) AS prediction
25+
FROM pgml_rust.digits
26+
LIMIT 10;
27+
28+
--
29+
-- After a project has been trained, ommited parameters will be reused from previous training runs
30+
-- In these examples we'll reuse the training data snapshots from the initial call.
31+
--
32+
33+
-- linear models
34+
SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'ridge');
35+
SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'stochastic_gradient_descent');
36+
SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'perceptron');
37+
SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'passive_aggressive');
38+
39+
-- support vector machines
40+
SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'svm');
41+
SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'nu_svm');
42+
SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'linear_svm');
43+
44+
-- ensembles
45+
SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'ada_boost');
46+
SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'bagging');
47+
SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'extra_trees', hyperparams => '{"n_estimators": 10}');
48+
SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'gradient_boosting_trees', hyperparams => '{"n_estimators": 10}');
49+
SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'random_forest', hyperparams => '{"n_estimators": 10}');
50+
51+
-- other
52+
-- Gaussian Process is too expensive for normal tests on even a toy dataset
53+
-- SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'gaussian_process', hyperparams => '{"max_iter_predict": 100, "warm_start": true}');
54+
55+
-- gradient boosting
56+
SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'xgboost', hyperparams => '{"n_estimators": 10}');
57+
SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'xgboost_random_forest', hyperparams => '{"n_estimators": 10}');
58+
SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'lightgbm', hyperparams => '{"n_estimators": 100}');
59+
-- Histogram Gradient Boosting is too expensive for normal tests on even a toy dataset
60+
-- SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'hist_gradient_boosting', hyperparams => '{"max_iter": 2}');
61+
62+
63+
-- check out all that hard work
64+
SELECT trained_models.* FROM pgml_rust.trained_models
65+
JOIN pgml_rust.models on models.id = trained_models.id
66+
ORDER BY models.metrics->>'f1' DESC LIMIT 5;
67+
68+
-- deploy the random_forest model for prediction use
69+
SELECT * FROM pgml_rust.deploy('Handwritten Digits', 'most_recent', 'random_forest');
70+
-- check out that throughput
71+
SELECT * FROM pgml_rust.deployed_models ORDER BY deployed_at DESC LIMIT 5;
72+
73+
-- do a hyperparam search on your favorite algorithm
74+
SELECT pgml_rust.train(
75+
'Handwritten Digits',
76+
algorithm => 'svm',
77+
hyperparams => '{"random_state": 0}',
78+
search => 'grid',
79+
search_params => '{
80+
"kernel": ["linear", "poly", "sigmoid"],
81+
"shrinking": [true, false]
82+
}'
83+
);
84+
85+
-- TODO SELECT pgml_rust.hypertune(100, 'Handwritten Digits', 'classification', 'pgml_rust.digits', 'target', 'gradient_boosted_trees');
86+
87+
-- deploy the "best" model for prediction use
88+
SELECT * FROM pgml_rust.deploy('Handwritten Digits', 'best_score');
89+
SELECT * FROM pgml_rust.deploy('Handwritten Digits', 'most_recent');
90+
SELECT * FROM pgml_rust.deploy('Handwritten Digits', 'rollback');
91+
SELECT * FROM pgml_rust.deploy('Handwritten Digits', 'best_score', 'svm');
92+
93+
-- check out the improved predictions
94+
SELECT target, pgml_rust.predict('Handwritten Digits', image) AS prediction
95+
FROM pgml_rust.digits
96+
LIMIT 10;

pgml-extension/pgml_rust/tests/regression.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ SELECT * FROM pgml_rust.train('Diabetes Progression', algorithm => 'random_fores
8181
-- gradient boosting
8282
SELECT * FROM pgml_rust.train('Diabetes Progression', algorithm => 'xgboost', hyperparams => '{"n_estimators": 10}');
8383
SELECT * FROM pgml_rust.train('Diabetes Progression', algorithm => 'xgboost_random_forest', hyperparams => '{"n_estimators": 10}');
84-
-- SELECT * FROM pgml_rust.train('Diabetes Progression', algorithm => 'lightgbm', hyperparams => '{"n_estimators": 1}');
84+
SELECT * FROM pgml_rust.train('Diabetes Progression', algorithm => 'lightgbm', hyperparams => '{"n_estimators": 100}');
8585
-- Histogram Gradient Boosting is too expensive for normal tests on even a toy dataset
8686
-- SELECT * FROM pgml_rust.train('Diabetes Progression', algorithm => 'hist_gradient_boosting', hyperparams => '{"max_iter": 10}');
8787

0 commit comments

Comments
 (0)