Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
tree: 3acdf1e4b3
Fetching contributors…

Cannot retrieve contributors at this time

336 lines (297 sloc) 10.671 kB
#include <iostream>
#include <tr1/memory>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
#include "filelib.h"
#include "dict.h"
#include "sampler.h"
#include "ccrp.h"
#include "ccrp_nt.h"
using namespace std;
using namespace std::tr1;
namespace po = boost::program_options;
Dict d; // global dictionary
string Join(char joiner, const vector<int>& phrase) {
ostringstream os;
for (unsigned i = 0; i < phrase.size(); ++i) {
if (i > 0) os << joiner;
os << d.Convert(phrase[i]);
}
return os.str();
}
template <typename BType>
void WriteSeg(const vector<int>& line, const vector<BType>& label, const Dict& d) {
assert(line.size() == label.size());
assert(label.back());
unsigned prev = 0;
unsigned cur = 0;
while (cur < line.size()) {
if (label[cur]) {
if (prev) cout << ' ';
cout << "{{";
for (unsigned i = prev; i <= cur; ++i)
cout << (i == prev ? "" : " ") << d.Convert(line[i]);
cout << "}}:" << label[cur];
prev = cur + 1;
}
++cur;
}
cout << endl;
}
ostream& operator<<(ostream& os, const vector<int>& phrase) {
for (unsigned i = 0; i < phrase.size(); ++i)
os << (i == 0 ? "" : " ") << d.Convert(phrase[i]);
return os;
}
struct UnigramLM {
explicit UnigramLM(const string& fname) {
ifstream in(fname.c_str());
assert(in);
}
double logprob(unsigned word) const {
assert(word < freqs_.size());
return freqs_[word];
}
vector<double> freqs_;
};
void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
("samples,s",po::value<unsigned>()->default_value(1000),"Number of samples")
("input,i",po::value<string>(),"Read file from")
("random_seed,S",po::value<uint32_t>(), "Random seed")
("write_cdec_grammar,g", po::value<string>(), "Write cdec grammar to this file")
("write_cdec_weights,w", po::value<string>(), "Write cdec weights to this file")
("poisson_length,p", "Use a Poisson distribution as the length of a phrase in the base distribuion")
("no_hyperparameter_inference,N", "Disable hyperparameter inference");
po::options_description clo("Command line options");
clo.add_options()
("config", po::value<string>(), "Configuration file")
("help,h", "Print this help message and exit");
po::options_description dconfig_options, dcmdline_options;
dconfig_options.add(opts);
dcmdline_options.add(opts).add(clo);
po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
if (conf->count("config")) {
ifstream config((*conf)["config"].as<string>().c_str());
po::store(po::parse_config_file(config, dconfig_options), *conf);
}
po::notify(*conf);
if (conf->count("help") || (conf->count("input") == 0)) {
cerr << dcmdline_options << endl;
exit(1);
}
}
void ReadCorpus(const string& filename, vector<vector<int> >* c, set<int>* vocab) {
c->clear();
istream* in;
if (filename == "-")
in = &cin;
else
in = new ifstream(filename.c_str());
assert(*in);
string line;
while(*in) {
getline(*in, line);
if (line.empty() && !*in) break;
c->push_back(vector<int>());
vector<int>& v = c->back();
d.ConvertWhitespaceDelimitedLine(line, &v);
for (unsigned i = 0; i < v.size(); ++i) vocab->insert(v[i]);
}
if (in != &cin) delete in;
}
double log_poisson(unsigned x, const double& lambda) {
assert(lambda > 0.0);
return log(lambda) * x - lgamma(x + 1) - lambda;
}
struct UniphraseLM {
UniphraseLM(const vector<vector<int> >& corpus,
const set<int>& vocab,
const po::variables_map& conf) :
phrases_(1,1),
gen_(1,1),
corpus_(corpus),
uniform_word_(1.0 / vocab.size()),
gen_p0_(0.5),
p_end_(0.5),
use_poisson_(conf.count("poisson_length") > 0) {}
double p0(const vector<int>& phrase) const {
static vector<double> p0s(10000, 0.0);
assert(phrase.size() < 10000);
double& p = p0s[phrase.size()];
if (p) return p;
p = exp(log_p0(phrase));
if (!p) {
cerr << "0 prob phrase: " << phrase << "\nAssigning std::numeric_limits<double>::min()\n";
p = std::numeric_limits<double>::min();
}
return p;
}
double log_p0(const vector<int>& phrase) const {
double len_logprob;
if (use_poisson_)
len_logprob = log_poisson(phrase.size(), 1.0);
else
len_logprob = log(1 - p_end_) * (phrase.size() -1) + log(p_end_);
return log(uniform_word_) * phrase.size() + len_logprob;
}
double llh() const {
double llh = gen_.log_crp_prob();
llh += log(gen_p0_) + log(1 - gen_p0_);
double llhr = phrases_.log_crp_prob();
for (CCRP_NoTable<vector<int> >::const_iterator it = phrases_.begin(); it != phrases_.end(); ++it) {
llhr += log_p0(it->first);
//llhr += log_p0(it->first);
if (!isfinite(llh)) {
cerr << it->first << endl;
cerr << log_p0(it->first) << endl;
abort();
}
}
return llh + llhr;
}
void Sample(unsigned int samples, bool hyp_inf, MT19937* rng) {
cerr << "Initializing...\n";
z_.resize(corpus_.size());
int tc = 0;
for (unsigned i = 0; i < corpus_.size(); ++i) {
const vector<int>& line = corpus_[i];
const int ls = line.size();
const int last_pos = ls - 1;
vector<bool>& z = z_[i];
z.resize(ls);
int prev = 0;
for (int j = 0; j < ls; ++j) {
z[j] = rng->next() < 0.5;
if (j == last_pos) z[j] = true; // break phrase at the end of the sentence
if (z[j]) {
const vector<int> p(line.begin() + prev, line.begin() + j + 1);
phrases_.increment(p);
//cerr << p << ": " << p0(p) << endl;
prev = j + 1;
gen_.increment(false);
++tc; // remove
}
}
++tc;
gen_.increment(true); // end of utterance
}
cerr << "TC: " << tc << endl;
cerr << "Initial LLH: " << llh() << endl;
cerr << "Sampling...\n";
cerr << gen_ << endl;
for (unsigned s = 1; s < samples; ++s) {
cerr << '.';
if (s % 10 == 0) {
cerr << " [" << s;
if (hyp_inf) ResampleHyperparameters(rng);
cerr << " LLH=" << llh() << "]\n";
vector<int> z(z_[0].size(), 0);
//for (int j = 0; j < z.size(); ++j) z[j] = z_[0][j];
//SegCorpus::Write(corpus_[0], z, d);
}
for (unsigned i = 0; i < corpus_.size(); ++i) {
const vector<int>& line = corpus_[i];
const int ls = line.size();
const int last_pos = ls - 1;
vector<bool>& z = z_[i];
int prev = 0;
for (int j = 0; j < last_pos; ++j) { // don't resample last position
int next = j+1; while(!z[next]) { ++next; }
const vector<int> p1p2(line.begin() + prev, line.begin() + next + 1);
const vector<int> p1(line.begin() + prev, line.begin() + j + 1);
const vector<int> p2(line.begin() + j + 1, line.begin() + next + 1);
if (z[j]) {
phrases_.decrement(p1);
phrases_.decrement(p2);
gen_.decrement(false);
gen_.decrement(false);
} else {
phrases_.decrement(p1p2);
gen_.decrement(false);
}
const double d1 = phrases_.prob(p1p2, p0(p1p2)) * gen_.prob(false, gen_p0_);
double d2 = phrases_.prob(p1, p0(p1)) * gen_.prob(false, gen_p0_);
phrases_.increment(p1);
gen_.increment(false);
d2 *= phrases_.prob(p2, p0(p2)) * gen_.prob(false, gen_p0_);
phrases_.decrement(p1);
gen_.decrement(false);
z[j] = rng->SelectSample(d1, d2);
if (z[j]) {
phrases_.increment(p1);
phrases_.increment(p2);
gen_.increment(false);
gen_.increment(false);
prev = j + 1;
} else {
phrases_.increment(p1p2);
gen_.increment(false);
}
}
}
}
// cerr << endl << endl << gen_ << endl << phrases_ << endl;
cerr << gen_.prob(false, gen_p0_) << " " << gen_.prob(true, 1 - gen_p0_) << endl;
}
void WriteCdecGrammarForCurrentSample(ostream* os) const {
CCRP_NoTable<vector<int> >::const_iterator it = phrases_.begin();
for (; it != phrases_.end(); ++it) {
(*os) << "[X] ||| " << Join(' ', it->first) << " ||| "
<< Join('_', it->first) << " ||| C=1 P="
<< log(phrases_.prob(it->first, p0(it->first))) << endl;
}
}
double OOVUnigramLogProb() const {
vector<int> x(1,99999999);
return log(phrases_.prob(x, p0(x)));
}
void ResampleHyperparameters(MT19937* rng) {
phrases_.resample_hyperparameters(rng);
gen_.resample_hyperparameters(rng);
cerr << " " << phrases_.alpha();
}
CCRP_NoTable<vector<int> > phrases_;
CCRP_NoTable<bool> gen_;
vector<vector<bool> > z_; // z_[i] is there a phrase boundary after the ith word
const vector<vector<int> >& corpus_;
const double uniform_word_;
const double gen_p0_;
const double p_end_; // in base length distribution, p of the end of a phrase
const bool use_poisson_;
};
int main(int argc, char** argv) {
po::variables_map conf;
InitCommandLine(argc, argv, &conf);
boost::shared_ptr<MT19937> prng;
if (conf.count("random_seed"))
prng.reset(new MT19937(conf["random_seed"].as<uint32_t>()));
else
prng.reset(new MT19937);
MT19937& rng = *prng;
vector<vector<int> > corpus;
set<int> vocab;
ReadCorpus(conf["input"].as<string>(), &corpus, &vocab);
cerr << "Corpus size: " << corpus.size() << " sentences\n";
cerr << "Vocabulary size: " << vocab.size() << " types\n";
UniphraseLM ulm(corpus, vocab, conf);
ulm.Sample(conf["samples"].as<unsigned>(), conf.count("no_hyperparameter_inference") == 0, &rng);
cerr << "OOV unigram prob: " << ulm.OOVUnigramLogProb() << endl;
for (unsigned i = 0; i < corpus.size(); ++i)
WriteSeg(corpus[i], ulm.z_[i], d);
if (conf.count("write_cdec_grammar")) {
string fname = conf["write_cdec_grammar"].as<string>();
cerr << "Writing model to " << fname << " ...\n";
WriteFile wf(fname);
ulm.WriteCdecGrammarForCurrentSample(wf.stream());
}
if (conf.count("write_cdec_weights")) {
string fname = conf["write_cdec_weights"].as<string>();
cerr << "Writing weights to " << fname << " .\n";
WriteFile wf(fname);
ostream& os = *wf.stream();
os << "# make C smaller to use more phrases\nP 1\nPassThrough " << ulm.OOVUnigramLogProb() << "\nC -3\n";
}
return 0;
}
Jump to Line
Something went wrong with that request. Please try again.