|
| 1 | +-- This example trains models on the sklean digits dataset |
| 2 | +-- which is a copy of the test set of the UCI ML hand-written digits datasets |
| 3 | +-- https://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits |
| 4 | +-- |
| 5 | +-- This demonstrates using a table with a single array feature column |
| 6 | +-- for classification. |
| 7 | +-- |
| 8 | +-- The final result after a few seconds of training is not terrible. Maybe not perfect |
| 9 | +-- enough for mission critical applications, but it's telling how quickly "off the shelf" |
| 10 | +-- solutions can solve problems these days. |
| 11 | + |
| 12 | +-- Exit on error (psql) |
| 13 | +\set ON_ERROR_STOP true |
| 14 | + |
| 15 | +SELECT pgml_rust.load_dataset('digits'); |
| 16 | + |
| 17 | +-- view the dataset |
| 18 | +SELECT left(image::text, 40) || ',...}', target FROM pgml_rust.digits LIMIT 10; |
| 19 | + |
| 20 | +-- train a simple model to classify the data |
| 21 | +SELECT * FROM pgml_rust.train('Handwritten Digits', 'classification', 'pgml_rust.digits', 'target'); |
| 22 | + |
| 23 | +-- check out the predictions |
| 24 | +SELECT target, pgml_rust.predict('Handwritten Digits', image) AS prediction |
| 25 | +FROM pgml_rust.digits |
| 26 | +LIMIT 10; |
| 27 | + |
| 28 | +-- |
| 29 | +-- After a project has been trained, ommited parameters will be reused from previous training runs |
| 30 | +-- In these examples we'll reuse the training data snapshots from the initial call. |
| 31 | +-- |
| 32 | + |
| 33 | +-- linear models |
| 34 | +SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'ridge'); |
| 35 | +SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'stochastic_gradient_descent'); |
| 36 | +SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'perceptron'); |
| 37 | +SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'passive_aggressive'); |
| 38 | + |
| 39 | +-- support vector machines |
| 40 | +SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'svm'); |
| 41 | +SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'nu_svm'); |
| 42 | +SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'linear_svm'); |
| 43 | + |
| 44 | +-- ensembles |
| 45 | +SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'ada_boost'); |
| 46 | +SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'bagging'); |
| 47 | +SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'extra_trees', hyperparams => '{"n_estimators": 10}'); |
| 48 | +SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'gradient_boosting_trees', hyperparams => '{"n_estimators": 10}'); |
| 49 | +SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'random_forest', hyperparams => '{"n_estimators": 10}'); |
| 50 | + |
| 51 | +-- other |
| 52 | +-- Gaussian Process is too expensive for normal tests on even a toy dataset |
| 53 | +-- SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'gaussian_process', hyperparams => '{"max_iter_predict": 100, "warm_start": true}'); |
| 54 | + |
| 55 | +-- gradient boosting |
| 56 | +SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'xgboost', hyperparams => '{"n_estimators": 10}'); |
| 57 | +SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'xgboost_random_forest', hyperparams => '{"n_estimators": 10}'); |
| 58 | +SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'lightgbm', hyperparams => '{"n_estimators": 100}'); |
| 59 | +-- Histogram Gradient Boosting is too expensive for normal tests on even a toy dataset |
| 60 | +-- SELECT * FROM pgml_rust.train('Handwritten Digits', algorithm => 'hist_gradient_boosting', hyperparams => '{"max_iter": 2}'); |
| 61 | + |
| 62 | + |
| 63 | +-- check out all that hard work |
| 64 | +SELECT trained_models.* FROM pgml_rust.trained_models |
| 65 | +JOIN pgml_rust.models on models.id = trained_models.id |
| 66 | +ORDER BY models.metrics->>'f1' DESC LIMIT 5; |
| 67 | + |
| 68 | +-- deploy the random_forest model for prediction use |
| 69 | +SELECT * FROM pgml_rust.deploy('Handwritten Digits', 'most_recent', 'random_forest'); |
| 70 | +-- check out that throughput |
| 71 | +SELECT * FROM pgml_rust.deployed_models ORDER BY deployed_at DESC LIMIT 5; |
| 72 | + |
| 73 | +-- do a hyperparam search on your favorite algorithm |
| 74 | +SELECT pgml_rust.train( |
| 75 | + 'Handwritten Digits', |
| 76 | + algorithm => 'svm', |
| 77 | + hyperparams => '{"random_state": 0}', |
| 78 | + search => 'grid', |
| 79 | + search_params => '{ |
| 80 | + "kernel": ["linear", "poly", "sigmoid"], |
| 81 | + "shrinking": [true, false] |
| 82 | + }' |
| 83 | +); |
| 84 | + |
| 85 | +-- TODO SELECT pgml_rust.hypertune(100, 'Handwritten Digits', 'classification', 'pgml_rust.digits', 'target', 'gradient_boosted_trees'); |
| 86 | + |
| 87 | +-- deploy the "best" model for prediction use |
| 88 | +SELECT * FROM pgml_rust.deploy('Handwritten Digits', 'best_score'); |
| 89 | +SELECT * FROM pgml_rust.deploy('Handwritten Digits', 'most_recent'); |
| 90 | +SELECT * FROM pgml_rust.deploy('Handwritten Digits', 'rollback'); |
| 91 | +SELECT * FROM pgml_rust.deploy('Handwritten Digits', 'best_score', 'svm'); |
| 92 | + |
| 93 | +-- check out the improved predictions |
| 94 | +SELECT target, pgml_rust.predict('Handwritten Digits', image) AS prediction |
| 95 | +FROM pgml_rust.digits |
| 96 | +LIMIT 10; |
0 commit comments