Skip to content

Commit fbbb67a

Browse files
authored
multiclass in rust (#315)
1 parent df443f0 commit fbbb67a

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

pgml-extension/pgml_rust/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ pgx = "0.4.5"
2020
once_cell = "1"
2121
rand = "0.8"
2222
xgboost = { path = "rust-xgboost" }
23-
smartcore = { version = "0.2.0", features = ["serde", "ndarray-bindings"] }
23+
smartcore = { git="https://github.com/postgresml/smartcore.git", branch="montana/multiclass", features = ["serde", "ndarray-bindings"] }
2424
ndarray = { version = "0.15.6", features = ["serde", "blas"] }
2525
blas = { version = "0.22.0" }
2626
blas-src = { version = "0.8", features = ["openblas"] }

pgml-extension/pgml_rust/src/orm/estimator.rs

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ fn test_smartcore(
114114
.unwrap();
115115
let y_test = Array1::from_shape_vec(dataset.num_test_rows, dataset.y_test().to_vec()).unwrap();
116116
let y_hat = smartcore::api::Predictor::predict(predictor, &x_test).unwrap();
117-
calc_metrics(&y_test, &y_hat, task)
117+
calc_metrics(&y_test, &y_hat, dataset.distinct_labels(), task)
118118
}
119119

120120
fn predict_smartcore(
@@ -125,7 +125,7 @@ fn predict_smartcore(
125125
smartcore::api::Predictor::predict(predictor, &features).unwrap()[0]
126126
}
127127

128-
fn calc_metrics(y_test: &Array1<f32>, y_hat: &Array1<f32>, task: Task) -> HashMap<String, f32> {
128+
fn calc_metrics(y_test: &Array1<f32>, y_hat: &Array1<f32>, distinct_labels: u32, task: Task) -> HashMap<String, f32> {
129129
let mut results = HashMap::new();
130130
match task {
131131
Task::regression => {
@@ -148,18 +148,20 @@ fn calc_metrics(y_test: &Array1<f32>, y_hat: &Array1<f32>, task: Task) -> HashMa
148148
"precision".to_string(),
149149
smartcore::metrics::precision(y_test, y_hat),
150150
);
151-
results.insert(
152-
"accuracy".to_string(),
153-
smartcore::metrics::accuracy(y_test, y_hat),
154-
);
155-
results.insert(
156-
"roc_auc_score".to_string(),
157-
smartcore::metrics::roc_auc_score(y_test, y_hat),
158-
);
159151
results.insert(
160152
"recall".to_string(),
161153
smartcore::metrics::recall(y_test, y_hat),
162154
);
155+
results.insert(
156+
"accuracy".to_string(),
157+
smartcore::metrics::accuracy(y_test, y_hat),
158+
);
159+
if distinct_labels == 2 {
160+
results.insert(
161+
"roc_auc_score".to_string(),
162+
smartcore::metrics::roc_auc_score(y_test, y_hat),
163+
);
164+
}
163165
}
164166
}
165167
results
@@ -247,7 +249,7 @@ impl Estimator for BoosterBox {
247249
Array1::from_shape_vec(dataset.num_test_rows, dataset.y_test().to_vec()).unwrap();
248250
let y_hat = self.contents.predict(&features).unwrap();
249251
let y_hat = Array1::from_shape_vec(dataset.num_test_rows, y_hat).unwrap();
250-
calc_metrics(&y_test, &y_hat, task)
252+
calc_metrics(&y_test, &y_hat, dataset.distinct_labels(), task)
251253
}
252254

253255
fn predict(&self, features: Vec<f32>) -> f32 {

0 commit comments

Comments
 (0)