-
Notifications
You must be signed in to change notification settings - Fork 4
/
classification.clj
97 lines (87 loc) · 3.51 KB
/
classification.clj
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
(ns tech.verify.ml.classification
(:require [clojure.string :as s]
[clojure.java.io :as io]
[camel-snake-kebab.core :refer [->kebab-case]]
[tech.ml.dataset :as ds]
[tech.ml :as ml]
[tech.ml.dataset.pipeline :as dsp]
[tech.ml.dataset.pipeline.pipeline-operators
:refer [without-recording
pipeline-train-context
pipeline-inference-context]]
[tech.ml.dataset.pipeline.column-filters :as cf]
[tech.ml.loss :as loss]
[clojure.test :refer :all]))
(defn mapseq-dataset
[]
(let [fruit-ds (slurp (io/resource "fruit_data_with_colors.txt"))
dataset (->> (s/split fruit-ds #"\n")
(mapv #(s/split % #"\s+")))
ds-keys (->> (first dataset)
(mapv (comp keyword ->kebab-case)))]
(->> (rest dataset)
(map (fn [ds-line]
(->> ds-line
(map (fn [ds-val]
(try
(Double/parseDouble ^String ds-val)
(catch Throwable e
(-> (->kebab-case ds-val)
keyword)))))
(zipmap ds-keys)))))))
(def fruit-dataset
(memoize
(fn []
(ds/->dataset (mapseq-dataset)))))
(defn fruit-pipeline
[dataset training?]
(-> dataset
(ds/remove-columns [:fruit-subtype :fruit-label])
(dsp/range-scale #(cf/not cf/categorical?))
(dsp/pwhen
training?
#(without-recording
(-> %
(dsp/string->number :fruit-name)
(ds/set-inference-target :fruit-name))))))
(defn classify-fruit
[options]
(let [options (assoc options :target :fruit-name)
pipeline-data (pipeline-train-context
(fruit-pipeline (fruit-dataset) true))
ds (:dataset pipeline-data)
{:keys [train-ds test-ds]} (ds/->train-test-split ds {})
model (ml/train options train-ds)
test-output (ml/predict model test-ds)
labels (ds/labels test-ds)]
;;Accuracy gets *better* as it increases. This is the opposite of a loss!!
(is (> (loss/classification-accuracy test-output labels)
(or (:classification-accuracy options)
0.7)))
;;Now here is the production pathway
(let [inference-src-ds (ds/remove-columns
(fruit-dataset)
[:fruit-name :fruit-subtype :fruit-label])
inference-ds (-> (pipeline-inference-context
(:context pipeline-data)
(fruit-pipeline inference-src-ds false))
:dataset)
inference-output (ml/predict model inference-ds)]
(is (> (loss/classification-accuracy test-output labels)
(or (:classification-accuracy options)
0.7))))))
(defn auto-gridsearch-fruit
[options]
(let [options (assoc options
:target :fruit-name
:k-fold 3)
ds (fruit-pipeline (fruit-dataset) true)
;; Annotate options with gridsearch information.
gs-options (ml/auto-gridsearch-options options)
retval (ml/gridsearch (assoc gs-options :k-fold 3)
loss/classification-loss
ds)]
(is (< (double (:average-loss (first retval)))
(double (or (:classification-loss options)
0.2))))
retval))