-
Notifications
You must be signed in to change notification settings - Fork 0
/
index.js
72 lines (56 loc) · 1.87 KB
/
index.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import _ from 'lodash'
const mean = _.memoize((ns) => _.sum(ns) / _.size(ns))
const getStdDev = _.memoize((ns) =>
Math.sqrt(mean(_.map(ns, (n) => _.multiply(n - mean(ns), n - mean(ns)))))
)
const scale = _.memoize((ns) =>
_.map(ns, (n) => _.divide(n - mean(ns), getStdDev(ns)) || 0)
)
const rotate = _.memoize((x) =>
_.map(_.first(x), (_v, i) => _.reverse(_.map(x, (r) => r[i])))
)
const rotateCounter = _.memoize((x) =>
_.map(_.first(x), (_v, i) => _.map(x, (r) => r[_.size(r) - 1 - i]))
)
const getShape = _.memoize((x) => [_.size(x), _.size(rotate(x))])
const getScaled = _.memoize((x) =>
rotateCounter(_.map(rotate(x), (_r, i) => scale(_.get(rotate(x), i))))
)
const getMean = _.memoize((x) => _.map(x, (_r, i) => mean(_.get(x, i))))
const getSqrt = _.memoize((x) => _.map(x, (n) => Math.sqrt(n)))
const getVariance = _.memoize((x, size) =>
_.map(
_.map(x, (r) => _.sum(r)),
(n) => n / size
)
)
const getCorrection = _.memoize((x) =>
_.map(rotate(x), (r, i) => _.map(r, (n) => n - _.get(getMean(rotate(x)), i)))
)
const getCorrectionSquared = _.memoize((x) =>
_.map(getCorrection(x), (r) => _.map(r, (n) => n ** 2))
)
export const transform = _.memoize((x, labels) => ({
shape: getShape(x),
scaled: getScaled(x),
mean: getMean(rotate(x)),
variance: getVariance(getCorrectionSquared(x), _.size(x)),
scale: getSqrt(getVariance(getCorrectionSquared(x), _.size(x))),
labels: _.take(labels, _.last(getShape(x)))
}))
export const inverseTransform = _.memoize((x) =>
_.map(x.scaled, (r) =>
_.map(r, (n, i) => _.multiply(n, _.get(x.scale, i)) + _.get(x.mean, i))
)
)
export const getFeature = (x, col) => rotateCounter(
_.get(rotate(x.scaled), _.indexOf(x.labels, col))
)
export const dropFeature = (x, col) =>
rotateCounter(
_.compact(
_.map(rotate(x.scaled), (r, i) =>
i !== _.indexOf(x.labels, col) ? r : null
)
)
)