-
Notifications
You must be signed in to change notification settings - Fork 4
/
sparse_logreg.clj
82 lines (66 loc) · 2.68 KB
/
sparse_logreg.clj
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
(ns scicloj.ml.smile.sparse-logreg
(:require
[tech.v3.datatype :as dt]
[scicloj.metamorph.ml :as ml]
[tech.v3.dataset :as ds]
[tech.v3.dataset.modelling :as ds-mod]
[scicloj.ml.smile.discrete-nb :as nb]
[scicloj.ml.smile.nlp :as nlp]
[tech.v3.datatype.errors :as errors]
)
(:import [smile.classification SparseLogisticRegression]
[smile.data SparseDataset]
[smile.util SparseArray]))
(defn train [feature-ds target-ds options]
"Training function of sparse logistic regression model.
The column of name `(options :sparse-column)` of `feature-ds` needs to contain the text as SparseArrays
over the vocabulary.
Options:
* `:sparse-column` : column name with contains the sparse data as seq of SparseArrays
* `:n-sparse-columns`: Number of columns / dimensions of the sparse vectors
"
(errors/when-not-error (:sparse-column options) ":sparse-column need to be given")
(errors/when-not-error (:n-sparse-columns options) ":n-sparse-columns need to be given")
(let [train-array (into-array SparseArray
(get feature-ds (:sparse-column options)))
train-dataset (SparseDataset/of (seq train-array) (options :n-sparse-columns))
score (get target-ds (first (ds-mod/inference-target-column-names target-ds)))]
(SparseLogisticRegression/fit train-dataset
(dt/->int-array score)
(get options :lambda 0.1)
(get options :tolerance 1e-5)
(get options :max-iterations 500)
)))
(defn predict [feature-ds
thawed-model
model]
"Predict function for sparse logistic regression model."
(nb/predict feature-ds thawed-model model))
(ml/define-model!
:smile.classification/sparse-logistic-regression
train
predict
{:options [{:name :lambda
:type :float32
:default 0.1}
{:name :tolerance
:type :float32
:default 1e-5}
{:name :max-iterations
:type :int32
:default 500}
]})
(comment
(defn get-reviews []
(->
(ds/->dataset "test/data/reviews.csv.gz" {:key-fn keyword })
(ds/select-columns [:Text :Score])
(ds/update-column :Score #(map dec %))
(nlp/count-vectorize :Text :bow nlp/default-text->bow)
(nb/bow->SparseArray :bow :bow-sparse 100)
(ds-mod/set-inference-target :Score)))
(def trained-model
(ml/train reviews {:model-type :sparse-logistic-regression
:sparse-column :bow-sparse}))
(ml/predict reviews trained-model)
)