@@ -102,7 +102,6 @@ pub fn python_version() -> String {
102102fn version ( ) -> String {
103103 crate :: VERSION . to_string ( )
104104}
105-
106105#[ allow( clippy:: too_many_arguments) ]
107106#[ pg_extern]
108107fn train (
@@ -126,6 +125,50 @@ fn train(
126125 name ! ( algorithm, String ) ,
127126 name ! ( deployed, bool ) ,
128127 ) ,
128+ > {
129+ train_joint (
130+ project_name,
131+ task,
132+ relation_name,
133+ match y_column_name {
134+ Some ( y_column_name) => Some ( vec ! [ y_column_name. to_string( ) ] ) ,
135+ None => None ,
136+ } ,
137+ algorithm,
138+ hyperparams,
139+ search,
140+ search_params,
141+ search_args,
142+ test_size,
143+ test_sampling,
144+ runtime,
145+ automatic_deploy,
146+ )
147+ }
148+
149+ #[ allow( clippy:: too_many_arguments) ]
150+ #[ pg_extern]
151+ fn train_joint (
152+ project_name : & str ,
153+ task : Option < default ! ( Task , "NULL" ) > ,
154+ relation_name : Option < default ! ( & str , "NULL" ) > ,
155+ y_column_name : Option < default ! ( Vec <String >, "NULL" ) > ,
156+ algorithm : default ! ( Algorithm , "'linear'" ) ,
157+ hyperparams : default ! ( JsonB , "'{}'" ) ,
158+ search : Option < default ! ( Search , "NULL" ) > ,
159+ search_params : default ! ( JsonB , "'{}'" ) ,
160+ search_args : default ! ( JsonB , "'{}'" ) ,
161+ test_size : default ! ( f32 , 0.25 ) ,
162+ test_sampling : default ! ( Sampling , "'last'" ) ,
163+ runtime : Option < default ! ( Runtime , "NULL" ) > ,
164+ automatic_deploy : Option < default ! ( bool , true ) > ,
165+ ) -> impl std:: iter:: Iterator <
166+ Item = (
167+ name ! ( project, String ) ,
168+ name ! ( task, String ) ,
169+ name ! ( algorithm, String ) ,
170+ name ! ( deployed, bool ) ,
171+ ) ,
129172> {
130173 let project = match Project :: find_by_name ( project_name) {
131174 Some ( project) => project,
@@ -364,6 +407,11 @@ fn deploy(
364407
365408#[ pg_extern]
366409fn predict ( project_name : & str , features : Vec < f32 > ) -> f32 {
410+ predict_joint ( project_name, features) [ 0 ]
411+ }
412+
413+ #[ pg_extern]
414+ fn predict_joint ( project_name : & str , features : Vec < f32 > ) -> Vec < f32 > {
367415 let mut projects = PROJECT_NAME_TO_PROJECT_ID . lock ( ) ;
368416 let project_id = match projects. get ( project_name) {
369417 Some ( project_id) => * project_id,
@@ -415,7 +463,12 @@ fn snapshot(
415463 test_size : default ! ( f32 , 0.25 ) ,
416464 test_sampling : default ! ( Sampling , "'last'" ) ,
417465) -> impl std:: iter:: Iterator < Item = ( name ! ( relation, String ) , name ! ( y_column_name, String ) ) > {
418- Snapshot :: create ( relation_name, y_column_name, test_size, test_sampling) ;
466+ Snapshot :: create (
467+ relation_name,
468+ vec ! [ y_column_name. to_string( ) ] ,
469+ test_size,
470+ test_sampling,
471+ ) ;
419472 vec ! [ ( relation_name. to_string( ) , y_column_name. to_string( ) ) ] . into_iter ( )
420473}
421474
@@ -442,7 +495,7 @@ fn load_dataset(
442495#[ pg_extern]
443496fn model_predict ( model_id : i64 , features : Vec < f32 > ) -> f32 {
444497 let estimator = crate :: orm:: file:: find_deployed_estimator_by_model_id ( model_id) ;
445- estimator. predict ( & features)
498+ estimator. predict ( & features) [ 0 ]
446499}
447500
448501#[ pg_extern]
0 commit comments