Skip to content

Commit 0df498e

Browse files
authored
Starting lightgbm (#335)
1 parent 46f4936 commit 0df498e

File tree

10 files changed

+184
-28
lines changed

10 files changed

+184
-28
lines changed

pgml-extension/examples/regression.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
-- Exit on error (psql)
1313
\set ON_ERROR_STOP true
14+
\timing
1415

1516
SELECT pgml.load_dataset('diabetes');
1617

pgml-extension/pgml_rust/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pg_test = []
1818
[dependencies]
1919
pgx = { git="https://github.com/postgresml/pgx.git", branch="master" }
2020
xgboost = { git="https://github.com/postgresml/rust-xgboost.git" }
21-
smartcore = { git="https://github.com/smartcorelib/smartcore.git", branch="development", features = ["serde", "ndarray-bindings"] }
21+
smartcore = { git="https://github.com/smartcorelib/smartcore.git", branch="main", features = ["serde", "ndarray-bindings"] }
2222
once_cell = "1"
2323
rand = "0.8"
2424
ndarray = { version = "0.15.6", features = ["serde", "blas"] }
@@ -31,6 +31,7 @@ rmp-serde = { version = "1.1.0" }
3131
typetag = "0.2"
3232
pyo3 = { version = "0.17", features = ["auto-initialize"] }
3333
heapless = "0.7.13"
34+
lightgbm = { git="https://github.com/postgresml/lightgbm-rs" }
3435
parking_lot = "0.12"
3536

3637
[dev-dependencies]

pgml-extension/pgml_rust/sql/schema.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ CREATE TABLE IF NOT EXISTS pgml_rust.models(
7676
id BIGSERIAL PRIMARY KEY,
7777
project_id BIGINT NOT NULL,
7878
snapshot_id BIGINT NOT NULL,
79+
num_features INT NOT NULL,
7980
algorithm TEXT NOT NULL,
8081
engine TEXT DEFAULT 'sklearn',
8182
hyperparams JSONB NOT NULL,

pgml-extension/pgml_rust/src/engines/engine.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use serde::Deserialize;
66
pub enum Engine {
77
xgboost,
88
torch,
9-
lightdbm,
9+
lightgbm,
1010
sklearn,
1111
smartcore,
1212
linfa,
@@ -19,7 +19,7 @@ impl std::str::FromStr for Engine {
1919
match input {
2020
"xgboost" => Ok(Engine::xgboost),
2121
"torch" => Ok(Engine::torch),
22-
"lightdbm" => Ok(Engine::lightdbm),
22+
"lightgbm" => Ok(Engine::lightgbm),
2323
"sklearn" => Ok(Engine::sklearn),
2424
"smartcore" => Ok(Engine::smartcore),
2525
"linfa" => Ok(Engine::linfa),
@@ -33,7 +33,7 @@ impl std::string::ToString for Engine {
3333
match *self {
3434
Engine::xgboost => "xgboost".to_string(),
3535
Engine::torch => "torch".to_string(),
36-
Engine::lightdbm => "lightdbm".to_string(),
36+
Engine::lightgbm => "lightgbm".to_string(),
3737
Engine::sklearn => "sklearn".to_string(),
3838
Engine::smartcore => "smartcore".to_string(),
3939
Engine::linfa => "linfa".to_string(),
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
use lightgbm;
2+
3+
use crate::engines::Hyperparams;
4+
use crate::orm::dataset::Dataset;
5+
use crate::orm::estimator::LightgbmBox;
6+
use crate::orm::task::Task;
7+
use serde_json::json;
8+
9+
pub fn lightgbm_train(task: Task, dataset: &Dataset, hyperparams: &Hyperparams) -> LightgbmBox {
10+
let x_train = dataset.x_train();
11+
let y_train = dataset.y_train();
12+
let objective = match task {
13+
Task::regression => "regression",
14+
Task::classification => {
15+
let distinct_labels = dataset.distinct_labels();
16+
17+
if distinct_labels > 2 {
18+
"multiclass"
19+
} else {
20+
"binary"
21+
}
22+
}
23+
};
24+
25+
let dataset =
26+
lightgbm::Dataset::from_vec(x_train, y_train, dataset.num_features as i32).unwrap();
27+
28+
let bst = lightgbm::Booster::train(
29+
dataset,
30+
&json! {{
31+
"objective": objective,
32+
}},
33+
)
34+
.unwrap();
35+
36+
LightgbmBox::new(bst)
37+
}
38+
39+
/// Serialize an LightGBm estimator into bytes.
40+
pub fn lightgbm_save(estimator: &LightgbmBox) -> Vec<u8> {
41+
let r: u64 = rand::random();
42+
let path = format!("/tmp/pgml_rust_{}.bin", r);
43+
44+
estimator.save_file(&path).unwrap();
45+
46+
let bytes = std::fs::read(&path).unwrap();
47+
48+
std::fs::remove_file(&path).unwrap();
49+
50+
bytes
51+
}
52+
53+
/// Load an LightGBM estimator from bytes.
54+
pub fn lightgbm_load(data: &Vec<u8>) -> LightgbmBox {
55+
// Oh boy
56+
let r: u64 = rand::random();
57+
let path = format!("/tmp/pgml_rust_{}.bin", r);
58+
59+
std::fs::write(&path, &data).unwrap();
60+
61+
let bst = lightgbm::Booster::from_file(&path).unwrap();
62+
LightgbmBox::new(bst)
63+
}
64+
65+
/// Validate a trained estimator against the test dataset.
66+
pub fn lightgbm_test(estimator: &LightgbmBox, dataset: &Dataset) -> Vec<f32> {
67+
let x_test = dataset.x_test();
68+
let num_features = dataset.num_features;
69+
70+
estimator.predict(&x_test, num_features as i32).unwrap()
71+
}
72+
73+
/// Predict a novel datapoint using the LightGBM estimator.
74+
pub fn lightgbm_predict(estimator: &LightgbmBox, x: &[f32]) -> f32 {
75+
estimator.predict(&x, x.len() as i32).unwrap()[0]
76+
}

pgml-extension/pgml_rust/src/engines/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pub mod engine;
2+
pub mod lightgbm;
23
pub mod sklearn;
34
pub mod smartcore;
45
pub mod xgboost;

pgml-extension/pgml_rust/src/engines/sklearn.rs

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -163,25 +163,12 @@ pub fn sklearn_test(estimator: &SklearnBox, dataset: &Dataset) -> Vec<f32> {
163163
}
164164

165165
pub fn sklearn_predict(estimator: &SklearnBox, x: &[f32]) -> Vec<f32> {
166-
let module = include_str!(concat!(
167-
env!("CARGO_MANIFEST_DIR"),
168-
"/src/engines/wrappers.py"
169-
));
170-
171166
let y_hat: Vec<f32> = Python::with_gil(|py| -> Vec<f32> {
172-
let module = PyModule::from_code(py, module, "", "").unwrap();
173-
let predictor = module.getattr("predictor").unwrap();
174-
let predict = predictor
175-
.call1(PyTuple::new(
176-
py,
177-
&[estimator.contents.as_ref(), &x.len().into_py(py)],
178-
))
179-
.unwrap();
180-
181-
predict
182-
.call1(PyTuple::new(py, &[x]))
167+
estimator
168+
.contents
169+
.call1(py, PyTuple::new(py, &[x]))
183170
.unwrap()
184-
.extract()
171+
.extract(py)
185172
.unwrap()
186173
});
187174

@@ -204,7 +191,7 @@ pub fn sklearn_save(estimator: &SklearnBox) -> Vec<u8> {
204191
})
205192
}
206193

207-
pub fn sklearn_load(data: &Vec<u8>) -> SklearnBox {
194+
pub fn sklearn_load(data: &Vec<u8>, num_features: i32) -> SklearnBox {
208195
let module = include_str!(concat!(
209196
env!("CARGO_MANIFEST_DIR"),
210197
"/src/engines/wrappers.py"
@@ -218,6 +205,13 @@ pub fn sklearn_load(data: &Vec<u8>) -> SklearnBox {
218205
.unwrap()
219206
.extract()
220207
.unwrap();
208+
let predict = module.getattr("predictor").unwrap();
209+
let estimator = predict
210+
.call1(PyTuple::new(py, &[estimator, num_features.into_py(py)]))
211+
.unwrap()
212+
.extract()
213+
.unwrap();
214+
221215
SklearnBox::new(estimator)
222216
})
223217
}

pgml-extension/pgml_rust/src/orm/algorithm.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ pub enum Algorithm {
3636
gradient_boosting_trees,
3737
hist_gradient_boosting,
3838
linear_svm,
39+
lightgbm,
3940
}
4041

4142
impl std::str::FromStr for Algorithm {
@@ -75,6 +76,7 @@ impl std::str::FromStr for Algorithm {
7576
"gradient_boosting_trees" => Ok(Algorithm::gradient_boosting_trees),
7677
"hist_gradient_boosting" => Ok(Algorithm::hist_gradient_boosting),
7778
"linear_svm" => Ok(Algorithm::linear_svm),
79+
"lightgbm" => Ok(Algorithm::lightgbm),
7880
_ => Err(()),
7981
}
8082
}
@@ -117,6 +119,7 @@ impl std::string::ToString for Algorithm {
117119
Algorithm::gradient_boosting_trees => "gradient_boosting_trees".to_string(),
118120
Algorithm::hist_gradient_boosting => "hist_gradient_boosting".to_string(),
119121
Algorithm::linear_svm => "linear_svm".to_string(),
122+
Algorithm::lightgbm => "lightgbm".to_string(),
120123
}
121124
}
122125
}

pgml-extension/pgml_rust/src/orm/estimator.rs

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use once_cell::sync::Lazy;
99
use pgx::*;
1010
use pyo3::prelude::*;
1111

12+
use crate::engines::lightgbm::{lightgbm_load, lightgbm_predict, lightgbm_test};
1213
use crate::engines::sklearn::{sklearn_load, sklearn_predict, sklearn_test};
1314
use crate::engines::smartcore::{smartcore_load, smartcore_predict, smartcore_test};
1415
use crate::engines::xgboost::{xgboost_load, xgboost_predict, xgboost_test};
@@ -32,9 +33,9 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Arc<Box<dyn Estimat
3233
}
3334
}
3435

35-
let (task, algorithm) = Spi::get_two_with_args::<String, String>(
36+
let (task, algorithm, num_features) = Spi::get_three_with_args::<String, String, i32>(
3637
"
37-
SELECT projects.task::TEXT, models.algorithm::TEXT
38+
SELECT projects.task::TEXT, models.algorithm::TEXT, models.num_features
3839
FROM pgml_rust.models
3940
JOIN pgml_rust.projects
4041
ON projects.id = models.project_id
@@ -59,6 +60,8 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Arc<Box<dyn Estimat
5960
}))
6061
.unwrap();
6162

63+
let num_features = num_features.unwrap();
64+
6265
let (data, hyperparams, engine) = Spi::get_three_with_args::<Vec<u8>, JsonB, String>(
6366
"SELECT data, hyperparams, engine::TEXT FROM pgml_rust.models
6467
INNER JOIN pgml_rust.files
@@ -83,7 +86,8 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Arc<Box<dyn Estimat
8386
let estimator: Box<dyn Estimator> = match engine {
8487
Engine::xgboost => Box::new(xgboost_load(&data)),
8588
Engine::smartcore => smartcore_load(&data, task, algorithm, &hyperparams),
86-
Engine::sklearn => Box::new(sklearn_load(&data)),
89+
Engine::sklearn => Box::new(sklearn_load(&data, num_features)),
90+
Engine::lightgbm => Box::new(lightgbm_load(&data)),
8791
_ => todo!(),
8892
};
8993

@@ -336,3 +340,67 @@ impl Estimator for SklearnBox {
336340
score[0]
337341
}
338342
}
343+
344+
/// LightGBM implementation of the Estimator trait.
345+
pub struct LightgbmBox {
346+
contents: Box<lightgbm::Booster>,
347+
}
348+
349+
impl LightgbmBox {
350+
pub fn new(contents: lightgbm::Booster) -> Self {
351+
LightgbmBox {
352+
contents: Box::new(contents),
353+
}
354+
}
355+
}
356+
357+
impl std::ops::Deref for LightgbmBox {
358+
type Target = lightgbm::Booster;
359+
360+
fn deref(&self) -> &Self::Target {
361+
self.contents.as_ref()
362+
}
363+
}
364+
365+
impl std::ops::DerefMut for LightgbmBox {
366+
fn deref_mut(&mut self) -> &mut Self::Target {
367+
self.contents.as_mut()
368+
}
369+
}
370+
371+
unsafe impl Send for LightgbmBox {}
372+
unsafe impl Sync for LightgbmBox {}
373+
374+
impl std::fmt::Debug for LightgbmBox {
375+
fn fmt(
376+
&self,
377+
formatter: &mut std::fmt::Formatter<'_>,
378+
) -> std::result::Result<(), std::fmt::Error> {
379+
formatter.debug_struct("LightgbmBox").finish()
380+
}
381+
}
382+
383+
impl serde::Serialize for LightgbmBox {
384+
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
385+
where
386+
S: serde::Serializer,
387+
{
388+
panic!("This is not used because we don't use Serde to serialize or deserialize XGBoost, it comes with its own.")
389+
}
390+
}
391+
392+
#[typetag::serialize]
393+
impl Estimator for LightgbmBox {
394+
fn test(&self, task: Task, dataset: &Dataset) -> HashMap<String, f32> {
395+
let y_hat =
396+
Array1::from_shape_vec(dataset.num_test_rows, lightgbm_test(self, dataset)).unwrap();
397+
let y_test =
398+
Array1::from_shape_vec(dataset.num_test_rows, dataset.y_test().to_vec()).unwrap();
399+
400+
calc_metrics(&y_test, &y_hat, dataset.distinct_labels(), task)
401+
}
402+
403+
fn predict(&self, features: Vec<f32>) -> f32 {
404+
lightgbm_predict(self, &features)
405+
}
406+
}

pgml-extension/pgml_rust/src/orm/model.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use crate::orm::Project;
1111
use crate::orm::Search;
1212
use crate::orm::Snapshot;
1313

14+
use crate::engines::lightgbm::{lightgbm_save, lightgbm_train};
1415
use crate::engines::sklearn::{sklearn_save, sklearn_search, sklearn_train};
1516
use crate::engines::smartcore::{smartcore_save, smartcore_train};
1617
use crate::engines::xgboost::{xgboost_save, xgboost_train};
@@ -51,15 +52,18 @@ impl Model {
5152
Some(engine) => engine,
5253
None => match algorithm {
5354
Algorithm::xgboost => Engine::xgboost,
55+
Algorithm::lightgbm => Engine::lightgbm,
5456
_ => Engine::sklearn,
5557
},
5658
};
5759

60+
let dataset = snapshot.dataset();
61+
5862
// Create the model record.
5963
Spi::connect(|client| {
6064
let result = client.select("
61-
INSERT INTO pgml_rust.models (project_id, snapshot_id, algorithm, hyperparams, status, search, search_params, search_args, engine)
62-
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
65+
INSERT INTO pgml_rust.models (project_id, snapshot_id, algorithm, hyperparams, status, search, search_params, search_args, engine, num_features)
66+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
6367
RETURNING id, project_id, snapshot_id, algorithm, hyperparams, status, metrics, search, search_params, search_args, created_at, updated_at;",
6468
Some(1),
6569
Some(vec![
@@ -75,6 +79,7 @@ impl Model {
7579
(PgBuiltInOids::JSONBOID.oid(), search_params.into_datum()),
7680
(PgBuiltInOids::JSONBOID.oid(), search_args.into_datum()),
7781
(PgBuiltInOids::TEXTOID.oid(), engine.to_string().into_datum()),
82+
(PgBuiltInOids::INT4OID.oid(), dataset.num_features.into_datum()),
7883
])
7984
).first();
8085
if !result.is_empty() {
@@ -100,7 +105,6 @@ impl Model {
100105
});
101106

102107
let mut model = model.unwrap();
103-
let dataset = snapshot.dataset();
104108

105109
model.fit(project, &dataset);
106110
model.test(project, &dataset);
@@ -159,6 +163,13 @@ impl Model {
159163
(estimator, bytes)
160164
}
161165

166+
Engine::lightgbm => {
167+
let estimator = lightgbm_train(project.task, dataset, &hyperparams);
168+
let bytes = lightgbm_save(&estimator);
169+
170+
(Box::new(estimator), bytes)
171+
}
172+
162173
_ => todo!(),
163174
};
164175

0 commit comments

Comments
 (0)