Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Tree: 7527592aaf
Fetching contributors…

Cannot retrieve contributors at this time

679 lines (639 sloc) 21.586 kB
#include "hg_io.h"
#include <fstream>
#include <sstream>
#include <iostream>
#include "fast_lexical_cast.hpp"
#include "tdict.h"
#include "json_parse.h"
#include "hg.h"
using namespace std;
struct HGReader : public JSONParser {
HGReader(Hypergraph* g) : rp("[X] ||| "), state(-1), hg(*g), nodes_needed(true), edges_needed(true) { nodes = 0; edges = 0; }
void CreateNode(const string& cat, const vector<int>& in_edges) {
WordID c = TD::Convert("X") * -1;
if (!cat.empty()) c = TD::Convert(cat) * -1;
Hypergraph::Node* node = hg.AddNode(c);
for (int i = 0; i < in_edges.size(); ++i) {
if (in_edges[i] >= hg.edges_.size()) {
cerr << "JSONParser: in_edges[" << i << "]=" << in_edges[i]
<< ", but hg only has " << hg.edges_.size() << " edges!\n";
abort();
}
hg.ConnectEdgeToHeadNode(&hg.edges_[in_edges[i]], node);
}
}
void CreateEdge(const TRulePtr& rule, FeatureVector* feats, const SmallVectorUnsigned& tail) {
Hypergraph::Edge* edge = hg.AddEdge(rule, tail);
feats->swap(edge->feature_values_);
edge->i_ = spans[0];
edge->j_ = spans[1];
edge->prev_i_ = spans[2];
edge->prev_j_ = spans[3];
}
bool HandleJSONEvent(int type, const JSON_value* value) {
switch(state) {
case -1:
assert(type == JSON_T_OBJECT_BEGIN);
state = 0;
break;
case 0:
if (type == JSON_T_OBJECT_END) {
//cerr << "HG created\n"; // TODO, signal some kind of callback
} else if (type == JSON_T_KEY) {
string val = value->vu.str.value;
if (val == "features") { assert(fdict.empty()); state = 1; }
else if (val == "is_sorted") { state = 3; }
else if (val == "rules") { assert(rules.empty()); state = 4; }
else if (val == "node") { state = 8; }
else if (val == "edges") { state = 13; }
else { cerr << "Unexpected key: " << val << endl; return false; }
}
break;
// features
case 1:
if(type == JSON_T_NULL) { state = 0; break; }
assert(type == JSON_T_ARRAY_BEGIN);
state = 2;
break;
case 2:
if(type == JSON_T_ARRAY_END) { state = 0; break; }
assert(type == JSON_T_STRING);
fdict.push_back(FD::Convert(value->vu.str.value));
assert(fdict.back() > 0);
break;
// is_sorted
case 3:
assert(type == JSON_T_TRUE || type == JSON_T_FALSE);
is_sorted = (type == JSON_T_TRUE);
if (!is_sorted) { cerr << "[WARNING] is_sorted flag is ignored\n"; }
state = 0;
break;
// rules
case 4:
if(type == JSON_T_NULL) { state = 0; break; }
assert(type == JSON_T_ARRAY_BEGIN);
state = 5;
break;
case 5:
if(type == JSON_T_ARRAY_END) { state = 0; break; }
assert(type == JSON_T_INTEGER);
state = 6;
rule_id = value->vu.integer_value;
break;
case 6:
assert(type == JSON_T_STRING);
rules[rule_id] = TRulePtr(new TRule(value->vu.str.value));
state = 5;
break;
// Nodes
case 8:
assert(type == JSON_T_OBJECT_BEGIN);
++nodes;
in_edges.clear();
cat.clear();
state = 9; break;
case 9:
if (type == JSON_T_OBJECT_END) {
//cerr << "Creating NODE\n";
CreateNode(cat, in_edges);
state = 0; break;
}
assert(type == JSON_T_KEY);
cur_key = value->vu.str.value;
if (cur_key == "cat") { assert(cat.empty()); state = 10; break; }
if (cur_key == "in_edges") { assert(in_edges.empty()); state = 11; break; }
cerr << "Syntax error: unexpected key " << cur_key << " in node specification.\n";
return false;
case 10:
assert(type == JSON_T_STRING || type == JSON_T_NULL);
cat = value->vu.str.value;
state = 9; break;
case 11:
if (type == JSON_T_NULL) { state = 9; break; }
assert(type == JSON_T_ARRAY_BEGIN);
state = 12; break;
case 12:
if (type == JSON_T_ARRAY_END) { state = 9; break; }
assert(type == JSON_T_INTEGER);
//cerr << "in_edges: " << value->vu.integer_value << endl;
in_edges.push_back(value->vu.integer_value);
break;
// "edges": [ { "tail": null, "feats" : [0,1.63,1,-0.54], "rule": 12},
// { "tail": null, "feats" : [0,0.87,1,0.02], "spans":[1,2,3,4], "rule": 17},
// { "tail": [0], "feats" : [1,2.3,2,15.3,"ExtraFeature",1.2], "rule": 13}]
case 13:
assert(type == JSON_T_ARRAY_BEGIN);
state = 14;
break;
case 14:
if (type == JSON_T_ARRAY_END) { state = 0; break; }
assert(type == JSON_T_OBJECT_BEGIN);
//cerr << "New edge\n";
++edges;
cur_rule.reset(); feats.clear(); tail.clear();
state = 15; break;
case 15:
if (type == JSON_T_OBJECT_END) {
CreateEdge(cur_rule, &feats, tail);
state = 14; break;
}
assert(type == JSON_T_KEY);
cur_key = value->vu.str.value;
//cerr << "edge key " << cur_key << endl;
if (cur_key == "rule") { assert(!cur_rule); state = 16; break; }
if (cur_key == "spans") { assert(!cur_rule); state = 22; break; }
if (cur_key == "feats") { assert(feats.empty()); state = 17; break; }
if (cur_key == "tail") { assert(tail.empty()); state = 20; break; }
cerr << "Unexpected key " << cur_key << " in edge specification\n";
return false;
case 16: // edge.rule
if (type == JSON_T_INTEGER) {
int rule_id = value->vu.integer_value;
if (rules.find(rule_id) == rules.end()) {
// rules list must come before the edge definitions!
cerr << "Rule_id " << rule_id << " given but only loaded " << rules.size() << " rules\n";
return false;
}
cur_rule = rules[rule_id];
} else if (type == JSON_T_STRING) {
cur_rule.reset(new TRule(value->vu.str.value));
} else {
cerr << "Rule must be either a rule id or a rule string" << endl;
return false;
}
// cerr << "Edge: rule=" << cur_rule->AsString() << endl;
state = 15;
break;
case 17: // edge.feats
if (type == JSON_T_NULL) { state = 15; break; }
assert(type == JSON_T_ARRAY_BEGIN);
state = 18; break;
case 18:
if (type == JSON_T_ARRAY_END) { state = 15; break; }
if (type != JSON_T_INTEGER && type != JSON_T_STRING) {
cerr << "Unexpected feature id type\n"; return false;
}
if (type == JSON_T_INTEGER) {
fid = value->vu.integer_value;
assert(fid < fdict.size());
fid = fdict[fid];
} else if (JSON_T_STRING) {
fid = FD::Convert(value->vu.str.value);
} else { abort(); }
state = 19;
break;
case 19:
{
assert(type == JSON_T_INTEGER || type == JSON_T_FLOAT);
double val = (type == JSON_T_INTEGER ? static_cast<double>(value->vu.integer_value) :
strtod(value->vu.str.value, NULL));
feats.set_value(fid, val);
state = 18;
break;
}
case 20: // edge.tail
if (type == JSON_T_NULL) { state = 15; break; }
assert(type == JSON_T_ARRAY_BEGIN);
state = 21; break;
case 21:
if (type == JSON_T_ARRAY_END) { state = 15; break; }
assert(type == JSON_T_INTEGER);
tail.push_back(value->vu.integer_value);
break;
case 22: // edge.spans
assert(type == JSON_T_ARRAY_BEGIN);
state = 23;
spans[0] = spans[1] = spans[2] = spans[3] = -1;
spanc = 0;
break;
case 23:
if (type == JSON_T_ARRAY_END) { state = 15; break; }
assert(type == JSON_T_INTEGER);
assert(spanc < 4);
spans[spanc] = value->vu.integer_value;
++spanc;
}
return true;
}
string rp;
string cat;
SmallVectorUnsigned tail;
vector<int> in_edges;
TRulePtr cur_rule;
map<int, TRulePtr> rules;
vector<int> fdict;
SparseVector<double> feats;
int state;
int fid;
int nodes;
int edges;
int spans[4];
int spanc;
string cur_key;
Hypergraph& hg;
int rule_id;
bool nodes_needed;
bool edges_needed;
bool is_sorted;
};
bool HypergraphIO::ReadFromJSON(istream* in, Hypergraph* hg) {
hg->clear();
HGReader reader(hg);
return reader.Parse(in);
}
static void WriteRule(const TRule& r, ostream* out) {
if (!r.lhs_) { (*out) << "[X] ||| "; }
JSONParser::WriteEscapedString(r.AsString(), out);
}
bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) {
if (hg.empty()) { *out << "{}\n"; return true; }
map<const TRule*, int> rid;
ostream& o = *out;
rid[NULL] = 0;
o << '{';
if (!remove_rules) {
o << "\"rules\":[";
for (int i = 0; i < hg.edges_.size(); ++i) {
const TRule* r = hg.edges_[i].rule_.get();
int &id = rid[r];
if (!id) {
id=rid.size() - 1;
if (id > 1) o << ',';
o << id << ',';
WriteRule(*r, &o);
};
}
o << "],";
}
const bool use_fdict = FD::NumFeats() < 1000;
if (use_fdict) {
o << "\"features\":[";
for (int i = 1; i < FD::NumFeats(); ++i) {
o << (i==1 ? "":",");
JSONParser::WriteEscapedString(FD::Convert(i), &o);
}
o << "],";
}
vector<int> edgemap(hg.edges_.size(), -1); // edges may be in non-topo order
int edge_count = 0;
for (int i = 0; i < hg.nodes_.size(); ++i) {
const Hypergraph::Node& node = hg.nodes_[i];
if (i > 0) { o << ","; }
o << "\"edges\":[";
for (int j = 0; j < node.in_edges_.size(); ++j) {
const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]];
edgemap[edge.id_] = edge_count;
++edge_count;
o << (j == 0 ? "" : ",") << "{";
o << "\"tail\":[";
for (int k = 0; k < edge.tail_nodes_.size(); ++k) {
o << (k > 0 ? "," : "") << edge.tail_nodes_[k];
}
o << "],";
o << "\"spans\":[" << edge.i_ << "," << edge.j_ << "," << edge.prev_i_ << "," << edge.prev_j_ << "],";
o << "\"feats\":[";
bool first = true;
for (SparseVector<double>::const_iterator it = edge.feature_values_.begin(); it != edge.feature_values_.end(); ++it) {
if (!it->second) continue; // don't write features that have a zero value
if (!it->first) continue; // if the feature set was frozen this might happen
if (!first) o << ',';
if (use_fdict)
o << (it->first - 1);
else {
JSONParser::WriteEscapedString(FD::Convert(it->first), &o);
}
o << ',' << it->second;
first = false;
}
o << "]";
if (!remove_rules) { o << ",\"rule\":" << rid[edge.rule_.get()]; }
o << "}";
}
o << "],";
o << "\"node\":{\"in_edges\":[";
for (int j = 0; j < node.in_edges_.size(); ++j) {
int mapped_edge = edgemap[node.in_edges_[j]];
assert(mapped_edge >= 0);
o << (j == 0 ? "" : ",") << mapped_edge;
}
o << "]";
if (node.cat_ < 0) {
o << ",\"cat\":";
JSONParser::WriteEscapedString(TD::Convert(node.cat_ * -1), &o);
}
o << "}";
}
o << "}\n";
return true;
}
bool needs_escape[128];
void InitEscapes() {
memset(needs_escape, false, 128);
needs_escape[static_cast<size_t>('\'')] = true;
needs_escape[static_cast<size_t>('\\')] = true;
}
string HypergraphIO::Escape(const string& s) {
size_t len = s.size();
for (int i = 0; i < s.size(); ++i) {
unsigned char c = s[i];
if (c < 128 && needs_escape[c]) ++len;
}
if (len == s.size()) return s;
string res(len, ' ');
size_t o = 0;
for (int i = 0; i < s.size(); ++i) {
unsigned char c = s[i];
if (c < 128 && needs_escape[c])
res[o++] = '\\';
res[o++] = c;
}
assert(o == len);
return res;
}
string HypergraphIO::AsPLF(const Hypergraph& hg, bool include_global_parentheses) {
static bool first = true;
if (first) { InitEscapes(); first = false; }
if (hg.nodes_.empty()) return "()";
ostringstream os;
if (include_global_parentheses) os << '(';
static const string EPS="*EPS*";
for (int i = 0; i < hg.nodes_.size()-1; ++i) {
if (hg.nodes_[i].out_edges_.empty()) abort();
const bool last_node = (i == hg.nodes_.size() - 2);
const int out_edges_size = hg.nodes_[i].out_edges_.size();
// compound splitter adds an extra goal transition which we suppress with
// the following conditional
if (!last_node || out_edges_size != 1 ||
hg.edges_[hg.nodes_[i].out_edges_[0]].rule_->EWords() == 1) {
os << '(';
for (int j = 0; j < out_edges_size; ++j) {
const Hypergraph::Edge& e = hg.edges_[hg.nodes_[i].out_edges_[j]];
const string output = e.rule_->e_.size() ==2 ? Escape(TD::Convert(e.rule_->e_[1])) : EPS;
double prob = log(e.edge_prob_);
if (isinf(prob)) { prob = -9e20; }
if (isnan(prob)) { prob = 0; }
os << "('" << output << "'," << prob << "," << e.head_node_ - i << "),";
}
os << "),";
}
}
if (include_global_parentheses) os << ')';
return os.str();
}
string HypergraphIO::AsPLF(const Lattice& lat, bool include_global_parentheses) {
static bool first = true;
if (first) { InitEscapes(); first = false; }
if (lat.empty()) return "()";
ostringstream os;
if (include_global_parentheses) os << '(';
static const string EPS="*EPS*";
for (int i = 0; i < lat.size(); ++i) {
const vector<LatticeArc> arcs = lat[i];
os << '(';
for (int j = 0; j < arcs.size(); ++j) {
os << "('" << Escape(TD::Convert(arcs[j].label)) << "',"
<< arcs[j].cost << ',' << arcs[j].dist2next << "),";
}
os << "),";
}
if (include_global_parentheses) os << ')';
return os.str();
}
namespace PLF {
const string chars = "'\\";
const char& quote = chars[0];
const char& slash = chars[1];
// safe get
inline char get(const std::string& in, int c) {
if (c < 0 || c >= (int)in.size()) return 0;
else return in[(size_t)c];
}
// consume whitespace
inline void eatws(const std::string& in, int& c) {
while (get(in,c) == ' ') { c++; }
}
// from 'foo' return foo
std::string getEscapedString(const std::string& in, int &c)
{
eatws(in,c);
if (get(in,c++) != quote) return "ERROR";
std::string res;
char cur = 0;
do {
cur = get(in,c++);
if (cur == slash) { res += get(in,c++); }
else if (cur != quote) { res += cur; }
} while (get(in,c) != quote && (c < (int)in.size()));
c++;
eatws(in,c);
return res;
}
// basically atof
float getFloat(const std::string& in, int &c)
{
std::string tmp;
eatws(in,c);
while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',') {
tmp += get(in,c++);
}
eatws(in,c);
if (tmp.empty()) {
cerr << "Syntax error while reading number! col=" << c << endl;
abort();
}
return atof(tmp.c_str());
}
// basically atoi
int getInt(const std::string& in, int &c)
{
std::string tmp;
eatws(in,c);
while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',') {
tmp += get(in,c++);
}
eatws(in,c);
return atoi(tmp.c_str());
}
// maximum number of nodes permitted
#define MAX_NODES 100000000
// parse ('foo', 0.23)
void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) {
if (get(in,c++) != '(') { cerr << "PCN/PLF parse error: expected (\n"; abort(); }
vector<WordID> ewords(2, 0);
ewords[1] = TD::Convert(getEscapedString(in,c));
TRulePtr r(new TRule(ewords));
r->ComputeArity();
// cerr << "RULE: " << r->AsString() << endl;
if (get(in,c++) != ',') { cerr << in << endl; cerr << "PCN/PLF parse error: expected , after string\n"; abort(); }
size_t cnNext = 1;
std::vector<float> probs;
probs.push_back(getFloat(in,c));
while (get(in,c) == ',') {
c++;
float val = getFloat(in,c);
probs.push_back(val);
// cerr << val << endl; //REMO
}
//if we read more than one prob, this was a lattice, last item was column increment
if (probs.size()>1) {
cnNext = static_cast<size_t>(probs.back());
probs.pop_back();
if (cnNext < 1) { cerr << cnNext << endl << "PCN/PLF parse error: bad link length at last element of cn alt block\n"; abort(); }
}
if (get(in,c++) != ')') { cerr << "PCN/PLF parse error: expected ) at end of cn alt block\n"; abort(); }
eatws(in,c);
Hypergraph::TailNodeVector tail(1, cur_node);
Hypergraph::Edge* edge = hg->AddEdge(r, tail);
//cerr << " <--" << cur_node << endl;
int head_node = cur_node + cnNext;
assert(head_node < MAX_NODES); // prevent malicious PLFs from using all the memory
if (hg->nodes_.size() < (head_node + 1)) { hg->ResizeNodes(head_node + 1); }
hg->ConnectEdgeToHeadNode(edge, &hg->nodes_[head_node]);
for (int i = 0; i < probs.size(); ++i)
edge->feature_values_.set_value(FD::Convert("Feature_" + boost::lexical_cast<string>(i)), probs[i]);
}
// parse (('foo', 0.23), ('bar', 0.77))
void ReadPLFNode(const std::string& in, int &c, int cur_node, int line, Hypergraph* hg) {
//cerr << "PLF READING NODE " << cur_node << endl;
if (hg->nodes_.size() < (cur_node + 1)) { hg->ResizeNodes(cur_node + 1); }
if (get(in,c++) != '(') { cerr << line << ": Syntax error 1\n"; abort(); }
eatws(in,c);
while (1) {
if (c > (int)in.size()) { break; }
if (get(in,c) == ')') {
c++;
eatws(in,c);
break;
}
if (get(in,c) == ',' && get(in,c+1) == ')') {
c+=2;
eatws(in,c);
break;
}
if (get(in,c) == ',') { c++; eatws(in,c); }
ReadPLFEdge(in, c, cur_node, hg);
}
}
} // namespace PLF
void HypergraphIO::ReadFromPLF(const std::string& in, Hypergraph* hg, int line) {
hg->clear();
int c = 0;
int cur_node = 0;
if (in[c++] != '(') { cerr << line << ": Syntax error!\n"; abort(); }
while (1) {
if (c > (int)in.size()) { break; }
if (PLF::get(in,c) == ')') {
c++;
PLF::eatws(in,c);
break;
}
if (PLF::get(in,c) == ',' && PLF::get(in,c+1) == ')') {
c+=2;
PLF::eatws(in,c);
break;
}
if (PLF::get(in,c) == ',') { c++; PLF::eatws(in,c); }
PLF::ReadPLFNode(in, c, cur_node, line, hg);
++cur_node;
}
assert(cur_node == hg->nodes_.size() - 1);
}
void HypergraphIO::PLFtoLattice(const string& plf, Lattice* pl) {
Lattice& l = *pl;
Hypergraph g;
ReadFromPLF(plf, &g, 0);
const int num_nodes = g.nodes_.size() - 1;
l.resize(num_nodes);
int fid0=FD::Convert("Feature_0");
for (int i = 0; i < num_nodes; ++i) {
vector<LatticeArc>& alts = l[i];
const Hypergraph::Node& node = g.nodes_[i];
const int num_alts = node.out_edges_.size();
alts.resize(num_alts);
for (int j = 0; j < num_alts; ++j) {
const Hypergraph::Edge& edge = g.edges_[node.out_edges_[j]];
alts[j].label = edge.rule_->e_[1];
alts[j].cost = edge.feature_values_.get(fid0);
alts[j].dist2next = edge.head_node_ - node.id_;
}
}
}
void HypergraphIO::WriteAsCFG(const Hypergraph& hg) {
vector<int> cats(hg.nodes_.size());
// each node in the translation forest becomes a "non-terminal" in the new
// grammar, create the labels here
const string kSEP = "_";
for (int i = 0; i < hg.nodes_.size(); ++i) {
string pstr = "CAT";
if (hg.nodes_[i].cat_ < 0)
pstr = TD::Convert(-hg.nodes_[i].cat_);
cats[i] = TD::Convert(pstr + kSEP + boost::lexical_cast<string>(i)) * -1;
}
for (int i = 0; i < hg.edges_.size(); ++i) {
const Hypergraph::Edge& edge = hg.edges_[i];
const vector<WordID>& tgt = edge.rule_->e();
const vector<WordID>& src = edge.rule_->f();
TRulePtr rule(new TRule);
rule->prev_i = edge.i_;
rule->prev_j = edge.j_;
rule->lhs_ = cats[edge.head_node_];
vector<WordID>& f = rule->f_;
vector<WordID>& e = rule->e_;
f.resize(tgt.size()); // swap source and target, since the parser
e.resize(src.size()); // parses using the source side!
Hypergraph::TailNodeVector tn(edge.tail_nodes_.size());
int ntc = 0;
for (int j = 0; j < tgt.size(); ++j) {
const WordID& cur = tgt[j];
if (cur > 0) {
f[j] = cur;
} else {
tn[ntc++] = cur;
f[j] = cats[edge.tail_nodes_[-cur]];
}
}
ntc = 0;
for (int j = 0; j < src.size(); ++j) {
const WordID& cur = src[j];
if (cur > 0) {
e[j] = cur;
} else {
e[j] = tn[ntc++];
}
}
rule->scores_ = edge.feature_values_;
rule->parent_rule_ = edge.rule_;
rule->ComputeArity();
cout << rule->AsString() << endl;
}
}
/* Output format:
* #vertices
* for each vertex in bottom-up topological order:
* #downward_edges
* for each downward edge:
* RHS with [vertex_index] for NTs ||| scores
*/
void HypergraphIO::WriteTarget(const std::string &base, unsigned int id, const Hypergraph& hg) {
std::string name(base);
name += '/';
name += boost::lexical_cast<std::string>(id);
std::fstream out(name.c_str(), std::fstream::out);
out << hg.nodes_.size() << ' ' << hg.edges_.size() << '\n';
for (unsigned int i = 0; i < hg.nodes_.size(); ++i) {
const Hypergraph::EdgesVector &edges = hg.nodes_[i].in_edges_;
out << edges.size() << '\n';
for (unsigned int j = 0; j < edges.size(); ++j) {
const Hypergraph::Edge &edge = hg.edges_[edges[j]];
const std::vector<WordID> &e = edge.rule_->e();
for (std::vector<WordID>::const_iterator word = e.begin(); word != e.end(); ++word) {
if (*word <= 0) {
out << '[' << edge.tail_nodes_[-*word] << "] ";
} else {
out << TD::Convert(*word) << ' ';
}
}
out << "||| " << edge.rule_->scores_ << '\n';
}
}
}
Jump to Line
Something went wrong with that request. Please try again.