-
Notifications
You must be signed in to change notification settings - Fork 4
/
ml_base.clj
127 lines (119 loc) · 5.32 KB
/
ml_base.clj
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
(ns tech.ml-base
(:require [tech.ml.registry :as registry]
[tech.ml.protocols :as protocols]
[tech.ml.dataset :as dataset]
[tech.ml.gridsearch :as ml-gs]
[tech.ml.train :as train]
[tech.parallel :as parallel]
[tech.datatype :as dtype]
[clojure.set :as c-set])
(:import [java.util UUID]))
(defn train
[system-name feature-keys label-keys options dataset]
(let [ml-system (registry/system system-name)
options (merge options (protocols/coalesce-options ml-system options))
{:keys [coalesced-dataset options]}
(dataset/apply-dataset-options feature-keys label-keys options dataset)
model (protocols/train ml-system options coalesced-dataset)]
{:system system-name
:model model
:options options
:feature-keys feature-keys
:label-keys label-keys
:id (UUID/randomUUID)}))
(defn predict
[model dataset]
(let [ml-system (registry/system (:system model))
trained-model (:model model)
{:keys [coalesced-dataset]} (dataset/apply-dataset-options
(:feature-keys model) nil (:options model) dataset)]
(protocols/predict ml-system
(:options model)
(:model model)
coalesced-dataset)))
(defn auto-gridsearch-options
[system-name options]
(let [ml-system (registry/system system-name)]
(merge options
(protocols/gridsearch-options ml-system options))))
(defn gridsearch
"Gridsearch these system/option pairs by this dataset, averaging the errors
across k-folds and taking the lowest top-n options.
We are breaking out of 'simple' and into 'easy' here, this is pretty
opinionated. The point is to make 80% of the cases work great on the
first try."
[system-name->options-seq feature-keys label-keys
loss-fn dataset
& {:keys [parallelism top-n gridsearch-depth k-fold
scalar-labels?]
:or {parallelism (.availableProcessors
(Runtime/getRuntime))
top-n 5
gridsearch-depth 50
k-fold 5}
:as options}]
;;Scale the dataset once; scanning it to find ranges of things is expensive.
;;You are free, however, to provide your own scale map in the options.
(let [{:keys [options coalesced-dataset]}
(dataset/apply-dataset-options feature-keys label-keys
options
dataset)
;;This makes mse work out later
coalesced-dataset (if scalar-labels?
(->> coalesced-dataset
(map (fn [ds-entry]
(update ds-entry
::dataset/label
#(dtype/get-value % 0)))))
coalesced-dataset)
dataset-seq (if k-fold
(dataset/->k-fold-datasets k-fold options coalesced-dataset)
[coalesced-dataset])
train-fn (fn [[system-name options-map] dataset]
(train system-name ::dataset/features ::dataset/label
options-map dataset))
predict-fn predict
;;Becase we are working with a
ds-entry->predict-fn (if-let [label-map
(get-in options [:label-map
(first (dataset/normalize-keys
label-keys))])]
;;classification
(let [val->label (c-set/map-invert label-map)]
(fn [{:keys [::dataset/label]}]
(get val->label (-> (dtype/get-value label 0)
long))))
(do
(fn [{:keys [::dataset/label]}]
(dtype/get-value label 0))))]
(->> system-name->options-seq
;;Build master set of gridsearch pairs
(mapcat (fn [[system-name options-map]]
(->> (ml-gs/gridsearch options-map)
(take gridsearch-depth)
(map (fn [gs-opt] [system-name (merge options gs-opt)])))))
(parallel/queued-pmap
parallelism
(fn [sys-op-pair]
(try
(let [pred-data (train/average-prediction-error
(partial train-fn sys-op-pair)
predict-fn
ds-entry->predict-fn
loss-fn
dataset-seq)]
(merge pred-data
{:system (first sys-op-pair)
:options (second sys-op-pair)
:k-fold k-fold
}))
(catch Throwable e
nil))))
(remove nil?)
;;Partition to keep sorting down a bit.
(partition-all top-n)
(reduce (fn [best-items next-group]
(->> (concat best-items next-group)
(sort-by :average-loss)
(take top-n)))
[]))))