Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP

Adds initial cut of TopicModelUtils class #18

Merged
merged 4 commits into from about 2 years ago

2 participants

Andy Schlaikjer Jake Mannix
Andy Schlaikjer
Collaborator

This class contains a basic impl of topic-term matrix sparsification routine discussed with Jake.

Does not yet support parameterization of sparsification logic. Wanted to get this written in the most direct way first, test it, then consider options for generalization.

sagemintblue added some commits
Andy Schlaikjer sagemintblue Adds initial cut of TopicModelUtils class
This class contains a basic impl of topic-term sparsification routine
discussed with Jake.

Does not yet support parameterization of sparsification logic. Wanted
to get this written in the most direct way first, test it, then
consider options for generalization.
7bbfa65
Andy Schlaikjer sagemintblue Missed static modifier on sparsifyTopicTermCounts method 7ea1b1d
Andy Schlaikjer sagemintblue Ensures sparsified topic-term vectors allow sequential access
Also:
- Clarifies javadoc of sparsityTopicTermCounts method.
ef8d9b6
Andy Schlaikjer sagemintblue Adds unit test for TopicModeUtils.sparsifyTopicTermCounts method f9acbe5
Jake Mannix
Collaborator

I like the unit test, looking good! We may want to migrate this class out of clustering.lda.cvb at some point, as it's pretty generic.

Jake Mannix
Collaborator

feel free to merge it in yourself if you want to start using it on twitter/master

Andy Schlaikjer sagemintblue merged commit 417d98b into from
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Showing 4 unique commits by 1 author.

Apr 04, 2012
Andy Schlaikjer sagemintblue Adds initial cut of TopicModelUtils class
This class contains a basic impl of topic-term sparsification routine
discussed with Jake.

