Browse files

Fixed DelegatingVector bug

Cleaned up small issues in StreamingKmeans
Allowed for small slop on weight in WeightedVector comparator
Improved kmeans test
Added another searcher test and added special test for Brute
  • Loading branch information...
1 parent 9e7d0fc commit c54a3f4580fa94e0d4379b991dfb406506ac7352 @tdunning committed Apr 4, 2012
View
230 src/main/java/org/apache/mahout/knn/DelegatingVector.java
@@ -17,56 +17,236 @@
package org.apache.mahout.knn;
-import org.apache.mahout.math.AbstractVector;
+import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.DoubleFunction;
import java.util.Iterator;
/**
* A delegating vector provides an easy way to decorate vectors with weights or id's and such while
* keeping all of the Vector functionality.
*/
-public class DelegatingVector extends AbstractVector {
+public class DelegatingVector implements Vector {
protected Vector delegate;
- protected DelegatingVector(int size) {
- super(size);
+ public DelegatingVector(int size) {
+ delegate = new DenseVector(size);
}
public DelegatingVector(Vector v) {
- super(v.size());
delegate = v;
}
+ public Vector getVector() {
+ return delegate;
+ }
+
@Override
- protected Matrix matrixLike(int i, int i1) {
- throw new UnsupportedOperationException("Can't make a matrix like this");
+ public double aggregate(DoubleDoubleFunction aggregator, DoubleFunction map) {
+ return delegate.aggregate(aggregator, map);
}
@Override
- public boolean isDense() {
- return delegate.isDense();
+ public double aggregate(Vector other, DoubleDoubleFunction aggregator, DoubleDoubleFunction combiner) {
+ return delegate.aggregate(other, aggregator, combiner);
}
@Override
- public boolean isSequentialAccess() {
- return delegate.isSequentialAccess();
+ public Vector viewPart(int offset, int length) {
+ return delegate.viewPart(offset, length);
}
@Override
- public Iterator<Element> iterator() {
- return delegate.iterator();
+ public Vector clone() {
+ return delegate.clone();
}
@Override
- public Iterator<Element> iterateNonZero() {
- return delegate.iterateNonZero();
+ public Vector divide(double x) {
+ return delegate.divide(x);
+ }
+
+ @Override
+ public double dot(Vector x) {
+ return delegate.dot(x);
+ }
+
+ @Override
+ public double get(int index) {
+ return delegate.get(index);
+ }
+
+ @Override
+ public Element getElement(int index) {
+ return delegate.getElement(index);
+ }
+
+ @Override
+ public Vector minus(Vector that) {
+ return delegate.minus(that);
+ }
+
+ @Override
+ public Vector normalize() {
+ return delegate.normalize();
+ }
+
+ @Override
+ public Vector normalize(double power) {
+ return delegate.normalize(power);
+ }
+
+ @Override
+ public Vector logNormalize() {
+ return delegate.logNormalize();
+ }
+
+ @Override
+ public Vector logNormalize(double power) {
+ return delegate.logNormalize(power);
+ }
+
+ @Override
+ public double norm(double power) {
+ return delegate.norm(power);
+ }
+
+ @Override
+ public double getLengthSquared() {
+ return delegate.getLengthSquared();
+ }
+
+ @Override
+ public double getDistanceSquared(Vector v) {
+ return delegate.getDistanceSquared(v);
+ }
+
+ @Override
+ public double maxValue() {
+ return delegate.maxValue();
+ }
+
+ @Override
+ public int maxValueIndex() {
+ return delegate.maxValueIndex();
+ }
+
+ @Override
+ public double minValue() {
+ return delegate.minValue();
+ }
+
+ @Override
+ public int minValueIndex() {
+ return delegate.minValueIndex();
+ }
+
+ @Override
+ public Vector plus(double x) {
+ return delegate.plus(x);
+ }
+
+ @Override
+ public Vector plus(Vector x) {
+ return delegate.plus(x);
+ }
+
+ @Override
+ public void set(int index, double value) {
+ delegate.set(index, value);
}
@Override
- public double getQuick(int i) {
- return delegate.getQuick(i);
+ public Vector times(double x) {
+ return delegate.times(x);
+ }
+
+ @Override
+ public Vector times(Vector x) {
+ return delegate.times(x);
+ }
+
+ @Override
+ public double zSum() {
+ return delegate.zSum();
+ }
+
+ @Override
+ public Vector assign(double value) {
+ return delegate.assign(value);
+ }
+
+ @Override
+ public Vector assign(double[] values) {
+ return delegate.assign(values);
+ }
+
+ @Override
+ public Vector assign(Vector other) {
+ return delegate.assign(other);
+ }
+
+ @Override
+ public Vector assign(DoubleDoubleFunction f, double y) {
+ return delegate.assign(f, y);
+ }
+
+ @Override
+ public Vector assign(DoubleFunction function) {
+ return delegate.assign(function);
+ }
+
+ @Override
+ public Vector assign(Vector other, DoubleDoubleFunction function) {
+ return delegate.assign(other, function);
+ }
+
+ @Override
+ public Matrix cross(Vector other) {
+ return delegate.cross(other);
+ }
+
+ @Override
+ public int size() {
+ return delegate.size();
+ }
+
+ @Override
+ public String asFormatString() {
+ return delegate.asFormatString();
+ }
+
+ @Override
+ public int hashCode() {
+ return delegate.hashCode();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return delegate.equals(o);
+ }
+
+ @Override
+ public String toString() {
+ return delegate.toString();
+ }
+
+ @Override
+ public boolean isDense() {
+ return delegate.isDense();
+ }
+
+ @Override
+ public boolean isSequentialAccess() {
+ return delegate.isSequentialAccess();
+ }
+
+ @Override
+ public double getQuick(int index) {
+ return delegate.getQuick(index);
}
@Override
@@ -75,16 +255,22 @@ public Vector like() {
}
@Override
- public void setQuick(int i, double v) {
- delegate.setQuick(i, v);
+ public void setQuick(int index, double value) {
+ delegate.setQuick(index, value);
}
@Override
public int getNumNondefaultElements() {
return delegate.getNumNondefaultElements();
}
- public Vector getVector() {
- return delegate;
+ @Override
+ public Iterator<Element> iterateNonZero() {
+ return delegate.iterateNonZero();
+ }
+
+ @Override
+ public Iterator<Element> iterator() {
+ return delegate.iterator();
}
-}
+}
View
90 src/main/java/org/apache/mahout/knn/VectorIterableView.java
@@ -1,90 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.knn;
-
-import com.google.common.collect.Iterables;
-import org.apache.mahout.math.MatrixSlice;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.VectorIterable;
-
-import java.util.Iterator;
-
-public class VectorIterableView implements VectorIterable {
- private VectorIterable data;
- private int start;
- private int rows;
-
- public VectorIterableView(VectorIterable data, int start, int rows) {
- this.data = data;
- this.start = start;
- this.rows = rows;
- }
-
- @Override
- public Iterator<MatrixSlice> iterateAll() {
- return Iterables.limit(Iterables.skip(data, start), rows).iterator();
- }
-
- @Override
- public int numSlices() {
- return rows;
- }
-
- @Override
- public int numRows() {
- return rows;
- }
-
- @Override
- public int numCols() {
- return this.iterateAll().next().vector().size();
- }
-
- /**
- * Return a new vector with cardinality equal to getNumRows() of this matrix which is the matrix product of the
- * recipient and the argument
- *
- * @param v a vector with cardinality equal to getNumCols() of the recipient
- * @return a new vector (typically a DenseVector)
- * @throws org.apache.mahout.math.CardinalityException
- * if this.getNumRows() != v.size()
- */
- @Override
- public Vector times(Vector v) {
- throw new UnsupportedOperationException("Default operation");
- }
-
- /**
- * Convenience method for producing this.transpose().times(this.times(v)), which can be implemented with only one pass
- * over the matrix, without making the transpose() call (which can be expensive if the matrix is sparse)
- *
- * @param v a vector with cardinality equal to getNumCols() of the recipient
- * @return a new vector (typically a DenseVector) with cardinality equal to that of the argument.
- * @throws org.apache.mahout.math.CardinalityException
- * if this.getNumCols() != v.size()
- */
- @Override
- public Vector timesSquared(Vector v) {
- throw new UnsupportedOperationException("Default operation");
- }
-
- @Override
- public Iterator<MatrixSlice> iterator() {
- return iterateAll();
- }
-}
View
2 src/main/java/org/apache/mahout/knn/WeightedVector.java
@@ -64,7 +64,7 @@ public int compareTo(WeightedVector other) {
return 0;
}
int r = Double.compare(weight, other.getWeight());
- if (r == 0) {
+ if (r == 0 || Math.abs(weight - other.getWeight()) < 1e-8) {
double diff = this.minus(other).norm(1);
if (diff < 1e-12) {
return 0;
View
18 src/main/java/org/apache/mahout/knn/means/StreamingKmeans.java
@@ -43,7 +43,7 @@ public UpdatableSearcher cluster(DistanceMeasure distance, Iterable<MatrixSlice>
this.distance = distance;
// cluster the data
- UpdatableSearcher centroids = clusterInternal(data, maxClusters);
+ UpdatableSearcher centroids = clusterInternal(data, maxClusters, 1);
// how make a clean set of empty centroids to get ready for final pass through the data
int width = data.iterator().next().vector().size();
@@ -84,13 +84,14 @@ public double estimateCutoff(Iterable<MatrixSlice> data) {
return distanceCutoff;
}
- private UpdatableSearcher clusterInternal(Iterable<MatrixSlice> data, int maxClusters) {
+ private UpdatableSearcher clusterInternal(Iterable<MatrixSlice> data, int maxClusters, int depth) {
int width = data.iterator().next().vector().size();
UpdatableSearcher centroids = new ProjectionSearch(width, distance, 4, 10);
// now we scan the data and either add each point to the nearest group or create a new group
// when we get too many groups, then we need to increase the threshold and rescan our current groups
Random rand = RandomUtils.getRandom();
+ int n = 0;
for (MatrixSlice row : data) {
if (centroids.size() == 0) {
// add first centroid on first vector
@@ -111,13 +112,20 @@ private UpdatableSearcher clusterInternal(Iterable<MatrixSlice> data, int maxClu
}
}
- if (centroids.size() > maxClusters) {
- distanceCutoff *= 1.5;
+ if (depth < 2 && centroids.size() > maxClusters) {
+ maxClusters = (int) Math.max(maxClusters, 10 * Math.log(n));
// TODO does shuffling help?
List<MatrixSlice> shuffled = Lists.newArrayList(centroids);
Collections.shuffle(shuffled);
- centroids = clusterInternal(shuffled, maxClusters);
+ centroids = clusterInternal(shuffled, maxClusters, depth + 1);
+ // for distributions with sharp scale effects, the distanceCutoff can grow to
+ // excessive size leading sub-clustering to collapse the centroids set too much.
+ // This test prevents that collapse from getting too severe.
+ if (centroids.size() > 0.1 * maxClusters) {
+ distanceCutoff *= 1.5;
+ }
}
+ n++;
}
return centroids;
}
View
5 src/main/java/org/apache/mahout/knn/search/Brute.java
@@ -26,8 +26,6 @@
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.function.DoubleDoubleFunction;
-import org.apache.mahout.math.function.Functions;
import java.util.Collections;
import java.util.Iterator;
@@ -150,8 +148,7 @@ public void setSearchSize(int size) {
List<List<WeightedVector>> r = Lists.newArrayList();
for (MatrixSlice row : query) {
- q.add(new PriorityQueue<WeightedVector>());
- r.add(Lists.reverse(Lists.newArrayList(searchInternal(row.vector(), reference, n, q.get(row.index())))));
+ r.add(search(row.vector(), n));
}
return r;
}
View
3 src/main/java/org/apache/mahout/knn/search/ProjectionSearch.java
@@ -34,6 +34,7 @@
import java.util.Collection;
import java.util.Collections;
+import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;
@@ -144,7 +145,7 @@ protected MatrixSlice computeNext() {
if (!data.hasNext()) {
return endOfData();
} else {
- return new MatrixSlice(data.next(), index++);
+ return new MatrixSlice(data.next().getVector(), index++);
}
}
};
View
3 src/test/java/org/apache/mahout/knn/CentroidTest.java
@@ -43,7 +43,8 @@ public void testUpdate() {
x1.update(c);
// check for correct value
- assertEquals(0, x1.getVector().minus(a.plus(b).plus(c).assign(Functions.div(3))).norm(1), 1e-8);
+ final Vector mean = a.plus(b).plus(c).assign(Functions.div(3));
+ assertEquals(0, x1.getVector().minus(mean).norm(1), 1e-8);
assertEquals(3, x1.getWeight(), 0);
assertEquals(0, x2.minus(a.plus(b).divide(2)).norm(1), 1e-8);
View
2 src/test/java/org/apache/mahout/knn/SampleSequenceFileWriterTest.java
@@ -27,7 +27,7 @@
public class SampleSequenceFileWriterTest {
@Test
public void testWrite() throws IOException {
- List<Vector> data = SampleSequenceFileWriter.writeTestFile("foo", 30, 1000000, false);
+ List<Vector> data = SampleSequenceFileWriter.writeTestFile("foo", 30, 10000, true);
List<Vector> actual = SampleSequenceFileWriter.readTestFile("foo");
Assert.assertEquals(data.size(), actual.size());
Assert.assertTrue(data.size() > 0);
View
38 src/test/java/org/apache/mahout/knn/WeightedVectorTest.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.knn;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+
+public class WeightedVectorTest {
+ @Test
+ public void testLength() {
+ Vector v = new DenseVector(new double[]{0.9921337470551008, 1.0031004325833064, 0.9963963182745947});
+ Centroid c = new Centroid(3, new DenseVector(v), 2);
+ assertEquals(c.getVector().getLengthSquared(), c.getLengthSquared(), 1e-17);
+ // previously, this wouldn't clear the cached squared length value correctly which would cause bad distances
+ c.set(0, -1);
+ System.out.printf("c = %.9f\nv = %.9f\n", c.getLengthSquared(), c.getVector().getLengthSquared());
+ assertEquals(c.getVector().getLengthSquared(), c.getLengthSquared(), 1e-17);
+ }
+}
View
34 src/test/java/org/apache/mahout/knn/means/StreamingKmeansTest.java
@@ -18,18 +18,19 @@
package org.apache.mahout.knn.means;
import com.google.common.collect.Lists;
-import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.knn.WeightedVector;
import org.apache.mahout.knn.generate.MultiNormal;
-import org.apache.mahout.knn.search.UpdatableSearcher;
+import org.apache.mahout.knn.search.Searcher;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.junit.Assert;
import org.junit.Test;
import java.util.List;
-import java.util.Random;
+
+import static org.junit.Assert.assertTrue;
public class StreamingKmeansTest {
@@ -44,23 +45,30 @@ public void testEstimateBeta() {
@Test
public void testClustering1() {
- Matrix data = new DenseMatrix(800, 3);
+ // construct data samplers centered on the corners of a unit cube
Matrix mean = new DenseMatrix(8, 3);
List<MultiNormal> rowSamplers = Lists.newArrayList();
for (int i = 0; i < 8; i++) {
mean.viewRow(i).assign(new double[]{0.25 * (i & 4), 0.5 * (i & 2), i & 1});
- MultiNormal gen = new MultiNormal(0.1, mean.viewRow(i));
+ MultiNormal gen = new MultiNormal(0.01, mean.viewRow(i));
rowSamplers.add(gen);
}
-
- Random rowSelector = RandomUtils.getRandom();
- for (MatrixSlice slice : data) {
- slice.vector().assign(rowSamplers.get(rowSelector.nextInt(8)).sample());
+ // sample a bunch of data points
+ long t0 = System.currentTimeMillis();
+ Matrix data = new DenseMatrix(10000, 3);
+ for (MatrixSlice row : data) {
+ row.vector().assign(rowSamplers.get(row.index() % 8).sample());
}
-
-
- UpdatableSearcher r = new StreamingKmeans().cluster(new EuclideanDistanceMeasure(), data, 30);
-
+ long t1 = System.currentTimeMillis();
+ // cluster the data
+ Searcher r = new StreamingKmeans().cluster(new EuclideanDistanceMeasure(), data, 1000);
+ long t2 = System.currentTimeMillis();
+ // and verify that each corner of the cube has a centroid very nearby
+ for (MatrixSlice row : mean) {
+ WeightedVector v = r.search(row.vector(), 1).get(0);
+ assertTrue(v.getWeight() < 0.05);
+ }
+ System.out.printf("%.2f s for data generation\n%.2f for clustering\n", (t1 - t0) / 1000.0, (t2 - t1) / 1000.0);
}
}
View
27 src/test/java/org/apache/mahout/knn/search/AbstractSearchTest.java
@@ -48,12 +48,12 @@ protected static Matrix randomData() {
public abstract Iterable<MatrixSlice> testData();
- public abstract Searcher getSearch();
+ public abstract Searcher getSearch(int n);
@Test
public void testExactMatch() {
List<WeightedVector> queries = subset(testData(), 100);
- Searcher s = getSearch();
+ Searcher s = getSearch(20);
s.addAll(testData());
assertEquals(Iterables.size(testData()), s.size());
@@ -68,7 +68,7 @@ public void testExactMatch() {
@Test
public void testNearMatch() {
List<WeightedVector> queries = subset(testData(), 100);
- Searcher s = getSearch();
+ Searcher s = getSearch(20);
s.addAll(testData());
MultiNormal noise = new MultiNormal(0.01, new DenseVector(20));
@@ -90,7 +90,7 @@ public void testOrdering() {
queries.viewRow(i).assign(gen.sample());
}
- Searcher s = getSearch();
+ Searcher s = getSearch(20);
s.setSearchSize(200);
s.addAll(testData());
@@ -105,8 +105,25 @@ public void testOrdering() {
}
@Test
+ public void testSmallSearch() {
+ Matrix m = new DenseMatrix(8, 3);
+ for (int i = 0; i < 8; i++) {
+ m.viewRow(i).assign(new double[]{0.125 * (i & 4), i & 2, i & 1});
+ }
+
+ Searcher s = getSearch(3);
+ s.addAll(m);
+ for (MatrixSlice row : m) {
+ final List<WeightedVector> r = s.search(row.vector(), 3);
+ assertEquals(0, r.get(0).getWeight(), 1e-8);
+ assertEquals(0, r.get(1).getWeight(), 0.5);
+ assertEquals(0, r.get(2).getWeight(), 1);
+ }
+ }
+
+ @Test
public void testRemoval() {
- Searcher s = getSearch();
+ Searcher s = getSearch(20);
s.addAll(testData());
if (s instanceof UpdatableSearcher) {
List<WeightedVector> x = subset(s, 2);
View
29 src/test/java/org/apache/mahout/knn/search/BruteTest.java
@@ -18,9 +18,16 @@
package org.apache.mahout.knn.search;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.knn.WeightedVector;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.junit.Before;
-import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
public class BruteTest extends AbstractSearchTest {
private static Iterable<MatrixSlice> data;
@@ -36,7 +43,25 @@ public void fillData() {
}
@Override
- public UpdatableSearcher getSearch() {
+ public UpdatableSearcher getSearch(int n) {
return new Brute(new EuclideanDistanceMeasure());
}
+
+ @Test
+ public void testMatrixSearch() {
+ Matrix m = new DenseMatrix(8, 3);
+ for (int i = 0; i < 8; i++) {
+ m.viewRow(i).assign(new double[]{0.125 * (i & 4), i & 2, i & 1});
+ }
+
+ Brute s = (Brute) getSearch(3);
+ s.addAll(m);
+
+ final List<List<WeightedVector>> searchResults = s.search(m, 3);
+ for (List<WeightedVector> r : searchResults) {
+ assertEquals(0, r.get(0).getWeight(), 1e-8);
+ assertEquals(0.5, r.get(1).getWeight(), 1e-8);
+ assertEquals(1, r.get(2).getWeight(), 1e-8);
+ }
+ }
}
View
7 src/test/java/org/apache/mahout/knn/search/ProjectionSearch3Test.java
@@ -24,13 +24,10 @@
public class ProjectionSearch3Test extends AbstractSearchTest {
private static Matrix data;
- private static ProjectionSearch3 searcher;
@BeforeClass
public static void setUp() {
data = randomData();
-
- searcher = new ProjectionSearch3(20, new EuclideanDistanceMeasure(), 4, 20);
}
@Override
@@ -39,7 +36,7 @@ public static void setUp() {
}
@Override
- public Searcher getSearch() {
- return searcher;
+ public Searcher getSearch(int n) {
+ return new ProjectionSearch3(n, new EuclideanDistanceMeasure(), 4, 20);
}
}
View
6 src/test/java/org/apache/mahout/knn/search/ProjectionSearchTest.java
@@ -26,7 +26,6 @@
public class ProjectionSearchTest extends AbstractSearchTest {
private static Matrix data;
- private static ProjectionSearch searcher;
@BeforeClass
public static void setUp() {
@@ -36,7 +35,6 @@ public static void setUp() {
slice.vector().assign(gen.sample());
}
- searcher = new ProjectionSearch(20, new EuclideanDistanceMeasure(), 4, 20);
}
@Override
@@ -45,7 +43,7 @@ public static void setUp() {
}
@Override
- public UpdatableSearcher getSearch() {
- return searcher;
+ public UpdatableSearcher getSearch(int n) {
+ return new ProjectionSearch(n, new EuclideanDistanceMeasure(), 4, 20);
}
}

0 comments on commit c54a3f4

Please sign in to comment.