@@ -28,21 +28,23 @@ import type { CentroidStrategy, Cluster, KMeansOpts } from "./api.js";
28
28
export const kmeans = < T extends ReadonlyVec > (
29
29
k : number ,
30
30
samples : T [ ] ,
31
- opts ? : Partial < KMeansOpts >
31
+ opts : Partial < KMeansOpts > = { }
32
32
) => {
33
- let { dist, initial, maxIter, rnd, strategy } = {
34
- dist : DIST_SQ ,
35
- maxIter : 32 ,
36
- strategy : means ,
37
- ...opts ,
38
- } ;
33
+ let {
34
+ dim = samples [ 0 ] . length ,
35
+ dist = DIST_SQ ,
36
+ maxIter = 32 ,
37
+ strategy = means ,
38
+ exponent,
39
+ initial,
40
+ rnd,
41
+ } = opts ;
39
42
const num = samples . length ;
40
- const dim = samples [ 0 ] . length ;
41
43
const centroidIDs = Array . isArray ( initial )
42
44
? initial
43
45
: initial
44
46
? initial ( k , samples , dist , rnd )
45
- : initKmeanspp ( k , samples , dist , rnd ) ;
47
+ : initKmeanspp ( k , samples , dist , rnd , exponent ) ;
46
48
assert ( centroidIDs . length > 0 , `missing initial centroids` ) ;
47
49
k = centroidIDs . length ;
48
50
const centroids : Vec [ ] = centroidIDs . map ( ( i ) => samples [ i ] ) ;
@@ -72,6 +74,11 @@ export const kmeans = <T extends ReadonlyVec>(
72
74
* fulfilled (e.g. due to lower number of samples and/or distance metric).
73
75
* Throws an error if `samples` are empty.
74
76
*
77
+ * The optional `exponent` (default: 2) is applied to scale the distances to
78
+ * nearest centroid, which will be used to control the weight distribution for
79
+ * choosing next centroid. A higher exponent means that points with larger
80
+ * distances will be more prioritized in the random selection.
81
+ *
75
82
* References:
76
83
*
77
84
* - https://en.wikipedia.org/wiki/K-means%2B%2B
@@ -82,12 +89,14 @@ export const kmeans = <T extends ReadonlyVec>(
82
89
* @param samples -
83
90
* @param dist -
84
91
* @param rnd -
92
+ * @param exponent -
85
93
*/
86
94
export const initKmeanspp = < T extends ReadonlyVec > (
87
95
k : number ,
88
96
samples : T [ ] ,
89
97
dist : IDistance < ReadonlyVec > = DIST_SQ ,
90
- rnd : IRandom = SYSTEM
98
+ rnd : IRandom = SYSTEM ,
99
+ exponent = 2
91
100
) => {
92
101
const num = samples . length ;
93
102
assert ( num > 0 , `missing samples` ) ;
@@ -101,7 +110,7 @@ export const initKmeanspp = <T extends ReadonlyVec>(
101
110
const probs = samples . map ( ( p ) => {
102
111
const d =
103
112
dist . from ( metric ( p , centroids [ argmin ( p , centroids , dist ) ! ] ) ) **
104
- 2 ;
113
+ exponent ;
105
114
psum += d ;
106
115
return d ;
107
116
} ) ;
0 commit comments