@@ -9,6 +9,7 @@ use once_cell::sync::Lazy;
99use pgx:: * ;
1010use pyo3:: prelude:: * ;
1111
12+ use crate :: engines:: lightgbm:: { lightgbm_load, lightgbm_predict, lightgbm_test} ;
1213use crate :: engines:: sklearn:: { sklearn_load, sklearn_predict, sklearn_test} ;
1314use crate :: engines:: smartcore:: { smartcore_load, smartcore_predict, smartcore_test} ;
1415use crate :: engines:: xgboost:: { xgboost_load, xgboost_predict, xgboost_test} ;
@@ -32,9 +33,9 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Arc<Box<dyn Estimat
3233 }
3334 }
3435
35- let ( task, algorithm) = Spi :: get_two_with_args :: < String , String > (
36+ let ( task, algorithm, num_features ) = Spi :: get_three_with_args :: < String , String , i32 > (
3637 "
37- SELECT projects.task::TEXT, models.algorithm::TEXT
38+ SELECT projects.task::TEXT, models.algorithm::TEXT, models.num_features
3839 FROM pgml_rust.models
3940 JOIN pgml_rust.projects
4041 ON projects.id = models.project_id
@@ -59,6 +60,8 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Arc<Box<dyn Estimat
5960 } ) )
6061 . unwrap ( ) ;
6162
63+ let num_features = num_features. unwrap ( ) ;
64+
6265 let ( data, hyperparams, engine) = Spi :: get_three_with_args :: < Vec < u8 > , JsonB , String > (
6366 "SELECT data, hyperparams, engine::TEXT FROM pgml_rust.models
6467 INNER JOIN pgml_rust.files
@@ -83,7 +86,8 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Arc<Box<dyn Estimat
8386 let estimator: Box < dyn Estimator > = match engine {
8487 Engine :: xgboost => Box :: new ( xgboost_load ( & data) ) ,
8588 Engine :: smartcore => smartcore_load ( & data, task, algorithm, & hyperparams) ,
86- Engine :: sklearn => Box :: new ( sklearn_load ( & data) ) ,
89+ Engine :: sklearn => Box :: new ( sklearn_load ( & data, num_features) ) ,
90+ Engine :: lightgbm => Box :: new ( lightgbm_load ( & data) ) ,
8791 _ => todo ! ( ) ,
8892 } ;
8993
@@ -336,3 +340,67 @@ impl Estimator for SklearnBox {
336340 score[ 0 ]
337341 }
338342}
343+
344+ /// LightGBM implementation of the Estimator trait.
345+ pub struct LightgbmBox {
346+ contents : Box < lightgbm:: Booster > ,
347+ }
348+
349+ impl LightgbmBox {
350+ pub fn new ( contents : lightgbm:: Booster ) -> Self {
351+ LightgbmBox {
352+ contents : Box :: new ( contents) ,
353+ }
354+ }
355+ }
356+
357+ impl std:: ops:: Deref for LightgbmBox {
358+ type Target = lightgbm:: Booster ;
359+
360+ fn deref ( & self ) -> & Self :: Target {
361+ self . contents . as_ref ( )
362+ }
363+ }
364+
365+ impl std:: ops:: DerefMut for LightgbmBox {
366+ fn deref_mut ( & mut self ) -> & mut Self :: Target {
367+ self . contents . as_mut ( )
368+ }
369+ }
370+
371+ unsafe impl Send for LightgbmBox { }
372+ unsafe impl Sync for LightgbmBox { }
373+
374+ impl std:: fmt:: Debug for LightgbmBox {
375+ fn fmt (
376+ & self ,
377+ formatter : & mut std:: fmt:: Formatter < ' _ > ,
378+ ) -> std:: result:: Result < ( ) , std:: fmt:: Error > {
379+ formatter. debug_struct ( "LightgbmBox" ) . finish ( )
380+ }
381+ }
382+
383+ impl serde:: Serialize for LightgbmBox {
384+ fn serialize < S > ( & self , _serializer : S ) -> Result < S :: Ok , S :: Error >
385+ where
386+ S : serde:: Serializer ,
387+ {
388+ panic ! ( "This is not used because we don't use Serde to serialize or deserialize XGBoost, it comes with its own." )
389+ }
390+ }
391+
392+ #[ typetag:: serialize]
393+ impl Estimator for LightgbmBox {
394+ fn test ( & self , task : Task , dataset : & Dataset ) -> HashMap < String , f32 > {
395+ let y_hat =
396+ Array1 :: from_shape_vec ( dataset. num_test_rows , lightgbm_test ( self , dataset) ) . unwrap ( ) ;
397+ let y_test =
398+ Array1 :: from_shape_vec ( dataset. num_test_rows , dataset. y_test ( ) . to_vec ( ) ) . unwrap ( ) ;
399+
400+ calc_metrics ( & y_test, & y_hat, dataset. distinct_labels ( ) , task)
401+ }
402+
403+ fn predict ( & self , features : Vec < f32 > ) -> f32 {
404+ lightgbm_predict ( self , & features)
405+ }
406+ }
0 commit comments