Skip to content

Commit

Permalink
port kmeans to new API, using lazily evaluated 'get'
Browse files Browse the repository at this point in the history
  • Loading branch information
karlnapf committed Jul 6, 2018
1 parent 8560748 commit e4740e3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
10 changes: 5 additions & 5 deletions examples/meta/src/clustering/kmeans.sg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
File f_feats_train = csv_file("../../data/classifier_binary_2d_linear_features_train.dat")
Math:init_random(1)
File f_feats_train = csv_file("../../data/classifier_binary_2d_linear_features_train.dat")

#![create_features]
Features features_train = features(f_feats_train)
Expand All @@ -10,18 +10,18 @@ Distance d = distance("EuclideanDistance", lhs=features_train, rhs=features_trai
#![choose_distance]

#![create_instance_lloyd]
KMeans kmeans(2, d)
Machine kmeans = machine("KMeans", k=2, distance=d)
#![create_instance_lloyd]

#![train_dataset]
kmeans.train()
#![train_dataset]

#![extract_centers_and_radius]
RealMatrix c = kmeans.get_cluster_centers()
RealVector r = kmeans.get_radiuses()
RealMatrix c = kmeans.get_real_matrix("cluster_centers")
RealVector r = kmeans.get_real_vector("radiuses")
#![extract_centers_and_radius]

#![create_instance_mb]
KMeansMiniBatch kmeans_mb(k=2, distance=d, batch_size=4, max_iter=1000)
Machine kmeans_mini_batch = machine("KMeansMiniBatch", k=2, distance=d, batch_size=4, max_iter=100)
#![create_instance_mb]
4 changes: 3 additions & 1 deletion src/shogun/clustering/KMeansBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@ void CKMeansBase::init()
SG_ADD(&max_iter, "max_iter", "Maximum number of iterations", MS_AVAILABLE);
SG_ADD(&k, "k", "k, the number of clusters", MS_AVAILABLE);
SG_ADD(&dimensions, "dimensions", "Dimensions of data", MS_NOT_AVAILABLE);
SG_ADD(&R, "R", "Cluster radiuses", MS_NOT_AVAILABLE);
SG_ADD(&R, "radiuses", "Cluster radiuses", MS_NOT_AVAILABLE);

watch_method("cluster_centers", &CKMeansBase::get_cluster_centers);
}

0 comments on commit e4740e3

Please sign in to comment.