Skip to content

Commit f3dd5ea

Browse files
authored
Hot fixing 2.0 (#409)
1 parent 563852e commit f3dd5ea

File tree

5 files changed

+38
-4
lines changed

5 files changed

+38
-4
lines changed

pgml-extension/sql/schema.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
--- Validate we have the necessary Python dependencies.
33
---
44
SELECT pgml.validate_python_dependencies();
5+
SELECT pgml.validate_shared_library();
56

67
---
78
--- Track of updates to data

pgml-extension/src/api.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,21 @@ pub fn validate_python_dependencies() {
6060
#[pg_extern]
6161
pub fn validate_python_dependencies() {}
6262

63+
#[pg_extern]
64+
pub fn validate_shared_library() {
65+
let shared_preload_libraries: String = Spi::get_one(
66+
"SELECT setting
67+
FROM pg_settings
68+
WHERE name = 'shared_preload_libraries'
69+
LIMIT 1",
70+
)
71+
.unwrap();
72+
73+
if !shared_preload_libraries.contains("pgml") {
74+
error!("`pgml` must be added to `shared_preload_libraries` setting or models cannot be deployed");
75+
}
76+
}
77+
6378
#[cfg(feature = "python")]
6479
#[pg_extern]
6580
pub fn sklearn_version() -> String {
@@ -102,6 +117,7 @@ pub fn python_version() -> String {
102117
fn version() -> String {
103118
crate::VERSION.to_string()
104119
}
120+
105121
#[allow(clippy::too_many_arguments)]
106122
#[pg_extern]
107123
fn train(

pgml-extension/src/bindings/xgboost.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use xgboost::parameters::*;
99
use xgboost::{Booster, DMatrix};
1010

1111
use crate::orm::dataset::Dataset;
12+
use crate::orm::task::Task;
1213
use crate::orm::Hyperparams;
1314

1415
use crate::bindings::Bindings;
@@ -133,20 +134,27 @@ fn get_tree_params(hyperparams: &Hyperparams) -> tree::TreeBoosterParameters {
133134
}
134135

135136
pub fn fit_regression(dataset: &Dataset, hyperparams: &Hyperparams) -> Box<dyn Bindings> {
136-
fit(dataset, hyperparams, learning::Objective::RegLinear)
137+
fit(
138+
dataset,
139+
hyperparams,
140+
Task::regression,
141+
learning::Objective::RegLinear,
142+
)
137143
}
138144

139145
pub fn fit_classification(dataset: &Dataset, hyperparams: &Hyperparams) -> Box<dyn Bindings> {
140146
fit(
141147
dataset,
142148
hyperparams,
149+
Task::classification,
143150
learning::Objective::MultiSoftprob(dataset.num_distinct_labels.try_into().unwrap()),
144151
)
145152
}
146153

147154
fn fit(
148155
dataset: &Dataset,
149156
hyperparams: &Hyperparams,
157+
task: Task,
150158
objective: learning::Objective,
151159
) -> Box<dyn Bindings> {
152160
// split the train/test data into DMatrix
@@ -205,7 +213,11 @@ fn fit(
205213
Box::new(Estimator {
206214
estimator: booster,
207215
num_features: dataset.num_features,
208-
num_classes: dataset.num_distinct_labels,
216+
num_classes: if task == Task::regression {
217+
1
218+
} else {
219+
dataset.num_distinct_labels
220+
},
209221
})
210222
}
211223

pgml-extension/src/orm/dataset.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ impl Display for Dataset {
2222
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
2323
write!(
2424
f,
25-
"Dataset {{ num_features: {}, num_labels: {}, num_rows: {}, num_train_rows: {}, num_test_rows: {} }}",
26-
self.num_features, self.num_labels, self.num_rows, self.num_train_rows, self.num_test_rows,
25+
"Dataset {{ num_features: {}, num_labels: {}, num_distinct_labels: {}, num_rows: {}, num_train_rows: {}, num_test_rows: {} }}",
26+
self.num_features, self.num_labels, self.num_distinct_labels, self.num_rows, self.num_train_rows, self.num_test_rows,
2727
)
2828
}
2929
}

pgml-extension/src/orm/snapshot.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,11 @@ impl Snapshot {
196196
let nullable = row[3].value::<bool>().unwrap();
197197
let position = row[4].value::<i32>().unwrap() as usize;
198198
let label = self.y_column_name.contains(&name);
199+
200+
if nullable {
201+
warning!("Column \"{}\" can contain nulls which can cause errors", name);
202+
}
203+
199204
columns.push(
200205
Column {
201206
name,

0 commit comments

Comments
 (0)