Does not yet support parameterization of sparsification logic. Wanted
to get this written in the most direct way first, test it, then
consider options for generalization.
7bbfa65
Andy Schlaikjer sagemintblue Missed static modifier on sparsifyTopicTermCounts method 7ea1b1d
Andy Schlaikjer sagemintblue Ensures sparsified topic-term vectors allow sequential access
Also:
- Clarifies javadoc of sparsityTopicTermCounts method.
ef8d9b6
Apr 05, 2012
Andy Schlaikjer sagemintblue Adds unit test for TopicModeUtils.sparsifyTopicTermCounts method f9acbe5
This page is out of date. Refresh to see the latest.
218 core/src/main/java/org/apache/mahout/clustering/lda/cvb/TopicModelUtils.java
... ... @@ -0,0 +1,218 @@
  1 +package org.apache.mahout.clustering.lda.cvb;
  2 +
  3 +import java.util.Iterator;
  4 +import java.util.List;
  5 +import java.util.PriorityQueue;
  6 +
  7 +import org.apache.mahout.math.DenseVector;
  8 +import org.apache.mahout.math.Matrix;
  9 +import org.apache.mahout.math.RandomAccessSparseVector;
  10 +import org.apache.mahout.math.SequentialAccessSparseVector;
  11 +import org.apache.mahout.math.SparseRowMatrix;
  12 +import org.apache.mahout.math.Vector;
  13 +import org.apache.mahout.math.Vector.Element;
  14 +
  15 +import com.google.common.base.Preconditions;
  16 +import com.google.common.collect.Iterators;
  17 +import com.google.common.collect.Lists;
  18 +
  19 +/**
  20 + * Utilities for {@link TopicModel}s.
  21 + */
  22 +public class TopicModelUtils {
  23 + /**
  24 + * Generates a sparse version of input topic-term matrix. Sparsification is
  25 + * performed as follows: For each topic (row), the sum of all entries is found
  26 + * (L1 norm). This sum is then scaled by the threshold argument to find a
  27 + * target count threshold for the current topic. Then the set of term counts
  28 + * with largest weight whose total weight is less than or equal to the count
  29 + * threshold is determined. These term counts are added to the output
  30 + * topic-term count vector for the current topic. The counts for all other
  31 + * terms for the current topic are added to global term count sums to keep
  32 + * track of lost term count mass. Once all truncated topic-term count vectors
  33 + * have been built, the removed term count mass is added evenly to remaining
  34 + * non-zero topic-term count entries: For each term (column), if removed term
  35 + * count mass is greater than zero we find the set of topics (rows) for which
  36 + * term count is still non-zero. We divide removed term count mass by this
  37 + * number and add this fraction of term count mass to each non-zero entry.
  38 + *
  39 + * @param topicTermCounts
  40 + * matrix containing topic-term counts to sparsify.
  41 + * @param threshold
  42 + * relative threshold on each topic's total term count.
  43 + * @return sparsified version of topicTermCounts.
  44 + */
  45 + public static Matrix sparsifyTopicTermCounts(Matrix topicTermCounts,
  46 + double threshold) {
  47 + Preconditions.checkNotNull(topicTermCounts);
  48 + Preconditions.checkArgument(0 < threshold,
  49 + "Expected threshold > 0 but found %s", threshold);
  50 + final int numTopics = topicTermCounts.rowSize();
  51 + final int numTerms = topicTermCounts.columnSize();
  52 + // storage for sparsified topic-term count vectors
  53 + final Vector[] sparseTopicTermCounts = new Vector[numTopics];
  54 + // storage for sums of truncated term counts
  55 + final Vector truncatedTermCounts = new DenseVector(numTerms);
  56 + // priority queue used to collect top-weighted vector entries
  57 + final PriorityQueue<Entry> topTermCountEntries = new PriorityQueue<Entry>(
  58 + (int) (numTerms * threshold));
  59 +
  60 + /*
  61 + * Truncate topic-term vectors while keeping track of lost term count mass.
  62 + * We use the lost count mass to perform term (column) normalization after
  63 + * truncation.
  64 + */
  65 +
  66 + // for each topic index
  67 + for (int t = 0; t < numTopics; ++t) {
  68 + // reset state
  69 + topTermCountEntries.clear();
  70 +
  71 + // fetch term counts and iterator over non-zero elements
  72 + final Vector termCounts = topicTermCounts.viewRow(t);
  73 + final Iterator<Element> itr = termCounts.iterateNonZero();
  74 +
  75 + // determine term count threshold
  76 + final double totalTermCount = termCounts.norm(1);
  77 + final double termCountThreshold = totalTermCount * threshold;
  78 +
  79 + // iterate over non-zero term counts, keeping track of total term count
  80 + double termCount = 0;
  81 + while (itr.hasNext()) {
  82 + Element e = itr.next();
  83 + termCount += e.get();
  84 + topTermCountEntries.add(new Entry(e.index(), e.get()));
  85 +
  86 + // remove elements with smallest count from queue till threshold is met
  87 + while (termCount > termCountThreshold && !topTermCountEntries.isEmpty()) {
  88 + Entry entry = topTermCountEntries.poll();
  89 + int index = entry.getIndex();
  90 + double count = entry.getValue();
  91 + termCount -= count;
  92 + // keep track of truncated mass for this term
  93 + truncatedTermCounts.setQuick(index,
  94 + truncatedTermCounts.getQuick(index) + count);
  95 + }
  96 + }
  97 +
  98 + // initialize output topic-term count vector
  99 + Vector sparseTermCounts = new RandomAccessSparseVector(numTerms,
  100 + topTermCountEntries.size());
  101 + for (Entry e : topTermCountEntries) {
  102 + sparseTermCounts.setQuick(e.getIndex(), e.getValue());
  103 + }
  104 + // ensure sequential access for output vectors
  105 + sparseTermCounts = new SequentialAccessSparseVector(sparseTermCounts);
  106 + sparseTopicTermCounts[t] = sparseTermCounts;
  107 + }
  108 +
  109 + /*
  110 + * now iterate over terms, spreading each term's truncated mass evenly among
  111 + * those topics for which the term still has non-zero count. To improve
  112 + * feature-wise iteration efficiency, we keep track of current non-zero
  113 + * iterator and element for each topic.
  114 + */
  115 +
  116 + // non-zero topic-term count vector iterators
  117 + final List<Iterator<Element>> topicTermElementIters = Lists
  118 + .newArrayListWithCapacity(numTopics);
  119 + // current non-zero topic-term count vector element for each topic
  120 + final List<Element> topicTermElements = Lists
  121 + .newArrayListWithCapacity(numTopics);
  122 + // initialize topic iterators and elements
  123 + for (int t = 0; t < numTopics; ++t) {
  124 + Iterator<Element> itr = sparseTopicTermCounts[t].iterateNonZero();
  125 + if (itr == null) {
  126 + itr = Iterators.emptyIterator();
  127 + }
  128 + topicTermElementIters.add(itr);
  129 + topicTermElements.add(itr.hasNext() ? itr.next() : null);
  130 + }
  131 + // current column of topic-term count elements
  132 + final List<Element> nonzeroTopicElements = Lists
  133 + .newArrayListWithCapacity(numTopics);
  134 +
  135 + // for each term index
  136 + for (int f = 0; f < numTerms; ++f) {
  137 + final double truncatedTermCount = truncatedTermCounts.get(f);
  138 + if (truncatedTermCount == 0) {
  139 + // no truncation occurred for this term; no normalization necessary
  140 + continue;
  141 + }
  142 +
  143 + // find topics for which current term has non-zero count
  144 + nonzeroTopicElements.clear();
  145 + for (int t = 0; t < numTopics; ++t) {
  146 + Element e = topicTermElements.get(t);
  147 + if (e == null) {
  148 + continue;
  149 + }
  150 + final Iterator<Element> itr = topicTermElementIters.get(t);
  151 + while (e != null && e.index() < f) {
  152 + if (!itr.hasNext()) {
  153 + e = null;
  154 + } else {
  155 + e = itr.next();
  156 + }
  157 + }
  158 + topicTermElements.set(t, e);
  159 + if (e == null || e.index() > f) {
  160 + continue;
  161 + }
  162 + nonzeroTopicElements.add(e);
  163 + }
  164 +
  165 + // deal with case where term has been removed from *all* topics
  166 + if (nonzeroTopicElements.isEmpty()) {
  167 + // TODO(Andy Schlaikjer): What should be done?
  168 + continue;
  169 + }
  170 +
  171 + // term count mass to add to each topic-term count
  172 + final double termCountDelta = truncatedTermCount
  173 + / nonzeroTopicElements.size();
  174 +
  175 + // update topic-term counts
  176 + for (Element e : nonzeroTopicElements) {
  177 + e.set(e.get() + termCountDelta);
  178 + }
  179 + }
  180 +
  181 + // create the sparse matrix
  182 + return new SparseRowMatrix(numTopics, numTerms, sparseTopicTermCounts,
  183 + true, true);
  184 + }
  185 +
  186 + /**
  187 + * Comparable struct for {@link Element} data. Natural ordering of
  188 + * {@link Entry} instances is value desc, index asc.
  189 + */
  190 + private static final class Entry implements Comparable<Entry> {
  191 + private final int index;
  192 + private final double value;
  193 +
  194 + public Entry(int index, double value) {
  195 + this.index = index;
  196 + this.value = value;
  197 + }
  198 +
  199 + public int getIndex() {
  200 + return index;
  201 + }
  202 +
  203 + public double getValue() {
  204 + return value;
  205 + }
  206 +
  207 + @Override
  208 + public int compareTo(Entry o) {
  209 + if (this == o) return 0;
  210 + if (o == null) return 1;
  211 + if (value > o.value) return -1;
  212 + if (value < o.value) return 1;
  213 + if (index < o.index) return -1;
  214 + if (index > o.index) return 1;
  215 + return 0;
  216 + }
  217 + }
  218 +}
