-
Notifications
You must be signed in to change notification settings - Fork 4
/
regression.clj
80 lines (69 loc) · 2.62 KB
/
regression.clj
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
(ns tech.verify.ml.regression
(:require [tech.ml-base :as ml]
[tech.ml.loss :as loss]
[tech.ml.train :as train]
[tech.ml.dataset :as dataset]
[clojure.test :refer :all]))
(defn datasets
[]
(let [f (partial * 2)
observe (fn []
(let [x (- (* 20 (rand)) 10)
y (f x)]
{:x x :y y}))
train-dataset (->> (repeatedly observe)
(take 1000))
test-dataset (for [x (range -9.9 10 0.1)] {:x x :y (f x)})]
{:train-ds train-dataset
:test-ds test-dataset}))
(defn basic-regression
[{:keys [model-type accuracy]
:or {accuracy 0.01} :as options}]
(let [{train-dataset :train-ds
test-dataset :test-ds} (datasets)
test-labels (map :y test-dataset)
model (ml/train options [:x] :y train-dataset)
test-output (ml/predict model test-dataset)
mse (loss/mse test-output test-labels)]
(is (< mse (double accuracy)))))
(defn scaled-features
[options]
(let [{train-dataset :train-ds
test-dataset :test-ds} (datasets)
test-labels (map :y test-dataset)
model (ml/train (merge {:range-map {::dataset/features [-1 1]}}
options)
[:x] :y
train-dataset)
test-output (ml/predict model test-dataset)
mse (loss/mse test-output test-labels)]
(is (< mse 0.01))))
(defn k-fold-regression
[options]
(let [{train-dataset :train-ds
test-dataset :test-ds} (datasets)
feature-keys [:x]
label :y
train-fn (partial ml/train options feature-keys label)
predict-fn ml/predict
mse (->> train-dataset
(dataset/->k-fold-datasets 10 {})
(train/average-prediction-error train-fn predict-fn
label loss/mse)
:average-loss)]
(is (< mse 0.01))))
(defn auto-gridsearch-simple
[options]
;;Pre-scale the dataset.
(let [gs-options (ml/auto-gridsearch-options options)
retval (ml/gridsearch [gs-options]
[:x] :y
loss/mse (:train-ds (datasets))
:scalar-labels? true
:gridsearch-depth (or (get options :gridsearch-depth)
100)
:range-map {::dataset/features [-1 1]})]
(is (< (double (:average-loss (first retval)))
(double (or (:mse-loss options)
0.2))))
retval))