Skip to content

Commit

Permalink
fix(k-means): use dist metric in initKmeanspp()
Browse files Browse the repository at this point in the history
- update to use user-provided distance metric
  • Loading branch information
postspectacular committed Apr 19, 2021
1 parent 3a9a77a commit 37bd6c6
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions packages/k-means/src/kmeans.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
import { assert } from "@thi.ng/api";
import { argmin, DIST_SQ, IDistance } from "@thi.ng/distance";
import { SYSTEM, uniqueIndices, weightedRandom } from "@thi.ng/random";
import {
add,
distSq,
median,
mulN,
ReadonlyVec,
Vec,
zeroes,
} from "@thi.ng/vectors";
import { add, median, mulN, ReadonlyVec, Vec, zeroes } from "@thi.ng/vectors";
import type { CentroidStrategy, Cluster, KMeansOpts } from "./api";

export const kmeans = <T extends ReadonlyVec>(
Expand All @@ -25,7 +17,7 @@ export const kmeans = <T extends ReadonlyVec>(
};
const num = samples.length;
const dim = samples[0].length;
const centroidIDs = initial || initKmeanspp(k, samples, rnd);
const centroidIDs = initial || initKmeanspp(k, samples, dist, rnd);
assert(centroidIDs.length === k, `wrong number of initial centroids`);
const centroids: Vec[] = centroidIDs.map((i) => samples[i]);
const clusters: number[] = [];
Expand Down Expand Up @@ -63,20 +55,23 @@ export const kmeans = <T extends ReadonlyVec>(
*
* @param k
* @param samples
* @param dist
* @param rnd
*/
export const initKmeanspp = <T extends ReadonlyVec>(
k: number,
samples: T[],
dist: IDistance<ReadonlyVec> = DIST_SQ,
rnd = SYSTEM
) => {
const num = samples.length;
const centroidIDs = [rnd.int() % num];
const centroids = [samples[centroidIDs[0]]];
const indices = new Array(num).fill(0).map((_, i) => i);
while (centroidIDs.length < k) {
let probs = samples.map((p) =>
distSq(p, centroids[argmin(p, centroids)!])
let probs = samples.map(
(p) =>
dist.from(dist.metric(p, centroids[argmin(p, centroids)!])) ** 2
);
const id = weightedRandom(indices, probs)();
centroidIDs.push(id);
Expand Down

0 comments on commit 37bd6c6

Please sign in to comment.