55 core/src/test/java/org/apache/mahout/clustering/lda/cvb/TestTopicModelUtils.java
... ... @@ -0,0 +1,55 @@
  1 +package org.apache.mahout.clustering.lda.cvb;
  2 +
  3 +import static org.junit.Assert.assertEquals;
  4 +import static org.junit.Assert.assertNotNull;
  5 +import static org.junit.Assert.assertTrue;
  6 +
  7 +import org.apache.mahout.clustering.ClusteringTestUtils;
  8 +import org.apache.mahout.math.Matrix;
  9 +import org.apache.mahout.math.Vector;
  10 +import org.apache.mahout.math.function.VectorFunction;
  11 +import org.junit.Test;
  12 +
  13 +/**
  14 + * Tests for {@link TopicModelUtils}.
  15 + */
  16 +public class TestTopicModelUtils {
  17 + public void assertColumnNormsEqualOrZero(Matrix expected, Matrix actual) {
  18 + assertNotNull(actual);
  19 + assertEquals(expected.columnSize(), actual.columnSize());
  20 + for (int c = 0; c < expected.columnSize(); ++c) {
  21 + Vector expectedColumn = expected.viewColumn(c);
  22 + Vector actualColumn = actual.viewColumn(c);
  23 + assertNotNull(actualColumn);
  24 + double expectedNorm = expectedColumn.norm(1);
  25 + double actualNorm = actualColumn.norm(1);
  26 + if (actualNorm == 0) {
  27 + continue;
  28 + }
  29 + assertEquals(expectedNorm, actualNorm, 1e-10);
  30 + }
  31 + }
  32 +
  33 + public long numNonzeros(Matrix matrix) {
  34 + return (long) matrix.aggregateRows(new VectorFunction() {
  35 + @Override
  36 + public double apply(Vector v) {
  37 + return v.getNumNondefaultElements();
  38 + }
  39 + }).norm(1);
  40 + }
  41 +
  42 + public void assertFewerNonzeros(Matrix expected, Matrix actual) {
  43 + assertTrue(numNonzeros(expected) > numNonzeros(actual));
  44 + }
  45 +
  46 + @Test
  47 + public void test() {
  48 + double threshold = 0.5;
  49 + Matrix topicTermCounts = ClusteringTestUtils.randomStructuredModel(20, 100);
  50 + Matrix sparseTopicTermCounts = TopicModelUtils.sparsifyTopicTermCounts(
  51 + topicTermCounts, threshold);
  52 + assertColumnNormsEqualOrZero(topicTermCounts, sparseTopicTermCounts);
  53 + assertFewerNonzeros(topicTermCounts, sparseTopicTermCounts);
  54 + }
  55 +}

Tip: You can add notes to lines in a file. Hover to the left of a line to make a note

Something went wrong with that request. Please try again.