Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for multi-threading #9

Merged
merged 2 commits into from Mar 29, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Expand Up @@ -7,10 +7,10 @@ else
endif

wcluster: $(files)
g++ -Wall -g -O3 -o wcluster $(files)
g++ -Wall -g -std=c++0x -O3 -o wcluster $(files) -lpthread

%.o: %.cc
g++ -Wall -g -O3 -o $@ -c $<
g++ -Wall -g -O3 -std=c++0x -o $@ -c $<

clean:
rm wcluster basic/*.o *.o
2 changes: 1 addition & 1 deletion README
Expand Up @@ -3,7 +3,7 @@ Percy Liang
Release 1.3
2012.07.24

Input: a sequence of words separated by whitespcae (see input.txt for an example).
Input: a sequence of words separated by whitespace (see input.txt for an example).
Output: for each word type, its cluster (see output.txt for an example).
In particular, each line is:
<cluster represented as a bit string> <word> <number of times word occurs in input>
Expand Down
159 changes: 131 additions & 28 deletions wcluster.cc
Expand Up @@ -16,6 +16,8 @@ The four structures p1, p2, q2, L2 allow this quick computation.
Changes:
* Removed hash tables for efficiency.
* Notation: a is an phrase (sequence of words), c is a cluster, s is a slot.
* Removed hash tables for efficiency.
* Notation: a is an phrase (sequence of words), c is a cluster, s is a slot.

To cut down memory usage:
* Change double to float.
Expand All @@ -35,6 +37,9 @@ To cut down memory usage:
#include "basic/mem-tracker.h"
#include "basic/opt.h"
#include <unistd.h>
#include <condition_variable>
#include <mutex>
#include <thread>

vector< OptInfo<bool> > bool_opts;
vector< OptInfo<int> > int_opts;
Expand All @@ -55,6 +60,7 @@ opt_define_int(initC, "c", 1000, "Number of clusters."
opt_define_int(plen, "plen", 1, "Maximum length of a phrase to consider.");
opt_define_int(min_occur, "min-occur", 1, "Keep phrases that occur at least this many times.");
opt_define_int(rand_seed, "rand", time(NULL)*getpid(), "Number to call srand with.");
opt_define_int(num_threads, "threads", 1, "Number of threads to use in the worker pool.");

opt_define_bool(chk, "chk", false, "Check data structures are valid (expensive).");
opt_define_bool(print_stats, "stats", false, "Just print out stats.");
Expand Down Expand Up @@ -108,6 +114,19 @@ double curr_minfo; // Mutual info, should be sum of all q2's
// Map phrase to the KL divergence to its cluster
DoubleVec kl_map[2];

// Variables used to control the thread pool
mutex * thread_idle;
mutex * thread_start;
thread * threads;
struct Compute_L2_Job {
int s;
int t;
int u;
bool is_type_a;
};
Compute_L2_Job the_job;
bool all_done = false;

#define FOR_SLOT(s) \
for(int s = 0; s < len(slot2cluster); s++) \
for(bool _tmp = true; slot2cluster[s] != -1 && _tmp; _tmp = false)
Expand Down Expand Up @@ -403,7 +422,12 @@ void read_text() {
// O(C) time.
double compute_s1(int s) { // compute s1[s]
double q = 0.0;
FOR_SLOT(t) q += bi_q2(s, t);

for(int t = 0; t < len(slot2cluster); t++) {
if (slot2cluster[t] == -1) continue;
q += bi_q2(s, t);
}

return q;
}

Expand All @@ -413,7 +437,14 @@ double compute_L2(int s, int t) { // compute L2[s, t]
// st is the hypothetical new cluster that combines s and t

// Lose old associations with s and t
double l = compute_s1(s) + compute_s1(t) - bi_q2(s, t);
double l = 0.0;
for (int w = 0; w < len(slot2cluster); w++) {
if ( slot2cluster[w] == -1) continue;
l += q2[s][w] + q2[w][s];
l += q2[t][w] + q2[w][t];
}
l -= q2[s][s] + q2[t][t];
l -= bi_q2(s, t);

// Form new associations with st
FOR_SLOT(u) {
Expand Down Expand Up @@ -641,28 +672,79 @@ void incorporate_new_phrase(int a) {

// Update L2: O(C^2)
track_block("Update L2", "", false) {
FOR_SLOT(t) { // L2[s, *], L2[*, s]
if(s == t) continue;
int S, T;
if(ORDER_VALID(s, t)) S = s, T = t;
else S = t, T = s;
L2[S][T] = compute_L2(S, T);
logs("L2[" << Slot(S) << ", " << Slot(T) << "] = " << L2[S][T]);
}

FOR_SLOT(t) { // L2[not s, not s]
if(t == s) continue;
FOR_SLOT(u) {
if(u == s) continue;
if(!ORDER_VALID(t, u)) continue;
L2[t][u] += bi_q2(t, s) + bi_q2(u, s) - bi_hyp_q2(_(t, u), s);
}
the_job.s = s;
the_job.is_type_a = true;
// start the jobs
for (int ii=0; ii<num_threads; ii++) {
thread_start[ii].unlock(); // the thread waits for this lock to begin
}
// wait for them to be done
for (int ii=0; ii<num_threads; ii++) {
thread_idle[ii].lock(); // the thread releases the lock to finish
}
}

//dump();
}


void update_L2(int thread_id) {

while (true) {

// wait for mutex to unlock to begin the job
thread_start[thread_id].lock();
if ( all_done ) break; // mechanism to close the threads

int num_clusters = len(slot2cluster);

if (the_job.is_type_a) {
int s = the_job.s;

for(int t=thread_id; t < num_clusters; t += num_threads) { // L2[s, *], L2[*, s]
if (slot2cluster[t] == -1) continue;
if (s == t) continue;
int S, T;
if(ORDER_VALID(s, t)) S = s, T = t;
else S = t, T = s;
L2[S][T] = compute_L2(S, T);
}

for(int t=thread_id; t < num_clusters; t += num_threads) {
if (slot2cluster[t] == -1) continue;
if (t == s) continue;
FOR_SLOT(u) {
if(u == s) continue;
if(!ORDER_VALID(t, u)) continue;
L2[t][u] += bi_q2(t, s) + bi_q2(u, s) - bi_hyp_q2(_(t, u), s);
}
}

} else { // this is a type B job
int s = the_job.s;
int t = the_job.t;
int u = the_job.u;

for (int v = thread_id; v < num_clusters; v += num_threads) {
if ( slot2cluster[v] == -1) continue;
for ( int w = 0; w < num_clusters; w++) {
if ( slot2cluster[w] == -1) continue;
if(!ORDER_VALID(v, w)) continue;

if(v == u || w == u)
L2[v][w] = compute_L2(v, w);
else
L2[v][w] = compute_L2_using_old(s, t, u, v, w);
}
}
}

// signal that the thread is done by unlocking the mutex
thread_idle[thread_id].unlock();
}
}

// O(C^2) time.
// Merge clusters a (in slot s) and b (in slot t) into c (in slot u).
void merge_clusters(int s, int t) {
Expand Down Expand Up @@ -714,17 +796,18 @@ void merge_clusters(int s, int t) {

// Compute L2: O(C^2)
track_block("Compute L2", "", false) {
FOR_SLOT(v) {
FOR_SLOT(w) {
if(!ORDER_VALID(v, w)) continue;
double l;
if(v == u || w == u)
l = compute_L2(v, w);
else
l = compute_L2_using_old(s, t, u, v, w);
L2[v][w] = l;
logs("L2[" << Slot(v) << "," << Slot(w) << "] = " << l << ", resulting minfo = " << curr_minfo-l);
}
the_job.s = s;
the_job.t = t;
the_job.u = u;
the_job.is_type_a = false;

// start the jobs
for (int ii=0; ii<num_threads; ii++) {
thread_start[ii].unlock(); // the thread waits for this lock to begin
}
// wait for them to be done
for (int ii=0; ii<num_threads; ii++) {
thread_idle[ii].lock(); // the thread releases the lock to finish
}
}
}
Expand Down Expand Up @@ -940,6 +1023,16 @@ void do_clustering() {
compute_L2();
repcheck();

// start the threads
thread_start = new mutex[num_threads];
thread_idle = new mutex[num_threads];
threads = new thread[num_threads];
for (int ii=0; ii<num_threads; ii++) {
thread_start[ii].lock();
thread_idle[ii].lock();
threads[ii] = thread(update_L2, ii);
}

curr_cluster_id = N; // New cluster ids will start at N, after all the phrases.

// Stage 1: Maintain initC clusters. For each of the phrases initC..N-1, make
Expand Down Expand Up @@ -974,6 +1067,16 @@ void do_clustering() {
}
}

// finish the threads
all_done = true;
for (int ii=0; ii<num_threads; ii++) {
thread_start[ii].unlock(); // thread will grab this to start
threads[ii].join();
}
delete [] thread_start;
delete [] thread_idle;
delete [] threads;

logs("Done: 1 cluster left: mutual info = " << curr_minfo);
mem_tracker.report_mem_usage();
//assert(feq(curr_minfo, 0.0));
Expand Down