Skip to content

Commit

Permalink
Merge pull request #9 from ajaech/master
Browse files Browse the repository at this point in the history
add support for multi-threading
  • Loading branch information
percyliang committed Mar 29, 2014
2 parents fbf6dc8 + dffcc96 commit 4080a44
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 31 deletions.
4 changes: 2 additions & 2 deletions Makefile
Expand Up @@ -7,10 +7,10 @@ else
endif

wcluster: $(files)
g++ -Wall -g -O3 -o wcluster $(files)
g++ -Wall -g -std=c++0x -O3 -o wcluster $(files) -lpthread

%.o: %.cc
g++ -Wall -g -O3 -o $@ -c $<
g++ -Wall -g -O3 -std=c++0x -o $@ -c $<

clean:
rm wcluster basic/*.o *.o
2 changes: 1 addition & 1 deletion README
Expand Up @@ -3,7 +3,7 @@ Percy Liang
Release 1.3
2012.07.24

Input: a sequence of words separated by whitespcae (see input.txt for an example).
Input: a sequence of words separated by whitespace (see input.txt for an example).
Output: for each word type, its cluster (see output.txt for an example).
In particular, each line is:
<cluster represented as a bit string> <word> <number of times word occurs in input>
Expand Down
159 changes: 131 additions & 28 deletions wcluster.cc
Expand Up @@ -16,6 +16,8 @@ The four structures p1, p2, q2, L2 allow this quick computation.
Changes:
* Removed hash tables for efficiency.
* Notation: a is an phrase (sequence of words), c is a cluster, s is a slot.
* Removed hash tables for efficiency.
* Notation: a is an phrase (sequence of words), c is a cluster, s is a slot.
To cut down memory usage:
* Change double to float.
Expand All @@ -35,6 +37,9 @@ To cut down memory usage:
#include "basic/mem-tracker.h"
#include "basic/opt.h"
#include <unistd.h>
#include <condition_variable>
#include <mutex>
#include <thread>

vector< OptInfo<bool> > bool_opts;
vector< OptInfo<int> > int_opts;
Expand All @@ -55,6 +60,7 @@ opt_define_int(initC, "c", 1000, "Number of clusters."
opt_define_int(plen, "plen", 1, "Maximum length of a phrase to consider.");
opt_define_int(min_occur, "min-occur", 1, "Keep phrases that occur at least this many times.");
opt_define_int(rand_seed, "rand", time(NULL)*getpid(), "Number to call srand with.");
opt_define_int(num_threads, "threads", 1, "Number of threads to use in the worker pool.");

opt_define_bool(chk, "chk", false, "Check data structures are valid (expensive).");
opt_define_bool(print_stats, "stats", false, "Just print out stats.");
Expand Down Expand Up @@ -108,6 +114,19 @@ double curr_minfo; // Mutual info, should be sum of all q2's
// Map phrase to the KL divergence to its cluster
DoubleVec kl_map[2];

// Variables used to control the thread pool
mutex * thread_idle;
mutex * thread_start;
thread * threads;
struct Compute_L2_Job {
int s;
int t;
int u;
bool is_type_a;
};
Compute_L2_Job the_job;
bool all_done = false;

#define FOR_SLOT(s) \
for(int s = 0; s < len(slot2cluster); s++) \
for(bool _tmp = true; slot2cluster[s] != -1 && _tmp; _tmp = false)
Expand Down Expand Up @@ -403,7 +422,12 @@ void read_text() {
// O(C) time.
double compute_s1(int s) { // compute s1[s]
double q = 0.0;
FOR_SLOT(t) q += bi_q2(s, t);

for(int t = 0; t < len(slot2cluster); t++) {
if (slot2cluster[t] == -1) continue;
q += bi_q2(s, t);
}

return q;
}

Expand All @@ -413,7 +437,14 @@ double compute_L2(int s, int t) { // compute L2[s, t]
// st is the hypothetical new cluster that combines s and t

// Lose old associations with s and t
double l = compute_s1(s) + compute_s1(t) - bi_q2(s, t);
double l = 0.0;
for (int w = 0; w < len(slot2cluster); w++) {
if ( slot2cluster[w] == -1) continue;
l += q2[s][w] + q2[w][s];
l += q2[t][w] + q2[w][t];
}
l -= q2[s][s] + q2[t][t];
l -= bi_q2(s, t);

// Form new associations with st
FOR_SLOT(u) {
Expand Down Expand Up @@ -641,28 +672,79 @@ void incorporate_new_phrase(int a) {

// Update L2: O(C^2)
track_block("Update L2", "", false) {
FOR_SLOT(t) { // L2[s, *], L2[*, s]
if(s == t) continue;
int S, T;
if(ORDER_VALID(s, t)) S = s, T = t;
else S = t, T = s;
L2[S][T] = compute_L2(S, T);
logs("L2[" << Slot(S) << ", " << Slot(T) << "] = " << L2[S][T]);
}

FOR_SLOT(t) { // L2[not s, not s]
if(t == s) continue;
FOR_SLOT(u) {
if(u == s) continue;
if(!ORDER_VALID(t, u)) continue;
L2[t][u] += bi_q2(t, s) + bi_q2(u, s) - bi_hyp_q2(_(t, u), s);
}
the_job.s = s;
the_job.is_type_a = true;
// start the jobs
for (int ii=0; ii<num_threads; ii++) {
thread_start[ii].unlock(); // the thread waits for this lock to begin
}
// wait for them to be done
for (int ii=0; ii<num_threads; ii++) {
thread_idle[ii].lock(); // the thread releases the lock to finish
}
}

//dump();
}


void update_L2(int thread_id) {

while (true) {

// wait for mutex to unlock to begin the job
thread_start[thread_id].lock();
if ( all_done ) break; // mechanism to close the threads

int num_clusters = len(slot2cluster);

if (the_job.is_type_a) {
int s = the_job.s;

for(int t=thread_id; t < num_clusters; t += num_threads) { // L2[s, *], L2[*, s]
if (slot2cluster[t] == -1) continue;
if (s == t) continue;
int S, T;
if(ORDER_VALID(s, t)) S = s, T = t;
else S = t, T = s;
L2[S][T] = compute_L2(S, T);
}

for(int t=thread_id; t < num_clusters; t += num_threads) {
if (slot2cluster[t] == -1) continue;
if (t == s) continue;
FOR_SLOT(u) {
if(u == s) continue;
if(!ORDER_VALID(t, u)) continue;
L2[t][u] += bi_q2(t, s) + bi_q2(u, s) - bi_hyp_q2(_(t, u), s);
}
}

} else { // this is a type B job
int s = the_job.s;
int t = the_job.t;
int u = the_job.u;

for (int v = thread_id; v < num_clusters; v += num_threads) {
if ( slot2cluster[v] == -1) continue;
for ( int w = 0; w < num_clusters; w++) {
if ( slot2cluster[w] == -1) continue;
if(!ORDER_VALID(v, w)) continue;

if(v == u || w == u)
L2[v][w] = compute_L2(v, w);
else
L2[v][w] = compute_L2_using_old(s, t, u, v, w);
}
}
}

// signal that the thread is done by unlocking the mutex
thread_idle[thread_id].unlock();
}
}

// O(C^2) time.
// Merge clusters a (in slot s) and b (in slot t) into c (in slot u).
void merge_clusters(int s, int t) {
Expand Down Expand Up @@ -714,17 +796,18 @@ void merge_clusters(int s, int t) {

// Compute L2: O(C^2)
track_block("Compute L2", "", false) {
FOR_SLOT(v) {
FOR_SLOT(w) {
if(!ORDER_VALID(v, w)) continue;
double l;
if(v == u || w == u)
l = compute_L2(v, w);
else
l = compute_L2_using_old(s, t, u, v, w);
L2[v][w] = l;
logs("L2[" << Slot(v) << "," << Slot(w) << "] = " << l << ", resulting minfo = " << curr_minfo-l);
}
the_job.s = s;
the_job.t = t;
the_job.u = u;
the_job.is_type_a = false;

// start the jobs
for (int ii=0; ii<num_threads; ii++) {
thread_start[ii].unlock(); // the thread waits for this lock to begin
}
// wait for them to be done
for (int ii=0; ii<num_threads; ii++) {
thread_idle[ii].lock(); // the thread releases the lock to finish
}
}
}
Expand Down Expand Up @@ -940,6 +1023,16 @@ void do_clustering() {
compute_L2();
repcheck();

// start the threads
thread_start = new mutex[num_threads];
thread_idle = new mutex[num_threads];
threads = new thread[num_threads];
for (int ii=0; ii<num_threads; ii++) {
thread_start[ii].lock();
thread_idle[ii].lock();
threads[ii] = thread(update_L2, ii);
}

curr_cluster_id = N; // New cluster ids will start at N, after all the phrases.

// Stage 1: Maintain initC clusters. For each of the phrases initC..N-1, make
Expand Down Expand Up @@ -974,6 +1067,16 @@ void do_clustering() {
}
}

// finish the threads
all_done = true;
for (int ii=0; ii<num_threads; ii++) {
thread_start[ii].unlock(); // thread will grab this to start
threads[ii].join();
}
delete [] thread_start;
delete [] thread_idle;
delete [] threads;

logs("Done: 1 cluster left: mutual info = " << curr_minfo);
mem_tracker.report_mem_usage();
//assert(feq(curr_minfo, 0.0));
Expand Down

0 comments on commit 4080a44

Please sign in to comment.