Skip to content

Commit eb65899

Browse files
authored
fix clippy lints (#838)
1 parent 322b72e commit eb65899

File tree

11 files changed

+79
-94
lines changed

11 files changed

+79
-94
lines changed

pgml-extension/src/api.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -584,8 +584,6 @@ pub fn embed_batch(
584584
/// ```
585585
#[pg_extern(immutable, parallel_safe, name = "clear_gpu_cache")]
586586
pub fn clear_gpu_cache(memory_usage: default!(Option<f32>, "NULL")) -> bool {
587-
let memory_usage: Option<f32> =
588-
memory_usage.map(|memory_usage| memory_usage.try_into().unwrap());
589587
crate::bindings::transformers::clear_gpu_cache(memory_usage)
590588
}
591589

pgml-extension/src/bindings/langchain.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ pub fn chunk(splitter: &str, text: &str, kwargs: &serde_json::Value) -> Vec<Stri
2020
let kwargs = serde_json::to_string(kwargs).unwrap();
2121

2222
Python::with_gil(|py| -> Vec<String> {
23-
let chunk: Py<PyAny> = PY_MODULE.getattr(py, "chunk").unwrap().into();
23+
let chunk: Py<PyAny> = PY_MODULE.getattr(py, "chunk").unwrap();
2424

2525
chunk
2626
.call1(

pgml-extension/src/bindings/lightgbm.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ impl Bindings for Estimator {
118118
{
119119
let r: u64 = rand::random();
120120
let path = format!("/tmp/pgml_{}.bin", r);
121-
std::fs::write(&path, &bytes).unwrap();
121+
std::fs::write(&path, bytes).unwrap();
122122
let mut estimator = lightgbm::Booster::from_file(&path);
123123
if estimator.is_err() {
124124
// backward compatibility w/ 2.0.0

pgml-extension/src/bindings/sklearn.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ fn fit(
355355

356356
let (estimator, predict, predict_proba) =
357357
Python::with_gil(|py| -> (Py<PyAny>, Py<PyAny>, Py<PyAny>) {
358-
let estimator: Py<PyAny> = PY_MODULE.getattr(py, "estimator").unwrap().into();
358+
let estimator: Py<PyAny> = PY_MODULE.getattr(py, "estimator").unwrap();
359359

360360
let train: Py<PyAny> = estimator
361361
.call1(
@@ -373,21 +373,21 @@ fn fit(
373373
.unwrap();
374374

375375
let estimator: Py<PyAny> = train
376-
.call1(py, PyTuple::new(py, &[&dataset.x_train, &dataset.y_train]))
376+
.call1(py, PyTuple::new(py, [&dataset.x_train, &dataset.y_train]))
377377
.unwrap();
378378

379379
let predict: Py<PyAny> = PY_MODULE
380380
.getattr(py, "predictor")
381381
.unwrap()
382-
.call1(py, PyTuple::new(py, &[&estimator]))
382+
.call1(py, PyTuple::new(py, [&estimator]))
383383
.unwrap()
384384
.extract(py)
385385
.unwrap();
386386

387387
let predict_proba: Py<PyAny> = PY_MODULE
388388
.getattr(py, "predictor_proba")
389389
.unwrap()
390-
.call1(py, PyTuple::new(py, &[&estimator]))
390+
.call1(py, PyTuple::new(py, [&estimator]))
391391
.unwrap()
392392
.extract(py)
393393
.unwrap();
@@ -425,7 +425,7 @@ impl Bindings for Estimator {
425425
fn predict(&self, features: &[f32], _num_features: usize, _num_classes: usize) -> Vec<f32> {
426426
Python::with_gil(|py| -> Vec<f32> {
427427
self.predict
428-
.call1(py, PyTuple::new(py, &[features]))
428+
.call1(py, PyTuple::new(py, [features]))
429429
.unwrap()
430430
.extract(py)
431431
.unwrap()
@@ -435,7 +435,7 @@ impl Bindings for Estimator {
435435
fn predict_proba(&self, features: &[f32], _num_features: usize) -> Vec<f32> {
436436
Python::with_gil(|py| -> Vec<f32> {
437437
self.predict_proba
438-
.call1(py, PyTuple::new(py, &[features]))
438+
.call1(py, PyTuple::new(py, [features]))
439439
.unwrap()
440440
.extract(py)
441441
.unwrap()
@@ -446,7 +446,7 @@ impl Bindings for Estimator {
446446
fn to_bytes(&self) -> Vec<u8> {
447447
Python::with_gil(|py| -> Vec<u8> {
448448
let save = PY_MODULE.getattr(py, "save").unwrap();
449-
save.call1(py, PyTuple::new(py, &[&self.estimator]))
449+
save.call1(py, PyTuple::new(py, [&self.estimator]))
450450
.unwrap()
451451
.extract(py)
452452
.unwrap()
@@ -461,23 +461,23 @@ impl Bindings for Estimator {
461461
Python::with_gil(|py| -> Box<dyn Bindings> {
462462
let load = PY_MODULE.getattr(py, "load").unwrap();
463463
let estimator: Py<PyAny> = load
464-
.call1(py, PyTuple::new(py, &[bytes]))
464+
.call1(py, PyTuple::new(py, [bytes]))
465465
.unwrap()
466466
.extract(py)
467467
.unwrap();
468468

469469
let predict: Py<PyAny> = PY_MODULE
470470
.getattr(py, "predictor")
471471
.unwrap()
472-
.call1(py, PyTuple::new(py, &[&estimator]))
472+
.call1(py, PyTuple::new(py, [&estimator]))
473473
.unwrap()
474474
.extract(py)
475475
.unwrap();
476476

477477
let predict_proba: Py<PyAny> = PY_MODULE
478478
.getattr(py, "predictor_proba")
479479
.unwrap()
480-
.call1(py, PyTuple::new(py, &[&estimator]))
480+
.call1(py, PyTuple::new(py, [&estimator]))
481481
.unwrap()
482482
.extract(py)
483483
.unwrap();
@@ -495,13 +495,13 @@ fn sklearn_metric(name: &str, ground_truth: &[f32], y_hat: &[f32]) -> f32 {
495495
Python::with_gil(|py| -> f32 {
496496
let calculate_metric = PY_MODULE.getattr(py, "calculate_metric").unwrap();
497497
let wrapper: Py<PyAny> = calculate_metric
498-
.call1(py, PyTuple::new(py, &[name]))
498+
.call1(py, PyTuple::new(py, [name]))
499499
.unwrap()
500500
.extract(py)
501501
.unwrap();
502502

503503
let score: f32 = wrapper
504-
.call1(py, PyTuple::new(py, &[ground_truth, y_hat]))
504+
.call1(py, PyTuple::new(py, [ground_truth, y_hat]))
505505
.unwrap()
506506
.extract(py)
507507
.unwrap();
@@ -530,13 +530,13 @@ pub fn confusion_matrix(ground_truth: &[f32], y_hat: &[f32]) -> Vec<Vec<f32>> {
530530
Python::with_gil(|py| -> Vec<Vec<f32>> {
531531
let calculate_metric = PY_MODULE.getattr(py, "calculate_metric").unwrap();
532532
let wrapper: Py<PyAny> = calculate_metric
533-
.call1(py, PyTuple::new(py, &["confusion_matrix"]))
533+
.call1(py, PyTuple::new(py, ["confusion_matrix"]))
534534
.unwrap()
535535
.extract(py)
536536
.unwrap();
537537

538538
let matrix: Vec<Vec<f32>> = wrapper
539-
.call1(py, PyTuple::new(py, &[ground_truth, y_hat]))
539+
.call1(py, PyTuple::new(py, [ground_truth, y_hat]))
540540
.unwrap()
541541
.extract(py)
542542
.unwrap();
@@ -549,7 +549,7 @@ pub fn regression_metrics(ground_truth: &[f32], y_hat: &[f32]) -> HashMap<String
549549
Python::with_gil(|py| -> HashMap<String, f32> {
550550
let calculate_metric = PY_MODULE.getattr(py, "regression_metrics").unwrap();
551551
let scores: HashMap<String, f32> = calculate_metric
552-
.call1(py, PyTuple::new(py, &[ground_truth, y_hat]))
552+
.call1(py, PyTuple::new(py, [ground_truth, y_hat]))
553553
.unwrap()
554554
.extract(py)
555555
.unwrap();
@@ -566,7 +566,7 @@ pub fn classification_metrics(
566566
let mut scores = Python::with_gil(|py| -> HashMap<String, f32> {
567567
let calculate_metric = PY_MODULE.getattr(py, "classification_metrics").unwrap();
568568
let scores: HashMap<String, f32> = calculate_metric
569-
.call1(py, PyTuple::new(py, &[ground_truth, y_hat]))
569+
.call1(py, PyTuple::new(py, [ground_truth, y_hat]))
570570
.unwrap()
571571
.extract(py)
572572
.unwrap();
@@ -591,7 +591,7 @@ pub fn cluster_metrics(
591591
let calculate_metric = PY_MODULE.getattr(py, "cluster_metrics").unwrap();
592592

593593
let scores: HashMap<String, f32> = calculate_metric
594-
.call1(py, (num_features, PyTuple::new(py, &[inputs, labels])))
594+
.call1(py, (num_features, PyTuple::new(py, [inputs, labels])))
595595
.unwrap()
596596
.extract(py)
597597
.unwrap();

pgml-extension/src/bindings/transformers.rs

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
use std::collections::HashMap;
21
use std::io::Write;
32
use std::path::PathBuf;
43
use std::str::FromStr;
4+
use std::{collections::HashMap, path::Path};
55

66
use once_cell::sync::Lazy;
77
use pgrx::*;
@@ -33,7 +33,7 @@ pub fn transform(
3333
let inputs = serde_json::to_string(&inputs).unwrap();
3434

3535
let results = Python::with_gil(|py| -> String {
36-
let transform: Py<PyAny> = PY_MODULE.getattr(py, "transform").unwrap().into();
36+
let transform: Py<PyAny> = PY_MODULE.getattr(py, "transform").unwrap();
3737

3838
let result = transform.call1(
3939
py,
@@ -61,7 +61,7 @@ pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: &serde_json::Value) -
6161

6262
let kwargs = serde_json::to_string(kwargs).unwrap();
6363
Python::with_gil(|py| -> Vec<Vec<f32>> {
64-
let embed: Py<PyAny> = PY_MODULE.getattr(py, "embed").unwrap().into();
64+
let embed: Py<PyAny> = PY_MODULE.getattr(py, "embed").unwrap();
6565
let result = embed.call1(
6666
py,
6767
PyTuple::new(
@@ -90,14 +90,14 @@ pub fn tune(
9090
task: &Task,
9191
dataset: TextDataset,
9292
hyperparams: &JsonB,
93-
path: &std::path::PathBuf,
93+
path: &Path,
9494
) -> HashMap<String, f64> {
9595
crate::bindings::venv::activate();
9696

9797
let task = task.to_string();
9898
let hyperparams = serde_json::to_string(&hyperparams.0).unwrap();
9999

100-
let metrics = Python::with_gil(|py| -> HashMap<String, f64> {
100+
Python::with_gil(|py| -> HashMap<String, f64> {
101101
let tune = PY_MODULE.getattr(py, "tune").unwrap();
102102
let result = tune.call1(
103103
py,
@@ -119,8 +119,7 @@ pub fn tune(
119119
Ok(o) => o,
120120
};
121121
result.extract(py).unwrap()
122-
});
123-
metrics
122+
})
124123
}
125124

126125
pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Vec<String> {
@@ -190,7 +189,7 @@ fn dump_model(model_id: i64, dir: PathBuf) {
190189
.append(true)
191190
.open(path)
192191
.unwrap();
193-
file.write(&data).unwrap();
192+
let _num_bytes = file.write(&data).unwrap();
194193
file.flush().unwrap();
195194
}
196195
});
@@ -207,7 +206,7 @@ pub fn load_dataset(
207206
let kwargs = serde_json::to_string(kwargs).unwrap();
208207

209208
let dataset = Python::with_gil(|py| -> String {
210-
let load_dataset: Py<PyAny> = PY_MODULE.getattr(py, "load_dataset").unwrap().into();
209+
let load_dataset: Py<PyAny> = PY_MODULE.getattr(py, "load_dataset").unwrap();
211210
load_dataset
212211
.call1(
213212
py,
@@ -273,9 +272,8 @@ pub fn load_dataset(
273272
let table_count = Spi::get_one_with_args::<i64>("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", vec![
274273
(PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum())
275274
]).unwrap().unwrap();
276-
match table_count {
277-
1 => Spi::run(&format!(r#"DROP TABLE IF EXISTS {table_name}"#)).unwrap(),
278-
_ => (),
275+
if table_count == 1 {
276+
Spi::run(&format!(r#"DROP TABLE IF EXISTS {table_name}"#)).unwrap()
279277
}
280278

281279
Spi::run(&format!(r#"CREATE TABLE {table_name} ({column_types})"#)).unwrap();
@@ -320,7 +318,7 @@ pub fn load_dataset(
320318

321319
pub fn clear_gpu_cache(memory_usage: Option<f32>) -> bool {
322320
Python::with_gil(|py| -> bool {
323-
let clear_gpu_cache: Py<PyAny> = PY_MODULE.getattr(py, "clear_gpu_cache").unwrap().into();
321+
let clear_gpu_cache: Py<PyAny> = PY_MODULE.getattr(py, "clear_gpu_cache").unwrap();
324322
clear_gpu_cache
325323
.call1(py, PyTuple::new(py, &[memory_usage.into_py(py)]))
326324
.unwrap()

pgml-extension/src/bindings/venv.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use pyo3::prelude::*;
77
use pyo3::types::PyTuple;
88
use std::ffi::CStr;
99

10-
static CONFIG_NAME: &'static str = "pgml.venv";
10+
static CONFIG_NAME: &str = "pgml.venv";
1111

1212
static PY_MODULE: Lazy<Py<PyModule>> = Lazy::new(|| {
1313
Python::with_gil(|py| -> Py<PyModule> {
@@ -19,7 +19,7 @@ static PY_MODULE: Lazy<Py<PyModule>> = Lazy::new(|| {
1919

2020
pub fn activate_venv(venv: &str) -> bool {
2121
Python::with_gil(|py| -> bool {
22-
let activate_venv: Py<PyAny> = PY_MODULE.getattr(py, "activate_venv").unwrap().into();
22+
let activate_venv: Py<PyAny> = PY_MODULE.getattr(py, "activate_venv").unwrap();
2323
let result: Py<PyAny> = activate_venv
2424
.call1(py, PyTuple::new(py, &[venv.to_string().into_py(py)]))
2525
.unwrap();

pgml-extension/src/orm/dataset.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,8 @@ fn drop_table_if_exists(table_name: &str) {
9797
let table_count = Spi::get_one_with_args::<i64>("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", vec![
9898
(PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum())
9999
]).unwrap().unwrap();
100-
match table_count {
101-
1 => Spi::run(&format!(r#"DROP TABLE pgml.{table_name} CASCADE"#)).unwrap(),
102-
_ => (),
100+
if table_count == 1 {
101+
Spi::run(&format!(r#"DROP TABLE pgml.{table_name} CASCADE"#)).unwrap();
103102
}
104103
}
105104

pgml-extension/src/orm/model.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -405,9 +405,7 @@ impl Model {
405405
Algorithm::svm => linfa::Svm::fit,
406406
_ => todo!(),
407407
},
408-
Task::cluster => match self.algorithm {
409-
_ => todo!(),
410-
},
408+
Task::cluster => todo!(),
411409
_ => error!("use pgml.tune for transformers tasks"),
412410
},
413411

@@ -570,7 +568,7 @@ impl Model {
570568
#[cfg(all(feature = "python", any(test, feature = "pg_test")))]
571569
{
572570
let sklearn_metrics =
573-
crate::bindings::sklearn::regression_metrics(&y_test, &y_hat);
571+
crate::bindings::sklearn::regression_metrics(y_test, &y_hat);
574572
metrics.insert("sklearn_r2".to_string(), sklearn_metrics["r2"]);
575573
metrics.insert(
576574
"sklearn_mean_absolute_error".to_string(),
@@ -599,7 +597,7 @@ impl Model {
599597
#[cfg(all(feature = "python", any(test, feature = "pg_test")))]
600598
{
601599
let sklearn_metrics = crate::bindings::sklearn::classification_metrics(
602-
&y_test,
600+
y_test,
603601
&y_hat,
604602
dataset.num_distinct_labels,
605603
);
@@ -662,7 +660,7 @@ impl Model {
662660
metrics.insert("mcc".to_string(), confusion_matrix.mcc());
663661
}
664662
Task::cluster => {
665-
#[cfg(all(feature = "python"))]
663+
#[cfg(feature = "python")]
666664
{
667665
let sklearn_metrics = crate::bindings::sklearn::cluster_metrics(
668666
dataset.num_features,
@@ -1104,7 +1102,7 @@ impl Model {
11041102
let element: Result<Option<Vec<f32>>, TryFromDatumError> =
11051103
tuple.get_by_index(index.try_into().unwrap());
11061104
for j in element.as_ref().unwrap().as_ref().unwrap() {
1107-
features.push(*j as f32);
1105+
features.push(*j);
11081106
}
11091107
}
11101108
pgrx_pg_sys::FLOAT8ARRAYOID => {

0 commit comments

Comments
 (0)