/
ml_base.clj
32 lines (28 loc) · 1.13 KB
/
ml_base.clj
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
(ns tech.ml-base
(:require [tech.ml.registry :as registry]
[tech.ml.protocols :as protocols]
[tech.ml.dataset :as dataset])
(:import [java.util UUID]))
(defn train
[system-name feature-keys label-keys options dataset]
(let [ml-system (registry/system system-name)
options (merge options (protocols/coalesce-options ml-system options))
{:keys [coalesced-dataset options]}
(dataset/apply-dataset-options feature-keys label-keys options dataset)
model (protocols/train ml-system options coalesced-dataset)]
{:system system-name
:model model
:options options
:feature-keys feature-keys
:label-keys label-keys
:id (UUID/randomUUID)}))
(defn predict
[model dataset]
(let [ml-system (registry/system (:system model))
trained-model (:model model)
{:keys [coalesced-dataset]} (dataset/apply-dataset-options
(:feature-keys model) nil (:options model) dataset)]
(protocols/predict ml-system
(:options model)
(:model model)
coalesced-dataset)))