From 56b384434b9120e14b745ce0cd3d9fef31e54df8 Mon Sep 17 00:00:00 2001 From: Chris Nuernberger Date: Thu, 8 Nov 2018 18:28:28 -0700 Subject: [PATCH] 0.9 --- project.clj | 2 +- src/tech/ml/train.clj | 8 +++++--- src/tech/verify/ml/train.clj | 30 ++++++++++++++++++++++++++++-- 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/project.clj b/project.clj index a304e28..a71b6da 100644 --- a/project.clj +++ b/project.clj @@ -1,4 +1,4 @@ -(defproject techascent/tech.ml-base "0.9-SNAPSHOT" +(defproject techascent/tech.ml-base "0.9" :description "Base concepts of the techascent ml suite" :url "http://github.com/tech-ascent/tech.ml-base" :license {:name "Eclipse Public License" diff --git a/src/tech/ml/train.clj b/src/tech/ml/train.clj index 80b2790..5d323f1 100644 --- a/src/tech/ml/train.clj +++ b/src/tech/ml/train.clj @@ -70,12 +70,14 @@ loss-fn: (loss-fn label-sequence prediction-sequence)-> double {:options options :error (average-prediction-error (partial train-fn options) - (partial predict-fn options) + predict-fn label-key loss-fn - dataset-seq) })) + dataset-seq)})) (reduce (fn [best-map {:keys [options error] :as next-map}] (if (or (not best-map) (< (double error) (double (:error best-map)))) - {}))))) + next-map + best-map)) + nil))) diff --git a/src/tech/verify/ml/train.clj b/src/tech/verify/ml/train.clj index 76550c9..1f91f6b 100644 --- a/src/tech/verify/ml/train.clj +++ b/src/tech/verify/ml/train.clj @@ -19,7 +19,8 @@ (take 1000)) test-dataset (for [x (range -9.9 10 0.1)] {:x x}) test-labels (map (comp f :x) test-dataset) - model (ml/train system-name [:x] :y {:model-type (or model-type :regression)} train-dataset) + model (ml/train system-name [:x] :y + {:model-type (or model-type :regression)} train-dataset) test-output (ml/predict model test-dataset) mse (loss/mse test-output test-labels)] (is (< mse (double accuracy))))) @@ -55,10 +56,35 @@ (take 1000)) feature-keys [:x] label :y - train-fn (partial ml/train system-name feature-keys label {:model-type :regression}) + train-fn (partial ml/train system-name feature-keys label + {:model-type :regression}) predict-fn ml/predict mse (->> dataset (dataset/dataset->k-fold-datasets 10 {}) (train/average-prediction-error train-fn predict-fn label loss/mse))] (is (< mse 0.01)))) + + +(defn gridsearch + [system-name options] + (let [f (partial * 2) + observe (fn [] + (let [x (- (* 20 (rand)) 10) + y (f x)] + {:x x :y y})) + dataset (->> (repeatedly observe) + (take 1000)) + feature-keys [:x] + label :y + train-fn (partial ml/train system-name feature-keys label) + predict-fn ml/predict + k-fold-ds (dataset/dataset->k-fold-datasets 5 {} dataset) + option-seq [(merge {:model-type :regression} options) + (merge {:model-type :regression} options)] + {:keys [error options]} (train/find-best-options train-fn predict-fn + label + loss/mse {} + option-seq k-fold-ds) + mse (or (:mse options) 0.01)] + (is (< error mse))))