-
Notifications
You must be signed in to change notification settings - Fork 1
/
kaldi-to-svm.cpp
111 lines (79 loc) · 2.46 KB
/
kaldi-to-svm.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#include <cmdparser.h>
#include <fstream>
#include <cstdio>
#include <kaldi-io.h>
map<string, vector<int> > readLabels(const string& filename);
void saveAsLibSvmFormat(const string& filename, const KaldiArchive& ark, const map<string, vector<int> >& labels, bool skip);
int main (int argc, char* argv[]) {
CmdParser cmd(argc, argv);
cmd.add("kaldi-ark-in")
.add("label-in")
.add("svm-out", false)
.add("mapping-out", false);
cmd.addGroup("Options:")
.add("--skip", "Skip over missing labels (--skip=true )\n"
"Print 0 as label (--skip=false)", "false");
cmd.addGroup("Example: ./kaldi-to-svm data/example.39.ark example.labels");
if(!cmd.isOptionLegal())
cmd.showUsageAndExit();
string input_fn = cmd[1];
string label_fn = cmd[2];
string output_fn = cmd[3];
string mapping_fn = cmd[4];
bool skip = cmd["--skip"];
map<string, vector<int> > labels = readLabels(label_fn);
KaldiArchive ark(input_fn);
saveAsLibSvmFormat(output_fn, ark, labels, skip);
saveFrameCounts(mapping_fn, ark);
return 0;
}
void saveAsLibSvmFormat(const string& filename, const KaldiArchive& ark, const map<string, vector<int> >& labels, bool skip) {
size_t N = ark.docid().size();
size_t dim = ark.dim();
const vector<string>& docids = ark.docid();
const vector<float>& data = ark.data();
const vector<size_t>& offset = ark.offset();
FILE* fid;
if (filename.empty() || filename == "-")
fid = stdout;
else
fid = fopen(filename.c_str(), "w");
for (size_t i=0; i<N; ++i) {
size_t length = (offset[i+1] - offset[i]) / dim;
for (size_t j=0; j<length; ++j) {
int y = 0;
if (labels.count(docids[i]) > 0)
y = labels.at(docids[i])[j];
else {
if (skip)
continue;
}
fprintf(fid, "%d ", y);
for (size_t k=0; k<dim; ++k) {
float x = data[offset[i] + j*dim + k];
if (x != 0)
fprintf(fid, "%lu:%g ", k + 1, x);
}
fprintf(fid, "\n");
}
}
if (fid != stdout)
fclose(fid);
}
map<string, vector<int> > readLabels(const string& filename) {
map<string, vector<int> > labels;
if (filename.empty())
return labels;
ifstream fin(filename.c_str());
string line;
while (getline(fin, line)) {
vector<string> tokens = split(line, ' ');
string docid = tokens[0];
vector<int>& L = labels[docid];
L.resize(tokens.size() - 1);
for (size_t i=0; i<L.size(); ++i)
L[i] = std::stoi(tokens[i+1]);
}
fin.close();
return labels;
}