-
Notifications
You must be signed in to change notification settings - Fork 3
/
ensemble.clj
96 lines (75 loc) · 3.02 KB
/
ensemble.clj
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
(ns scicloj.metamorph.ml.ensemble
(:require
[scicloj.metamorph.core :as morph]
[tablecloth.api :as tc]
[tech.v3.dataset :as ds]
[tech.v3.dataset.column-filters :as cf]))
(defn- majority [l]
(->>
(frequencies l)
seq
(sort-by second)
reverse
first
first))
(defn ensemble-pipe
"Creates an ensemble pipeline function out of various pipelines. The different predictions
get combined via majority voting.
Can be used in the same way as any other pipeline."
[pipes]
(morph/pipeline
{:metamorph/id :ensemble}
(fn [{:metamorph/keys [id data mode] :as ctx}]
(let [
pipe-keys (map-indexed (fn [index _] (keyword (str "pipe-" index))) pipes)]
(case mode
:fit
(let [fitted-ctxs
(apply merge
(map
(fn [pipe-key pipe]
(hash-map pipe-key
(morph/fit-pipe (:metamorph/data ctx) pipe)))
pipe-keys
pipes))]
(assoc ctx id {
:fitted-ctxs fitted-ctxs}))
:transform
(let [
target-column (-> ctx (get id) :fitted-ctxs :pipe-0 :model :target-columns first)
target-categorical-map (-> ctx (get id) :fitted-ctxs :pipe-0 :model :target-categorical-maps)
transformed-ctxs
(map
(fn [pipe-key pipe] (morph/transform-pipe data pipe (-> ctx (get id) :fitted-ctxs pipe-key)))
pipe-keys
pipes)
predictions
(map
#(cf/prediction (get % :metamorph/data))
transformed-ctxs)
columns
(map-indexed
(fn [index prediction]
(ds/new-column (keyword (str "model-" index)) (get prediction target-column)))
predictions)
target-ds (-> transformed-ctxs first :model :scicloj.metamorph.ml/target-ds)
prediction-ds (-> (ds/new-dataset columns)
(tc/add-column target-column
(fn [ds]
(->> ds
tc/rows
(map majority))))
(ds/assoc-metadata [target-column]
:column-type :prediction
:categorical-map (get target-categorical-map target-column)))]
(assoc ctx
:model {:scicloj.metamorph.ml/target-ds target-ds}
:metamorph/data prediction-ds
id
(->>
(map
(fn [pipe-key ctx]
(hash-map pipe-key ctx))
pipe-keys
transformed-ctxs)
(apply merge)))))))))