Skip to content

Commit 2aed319

Browse files
feat(k-means): update initKmeanspp(), add configurable distance exponent
- add KMeansOpts.exponent - update doc strings
1 parent 84d3aba commit 2aed319

File tree

2 files changed

+33
-11
lines changed

2 files changed

+33
-11
lines changed

packages/k-means/src/api.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ export interface KMeansOpts {
1515
* Distance function/metric to use for finding nearest centroid.
1616
*/
1717
dist: IDistance<ReadonlyVec>;
18+
/**
19+
* Sample dimensions. If omitted uses length of first sample vector.
20+
*/
21+
dim: number;
1822
/**
1923
* Max. iteration count
2024
*/
@@ -27,6 +31,15 @@ export interface KMeansOpts {
2731
* Centroid refinement strategy (default: {@link means}).
2832
*/
2933
strategy: CentroidStrategy;
34+
/**
35+
* Only used if no {@link KMeansOpts.initial} is given and the
36+
* {@link initKmeanspp} default is used. There the `exponent` is applied to
37+
* scale the distances to nearest centroid, which will be used to control
38+
* the weight distribution for choosing next centroid. A higher exponent
39+
* means that points with larger distances will be more prioritized in the
40+
* random selection.
41+
*/
42+
exponent: number;
3043
}
3144

3245
/**

packages/k-means/src/kmeans.ts

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,23 @@ import type { CentroidStrategy, Cluster, KMeansOpts } from "./api.js";
2828
export const kmeans = <T extends ReadonlyVec>(
2929
k: number,
3030
samples: T[],
31-
opts?: Partial<KMeansOpts>
31+
opts: Partial<KMeansOpts> = {}
3232
) => {
33-
let { dist, initial, maxIter, rnd, strategy } = {
34-
dist: DIST_SQ,
35-
maxIter: 32,
36-
strategy: means,
37-
...opts,
38-
};
33+
let {
34+
dim = samples[0].length,
35+
dist = DIST_SQ,
36+
maxIter = 32,
37+
strategy = means,
38+
exponent,
39+
initial,
40+
rnd,
41+
} = opts;
3942
const num = samples.length;
40-
const dim = samples[0].length;
4143
const centroidIDs = Array.isArray(initial)
4244
? initial
4345
: initial
4446
? initial(k, samples, dist, rnd)
45-
: initKmeanspp(k, samples, dist, rnd);
47+
: initKmeanspp(k, samples, dist, rnd, exponent);
4648
assert(centroidIDs.length > 0, `missing initial centroids`);
4749
k = centroidIDs.length;
4850
const centroids: Vec[] = centroidIDs.map((i) => samples[i]);
@@ -72,6 +74,11 @@ export const kmeans = <T extends ReadonlyVec>(
7274
* fulfilled (e.g. due to lower number of samples and/or distance metric).
7375
* Throws an error if `samples` are empty.
7476
*
77+
* The optional `exponent` (default: 2) is applied to scale the distances to
78+
* nearest centroid, which will be used to control the weight distribution for
79+
* choosing next centroid. A higher exponent means that points with larger
80+
* distances will be more prioritized in the random selection.
81+
*
7582
* References:
7683
*
7784
* - https://en.wikipedia.org/wiki/K-means%2B%2B
@@ -82,12 +89,14 @@ export const kmeans = <T extends ReadonlyVec>(
8289
* @param samples -
8390
* @param dist -
8491
* @param rnd -
92+
* @param exponent -
8593
*/
8694
export const initKmeanspp = <T extends ReadonlyVec>(
8795
k: number,
8896
samples: T[],
8997
dist: IDistance<ReadonlyVec> = DIST_SQ,
90-
rnd: IRandom = SYSTEM
98+
rnd: IRandom = SYSTEM,
99+
exponent = 2
91100
) => {
92101
const num = samples.length;
93102
assert(num > 0, `missing samples`);
@@ -101,7 +110,7 @@ export const initKmeanspp = <T extends ReadonlyVec>(
101110
const probs = samples.map((p) => {
102111
const d =
103112
dist.from(metric(p, centroids[argmin(p, centroids, dist)!])) **
104-
2;
113+
exponent;
105114
psum += d;
106115
return d;
107116
});

0 commit comments

Comments
 (0)