Skip to content

Commit

Permalink
beginnings of beam search
Browse files Browse the repository at this point in the history
  • Loading branch information
hal3 committed Jun 19, 2012
1 parent 7faa89d commit ac2ed15
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 218 deletions.
2 changes: 1 addition & 1 deletion Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ nobase_include_HEADERS = vowpalwabbit/accumulate.h vowpalwabbit/comp_io.h vowpal
vowpalwabbit/simple_label.h vowpalwabbit/allreduce.h vowpalwabbit/config.h vowpalwabbit/gd_mf.h vowpalwabbit/lda_core.h vowpalwabbit/oaa.h \
vowpalwabbit/parse_regressor.h vowpalwabbit/sparse_dense.h vowpalwabbit/bfgs.h vowpalwabbit/constant.h vowpalwabbit/global_data.h vowpalwabbit/loss_functions.h \
vowpalwabbit/parse_args.h vowpalwabbit/parser.h vowpalwabbit/unique_sort.h vowpalwabbit/cache.h vowpalwabbit/example.h vowpalwabbit/hash.h vowpalwabbit/network.h \
vowpalwabbit/parse_example.h vowpalwabbit/sender.h vowpalwabbit/v_array.h vowpalwabbit/v_hashmap.h vowpalwabbit/wap.h vowpalwabbit/searn.h \
vowpalwabbit/parse_example.h vowpalwabbit/sender.h vowpalwabbit/v_array.h vowpalwabbit/v_hashmap.h vowpalwabbit/wap.h vowpalwabbit/beam.h vowpalwabbit/searn.h \
vowpalwabbit/searn_sequencetask.h vowpalwabbit/sequence.h vowpalwabbit/csoaa.h vowpalwabbit/ect.h

ACLOCAL_AMFLAGS = -I acinclude.d
Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ include_HEADERS = allreduce.h

bin_PROGRAMS = vw active_interactor

vw_SOURCES = hash.cc global_data.cc io.cc parse_regressor.cc parse_primitives.cc unique_sort.cc cache.cc simple_label.cc oaa.cc ect.cc csoaa.cc v_hashmap.cc wap.cc searn.cc searn_sequencetask.cc sequence.cc parse_example.cc sparse_dense.cc network.cc parse_args.cc accumulate.cc gd.cc lda_core.cc gd_mf.cc bfgs.cc noop.cc example.cc parser.cc vw.cc loss_functions.cc sender.cc
vw_SOURCES = hash.cc global_data.cc io.cc parse_regressor.cc parse_primitives.cc unique_sort.cc cache.cc simple_label.cc oaa.cc ect.cc csoaa.cc v_hashmap.cc wap.cc beam.cc searn.cc searn_sequencetask.cc sequence.cc parse_example.cc sparse_dense.cc network.cc parse_args.cc accumulate.cc gd.cc lda_core.cc gd_mf.cc bfgs.cc noop.cc example.cc parser.cc vw.cc loss_functions.cc sender.cc
vw_LDADD = allreduce.o
vw_DEPENDENCIES = allreduce.o

Expand Down
195 changes: 195 additions & 0 deletions vowpalwabbit/beam.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
#include <iostream>
#include <float.h>
#include <stdio.h>
#include <math.h>
#include "beam.h"
#include "v_hashmap.h"
#include "v_array.h"

#define MULTIPLIER 5

using namespace std;

