-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.clj
83 lines (73 loc) · 3.23 KB
/
train.clj
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
(ns tech.ml.train
(:require [tech.ml.dataset :as dataset]
[tech.parallel :as parallel]
[tech.datatype :as dtype]))
(defn dataset-seq->dataset-model-seq
"Given a sequence of {:train-ds ...} datasets, produce a sequence of:
{:model ...}
train-ds is removed to keep memory usage as low as possible.
See dataset/dataset->k-fold-datasets"
[train-fn dataset-seq]
(->> dataset-seq
(map (fn [{:keys [train-ds] :as item}]
(-> (dissoc item :train-ds)
(assoc :model (train-fn train-ds)))))))
(defn average-prediction-error
"Average prediction error across models generated with these datasets
Page 242, https://web.stanford.edu/~hastie/ElemStatLearn/"
[train-fn predict-fn ds-entry->label-fn loss-fn dataset-seq]
(->> (dataset-seq->dataset-model-seq train-fn dataset-seq)
(map (fn [{:keys [test-ds model]}]
(let [predictions (predict-fn model test-ds)
labels (->> test-ds
(map ds-entry->label-fn))]
(loss-fn predictions labels))))
(apply +)
(* (/ 1.0 (count dataset-seq)))))
(defn- expand-parameter-sequence
[base-options param-key param-seq-map]
(let [val-seq (get param-seq-map param-key)
param-seq-map (dissoc param-seq-map param-key)]
(->> val-seq
(mapcat (fn [seq-val]
(let [base-options (assoc base-options param-key seq-val)]
(if-let [next-key (first (keys param-seq-map))]
(lazy-seq (expand-parameter-sequence base-options next-key param-seq-map))
[base-options])))))))
(defn options-seq
"Given base options map and a map of parameter keyword -> value sequence
produce a sequence of options maps that does a cartesian join across all of
the parameter sequences"
[base-options parameter-sequence-map]
(if-let [first-key (first (keys parameter-sequence-map))]
(expand-parameter-sequence base-options first-key parameter-sequence-map)
base-options))
(defn find-best-options
"Given a sequence of options and a sequence of datasets (for k-fold),
run them and return the best options.
train-fn: (train-fn options dataset) -> model
predict-fn: (predict-fn options dataset) -> prediction-sequence
label-key: key to get labels from dataset.
loss-fn: (loss-fn label-sequence prediction-sequence)-> double
Lowest number wins."
[train-fn predict-fn label-key loss-fn {:keys [parallelism top-n]
:or {parallelism (.availableProcessors
(Runtime/getRuntime))
top-n 5}}
option-seq dataset-seq]
(->> option-seq
(parallel/queued-pmap
parallelism
(fn [options]
{:options options
:error (average-prediction-error
(partial train-fn options)
predict-fn
label-key
loss-fn
dataset-seq)}))
(reduce (fn [best-models {:keys [options error] :as next-map}]
(->> (conj best-models next-map)
(sort-by :error)
(take top-n)
vec)))))