diff --git a/core/src/main/java/org/apache/mahout/clustering/lda/cvb/TopicModel.java b/core/src/main/java/org/apache/mahout/clustering/lda/cvb/TopicModel.java index 828a470d1..6a55676a3 100644 --- a/core/src/main/java/org/apache/mahout/clustering/lda/cvb/TopicModel.java +++ b/core/src/main/java/org/apache/mahout/clustering/lda/cvb/TopicModel.java @@ -16,6 +16,7 @@ */ package org.apache.mahout.clustering.lda.cvb; +import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import org.apache.hadoop.conf.Configurable; import org.apache.hadoop.conf.Configuration; @@ -153,6 +154,12 @@ private static Vector viewRowSums(Matrix m) { } private void initializeThreadPool() { + // avoid initializing thread pool if client has specified less than one thread + if (numThreads < 1) { + log.info("Skipping thread pool initialization"); + return; + } + log.info("Initializing thread pool with {} threads", numThreads); ThreadPoolExecutor threadPool = new ThreadPoolExecutor(numThreads, numThreads, 0, TimeUnit.SECONDS, new ArrayBlockingQueue(numThreads * 10)); threadPool.allowCoreThreadTimeOut(false); @@ -323,6 +330,8 @@ public Vector infer(Vector original, Vector docTopicPrior, double minRelPerplexi } public void update(Matrix docTopicCounts) { + Preconditions.checkState(updaters.length > 0, + "Unable to update model; No threads requested during TopicModel instantiation"); for(int x = 0; x < numTopics; x++) { updaters[x % updaters.length].update(x, docTopicCounts.viewRow(x)); }