Skip to content

Commit

Permalink
0.9
Browse files Browse the repository at this point in the history
  • Loading branch information
cnuernber committed Nov 9, 2018
1 parent a6f1978 commit 56b3844
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 6 deletions.
2 changes: 1 addition & 1 deletion project.clj
@@ -1,4 +1,4 @@
(defproject techascent/tech.ml-base "0.9-SNAPSHOT"
(defproject techascent/tech.ml-base "0.9"
:description "Base concepts of the techascent ml suite"
:url "http://github.com/tech-ascent/tech.ml-base"
:license {:name "Eclipse Public License"
Expand Down
8 changes: 5 additions & 3 deletions src/tech/ml/train.clj
Expand Up @@ -70,12 +70,14 @@ loss-fn: (loss-fn label-sequence prediction-sequence)-> double
{:options options
:error (average-prediction-error
(partial train-fn options)
(partial predict-fn options)
predict-fn
label-key
loss-fn
dataset-seq) }))
dataset-seq)}))
(reduce (fn [best-map {:keys [options error] :as next-map}]
(if (or (not best-map)
(< (double error)
(double (:error best-map))))
{})))))
next-map
best-map))
nil)))
30 changes: 28 additions & 2 deletions src/tech/verify/ml/train.clj
Expand Up @@ -19,7 +19,8 @@
(take 1000))
test-dataset (for [x (range -9.9 10 0.1)] {:x x})
test-labels (map (comp f :x) test-dataset)
model (ml/train system-name [:x] :y {:model-type (or model-type :regression)} train-dataset)
model (ml/train system-name [:x] :y
{:model-type (or model-type :regression)} train-dataset)
test-output (ml/predict model test-dataset)
mse (loss/mse test-output test-labels)]
(is (< mse (double accuracy)))))
Expand Down Expand Up @@ -55,10 +56,35 @@
(take 1000))
feature-keys [:x]
label :y
train-fn (partial ml/train system-name feature-keys label {:model-type :regression})
train-fn (partial ml/train system-name feature-keys label
{:model-type :regression})
predict-fn ml/predict
mse (->> dataset
(dataset/dataset->k-fold-datasets 10 {})
(train/average-prediction-error train-fn predict-fn
label loss/mse))]
(is (< mse 0.01))))


(defn gridsearch
[system-name options]
(let [f (partial * 2)
observe (fn []
(let [x (- (* 20 (rand)) 10)
y (f x)]
{:x x :y y}))
dataset (->> (repeatedly observe)
(take 1000))
feature-keys [:x]
label :y
train-fn (partial ml/train system-name feature-keys label)
predict-fn ml/predict
k-fold-ds (dataset/dataset->k-fold-datasets 5 {} dataset)
option-seq [(merge {:model-type :regression} options)
(merge {:model-type :regression} options)]
{:keys [error options]} (train/find-best-options train-fn predict-fn
label
loss/mse {}
option-seq k-fold-ds)
mse (or (:mse options) 0.01)]
(is (< error mse))))

0 comments on commit 56b3844

Please sign in to comment.