namespace Beam
{
int compare_elem(const void *va, const void *vb) {
// first sort on hash, then on loss
elem* a = (elem*)va;
elem* b = (elem*)vb;
if (a->hash < b->hash) { return -1; }
if (a->hash > b->hash) { return 1; }
return b->loss - a->loss; // if b is greater, it should go second
}

beam::beam(bool (*eq)(state,state), size_t (*hs)(state), size_t max_beam_size) {
equivalent = eq;
hash = hs;
empty_bucket = new v_array<elem>();
last_retrieved = NULL;
max_size = max_beam_size;
losses = (float*)calloc(max_size, sizeof(float));
dat = new v_hashmap<size_t,bucket>(8, empty_bucket, NULL);
}

beam::~beam() {
// TODO: really free the elements
delete dat;
free(empty_bucket->begin);
delete empty_bucket;
}

size_t hash_bucket(size_t id) { return 1043221*(893901 + id); }

void beam::put(size_t id, state s, size_t hs, float loss) {
elem e = { s, hs, loss, id, last_retrieved };
// check to see if we have this bucket yet
bucket b = dat->get(id, hash_bucket(id));
if (b->index() > 0) { // this one exists: just add to it
push(*b, e);
//dat->put_after_get(id, hash_bucket(id), b);
if (b->index() >= max_size * MULTIPLIER)
prune(id);
} else {
bucket bnew = new v_array<elem>();
push(*bnew, e);
dat->put_after_get(id, hash_bucket(id), bnew);
}
}

void beam::iterate(size_t id, void (*f)(beam*,size_t,state,float)) {
bucket b = dat->get(id, hash_bucket(id));
if (b->index() == 0) return;

cout << "before prune" << endl;
prune(id);
cout << "after prune" << endl;

for (elem*e=b->begin; e!=b->end; e++) {
cout << "element" << endl;
if (e->alive) {
last_retrieved = e;
f(this, id, e->s, e->loss);
}
}
}

#define SWAP(a,b) temp=(a);(a)=(b);(b)=temp;
float quickselect(float *arr, size_t n, size_t k) {
size_t i,ir,j,l,mid;
float a,temp;

l=0;
ir=n-1;
for(;;) {
if (ir <= l+1) {
if (ir == l+1 && arr[ir] < arr[l]) {
SWAP(arr[l],arr[ir]);
}
return arr[k];
}
else {
mid=(l+ir) >> 1;
SWAP(arr[mid],arr[l+1]);
if (arr[l] > arr[ir]) {
SWAP(arr[l],arr[ir]);
}
if (arr[l+1] > arr[ir]) {
SWAP(arr[l+1],arr[ir]);
}
if (arr[l] > arr[l+1]) {
SWAP(arr[l],arr[l+1]);
}
i=l+1;
j=ir;
a=arr[l+1];
for (;;) {
do i++; while (arr[i] < a);
do j--; while (arr[j] > a);
if (j < i) break;
SWAP(arr[i],arr[j]);
}
arr[l+1]=arr[j];
arr[j]=a;
if (j >= k) ir=j-1;
if (j <= k) l=i;
}
}
}


void beam::prune(size_t id) {
bucket b = dat->get(id, hash_bucket(id));
if (b->index() == 0) return;

size_t num_alive = 0;
if (equivalent == NULL) {
for (size_t i=1; i<b->index(); i++) {
(*b)[i].alive = true;
}
num_alive = b->index();
} else {
// first, sort on hash, backing off to loss
qsort(b->begin, b->index(), sizeof(elem), compare_elem);

// now, check actual equivalence
size_t last_pos = 0;
size_t last_hash = (*b)[0].hash;
for (size_t i=1; i<b->index(); i++) {
(*b)[i].alive = true;
if ((*b)[i].hash != last_hash) {
last_pos = i;
last_hash = (*b)[i].hash;
} else {
for (size_t j=last_pos; j<i; j++) {
if ((*b)[j].alive && equivalent((*b)[j].s, (*b)[i].s)) {
(*b)[i].alive = false;
break;
}
}
}

if ((*b)[i].alive) {
losses[num_alive] = (*b)[i].loss;
num_alive++;
}
}
}

if (num_alive <= max_size) return;

// sort the remaining items on loss
float cutoff = quickselect(losses, num_alive, max_size);
bucket bnew = new v_array<elem>();
for (elem*e=b->begin; e!=b->end; e++) {
if (e->loss > cutoff) continue;
push(*bnew, *e);
num_alive--;
if (num_alive < 0) break;
}
dat->put_after_get(id, hash_bucket(id), bnew);
}


struct test_beam_state {
size_t id;
};
bool state_eq(state a,state b) { return ((test_beam_state*)a)->id == ((test_beam_state*)b)->id; }
size_t state_hash(state a) { return 381049*(3820+((test_beam_state*)a)->id); }
void expand_state(beam*b, size_t old_id, state old_state, float old_loss) {
test_beam_state* new_state = (test_beam_state*)calloc(1, sizeof(test_beam_state));
new_state->id = old_id + ((test_beam_state*)old_state)->id * 2;
float new_loss = old_loss + 0.5;
cout << "expand_state " << old_loss << " -> " << new_state->id << " , " << new_loss << endl;
b->put(old_id+1, new_state, new_loss);
}
void test_beam() {
beam*b = new beam(&state_eq, &state_hash, 5);
for (size_t i=0; i<25; i++) {
test_beam_state* s = (test_beam_state*)calloc(1, sizeof(test_beam_state));
s->id = i / 3;
b->put(0, s, 0. - (float)i);
cout << "added " << s->id << endl;
}
b->iterate(0, expand_state);
}
}
44 changes: 44 additions & 0 deletions vowpalwabbit/beam.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#ifndef BEAM_H
#define BEAM_H

#include <stdio.h>
#include "v_hashmap.h"
#include "v_array.h"

typedef void* state;

namespace Beam
{
struct elem {
state s;
size_t hash;
float loss;
size_t bucket_id;
elem* backpointer;
bool alive;
};

typedef v_array<elem>* bucket;

class beam {
public:
bool (*equivalent)(state, state);
size_t (*hash)(state);

v_hashmap<size_t, bucket>* dat;

bucket empty_bucket;
elem* last_retrieved;
size_t max_size;
float* losses;

beam(bool (*eq)(state,state), size_t (*hs)(state), size_t max_beam_size);
~beam();
void put(size_t id, state s, size_t hs, float loss);
void put(size_t id, state s, float loss) { put(id, s, hash(s), loss); }
void iterate(size_t id, void (*f)(beam*,size_t,state,float));
void prune(size_t id);
};
}

#endif
Loading

0 comments on commit ac2ed15

Please sign in to comment.