Permalink
Browse files

Merge tag 'online' into dev

basic online multithreaded decoding + vad
  • Loading branch information...
2 parents 7a50625 + 9777442 commit ed015db8749e807944af6e73d80a6d1a43da3359 @pckben committed Mar 27, 2013
Showing with 24,620 additions and 86 deletions.
  1. +1 −1 Makefile
  2. +3 −3 make.mk
  3. +164 −27 speechsvr/DecodeWorker.cc
  4. +11 −1 speechsvr/DecodeWorker.h
  5. +12 −31 speechsvr/Makefile
  6. +40 −0 speechsvr/RemoteAudioSource.cc
  7. +22 −0 speechsvr/RemoteAudioSource.h
  8. +2 −1 speechsvr/SpeechServer.cc
  9. +26 −0 speechsvr/kaldi.mk
  10. +20 −0 speechsvr/online-decoder/Makefile
  11. +35 −0 speechsvr/online-decoder/online-audio-source-interface.h
  12. +156 −0 speechsvr/online-decoder/online-audio-source.cc
  13. +134 −0 speechsvr/online-decoder/online-audio-source.h
  14. +159 −0 speechsvr/online-decoder/online-cmn.cc
  15. +63 −0 speechsvr/online-decoder/online-cmn.h
  16. +56 −0 speechsvr/online-decoder/online-decodable.cc
  17. +75 −0 speechsvr/online-decoder/online-decodable.h
  18. +283 −0 speechsvr/online-decoder/online-faster-decoder.cc
  19. +143 −0 speechsvr/online-decoder/online-faster-decoder.h
  20. +272 −0 speechsvr/online-decoder/online-feat-input.cc
  21. +202 −0 speechsvr/online-decoder/online-feat-input.h
  22. +248 −0 speechsvr/online-decoder/online-vad.cc
  23. +136 −0 speechsvr/online-decoder/online-vad.h
  24. +74 −0 speechsvr/online-decoder/onlinebin-util.cc
  25. +45 −0 speechsvr/online-decoder/onlinebin-util.h
  26. +11 −0 speechsvr/protocol.cc
  27. +57 −6 speechsvr/protocol.h
  28. BIN speechsvr/speechclient
  29. +89 −11 speechsvr/speechclient.cc
  30. BIN speechsvr/speechsvr
  31. BIN speechsvr/test/tests
  32. +22,076 −0 speechsvr/vad_file
  33. BIN speechsvr/vad_input
  34. BIN speechsvr/vad_output
  35. +3 −3 src/Socket.cc
  36. +2 −2 src/Socket.h
  37. BIN test/test_client
  38. BIN test/test_server
  39. BIN test/test_worker
View
@@ -4,7 +4,7 @@ all: $(SUBDIRS)
.PHONY: $(SUBDIRS)
-subdirs: $(SUBDIRS)
+#subdirs: $(SUBDIRS)
$(SUBDIRS):
$(MAKE) -C $@
View
@@ -1,6 +1,6 @@
CC = g++
-CCFLAGS = -g -O2 -Wall
+CCFLAGS = -g -O2 -Wall \
+ -Wno-sign-compare -Winit-self \
+
GTEST_ROOT = /opt/local
-FSTROOT = /Users/Ben/projects/kaldi/tools/openfst
-KALDIROOT = /Users/Ben/projects/kaldi/src
View
@@ -1,4 +1,5 @@
#include "DecodeWorker.h"
+#include "RemoteAudioSource.h"
#include <base/kaldi-common.h>
#include <util/common-utils.h>
@@ -13,6 +14,13 @@
#include <feat/feature-mfcc.h>
#include <online/online-cmn.h>
+//#include "online/online-feat-input.h"
+#include "online-decoder/online-decodable.h"
+#include "online-decoder/online-faster-decoder.h"
+#include "online-decoder/onlinebin-util.h"
+#include "online-decoder/online-cmn.h"
+#include "online-decoder/online-vad.h"
+
#include <SocketTask.h>
#include <Socket.h>
@@ -22,6 +30,15 @@ using namespace std;
using namespace kaldi;
using namespace fst;
+void SendOutput(Socket *socket, bool end_utterance, string str) {
+ PacketHeader header;
+ header.payload_length = sizeof(bool) + str.length();
+ header.type = DECODE_OUTPUT;
+ socket->Send(&header, sizeof(PacketHeader));
+ socket->Send(&end_utterance, sizeof(bool));
+ socket->Send(str.c_str(), str.length());
+}
+
string SymbolsToWords(vector<int32> symbols, SymbolTable *symbol_table) {
string words = "";
for (vector<int32>::iterator it = symbols.begin();
@@ -34,6 +51,18 @@ string SymbolsToWords(vector<int32> symbols, SymbolTable *symbol_table) {
return words;
}
+string GetPartialResult(const std::vector<int32>& words,
+ const fst::SymbolTable *word_syms) {
+ KALDI_ASSERT(word_syms != NULL);
+ string result = "";
+ for (size_t i = 0; i < words.size(); i++) {
+ string word = word_syms->Find(words[i]);
+ if (word == "")
+ KALDI_ERR << "Word-id " << words[i] <<" not in symbol table.";
+ result += word + " ";
+ }
+ return result;
+}
void LdaTransform(Matrix<BaseFloat> &cmvn,
Matrix<BaseFloat> &output,
@@ -87,15 +116,17 @@ DecodeWorker::DecodeWorker(LatticeFasterDecoder *decoder,
SymbolTable *symbol_table,
float acoustic_scale,
float left_context,
- float right_context)
+ float right_context,
+ VectorFst<StdArc> *decode_fst)
: decoder_(decoder),
gmm_(gmm),
trans_model_(trans_model),
lda_transform_(lda_transform),
symbol_table_(symbol_table),
acoustic_scale_(acoustic_scale),
left_context_(left_context),
- right_context_(right_context)
+ right_context_(right_context),
+ decode_fst_(decode_fst)
{
MfccOptions mfcc_opts;
mfcc_opts.use_energy = false;
@@ -119,28 +150,39 @@ void DecodeWorker::Work() {
Socket *sock = ((SocketTask *)GetTask())->GetSocket();
// Receive header
- Packet packet;
- sock->Receive((char *)&packet.header, sizeof(PacketHeader));
+ PacketHeader header;
+ sock->Receive((char *)&header, sizeof(PacketHeader));
+
+ if (header.type != DECODE_REQUEST)
+ exit(1);
- cout << "Packet received: type=" << packet.header.type
- << ", length=" << packet.header.payload_length << endl;
+ DecodeRequest req;
+ sock->Receive((char *)&req, header.payload_length);
+
+ cout << "Decode request: "
+ << (req.online_mode ? "online mode" : "file mode") << endl;
+
+ DecodeResponse res = { true };
+ SendPacket(sock, DECODE_ACCEPT, sizeof(DecodeResponse), &res);
Vector<BaseFloat> wave; // raw wave data
Matrix<BaseFloat> mfcc_output;
Matrix<BaseFloat> cmvn_output;
Matrix<BaseFloat> lda_output;
string words;
- PacketHeader response_header;
-
- switch (packet.header.type) {
- case DATA_WAVE:
+ if (req.online_mode) {
+ OnlineDecode(sock);
+ }
+ else {
+ sock->Receive((char *)&header, sizeof(PacketHeader));
+ if (header.type == DATA_WAVE) {
// Receive wave
- wave.Resize(packet.header.payload_length / sizeof(float));
- cout << "Waiting for " << packet.header.payload_length << " bytes ("
- << packet.header.payload_length / sizeof(float) << " floats)\n";
- sock->Receive((char *)wave.Data(), packet.header.payload_length);
- cout << packet.header.payload_length << " bytes received.\n";
+ wave.Resize(header.payload_length / sizeof(float));
+ cout << "Waiting for " << header.payload_length << " bytes ("
+ << header.payload_length / sizeof(float) << " floats)\n";
+ sock->Receive((char *)wave.Data(), header.payload_length);
+ cout << header.payload_length << " bytes received.\n";
// Feature extraction
cout << "Feature extraction...\n";
mfcc_->Compute(wave, &mfcc_output, NULL);
@@ -153,18 +195,13 @@ void DecodeWorker::Work() {
cout << "Output: " << words << endl;
// Send back result
- response_header.payload_length = words.length();
- response_header.type = DECODE_OUTPUT;
- sock->Send((char *)&response_header, sizeof(PacketHeader));
- sock->Send(words.c_str(), words.length());
- break;
-
- case DATA_FEATURE:
- break;
-
- default:
- cerr << "Invalid request received.\n";
- break;
+ SendOutput(sock, true, words);
+ }
+ else if (header.type == DATA_FEATURE) {
+ }
+ else {
+ cerr << "Invalid request received: " << header.type << endl;
+ }
}
}
@@ -189,3 +226,103 @@ string DecodeWorker::Decode(Matrix<BaseFloat> &features) {
delete decodable_;
return words;
}
+
+void DecodeWorker::OnlineDecode(Socket *sock) {
+
+ const int32 kSampleFreq = 16000;
+ int32 vad_buffer_length_ms = 500;
+ int32 vad_hangover_ms = 500;
+ int32 batch_size = 27;
+ int32 frame_length_ms = 25;
+ int32 frame_shift_ms = 10;
+ float vad_onset_threshold = 400;
+ float vad_offset_threshold = 50;
+ float vad_recover_threshold = 100;
+ int frame_length_samples = frame_length_ms * (kSampleFreq/1000);
+ int frame_shift_samples = frame_shift_ms * (kSampleFreq/1000);
+ int vad_hangover_samples = vad_hangover_ms * (kSampleFreq/1000);
+ int vad_hangover_frames = (vad_hangover_samples - frame_length_samples) / frame_shift_samples + 1;
+ MfccOptions mfcc_opts;
+
+ OnlineFasterDecoderOpts config;
+ string silence_phones_str = "1:2:3:4:5:6:7:8:9:10:11:12:13:14:15";
+ std::vector<int32> silence_phones;
+ if (!SplitStringToIntegers(silence_phones_str, ":", false, &silence_phones))
+ KALDI_ERR << "Invalid silence-phones string " << silence_phones_str;
+ if (silence_phones.empty())
+ KALDI_ERR << "No silence phones given!";
+ VectorFst<LatticeArc> out_fst;
+
+ OnlineFasterDecoder decoder(*decode_fst_, config,
+ silence_phones, *trans_model_);
+
+ RemoteAudioSource audioSource(sock);
+
+ SimpleEnergyVad vad(vad_onset_threshold, vad_offset_threshold,
+ vad_recover_threshold, vad_hangover_frames);
+
+ OnlineVadFeInput vadMfccInput(&audioSource, mfcc_, &vad,
+ frame_length_samples,
+ frame_shift_samples,
+ vad_buffer_length_ms * (kSampleFreq / 1000));
+
+ OnlineCmvnInput cmvn_input(&vadMfccInput, mfcc_opts.num_ceps, 600);
+ OnlineLdaInput lda_input(&cmvn_input, mfcc_opts.num_ceps, *lda_transform_,
+ left_context_, right_context_);
+ int32 feat_dim = lda_transform_->NumRows();
+ OnlineDecodableDiagGmmScaled decodable(&lda_input, *gmm_, *trans_model_,
+ acoustic_scale_, batch_size,
+ feat_dim, -1);
+
+ // setup VAD-->Decoder event notification
+ OnlineFasterDecoderVadListener vadListener(&decoder);
+ vadMfccInput.AddVadListener(&vadListener);
+
+ OnlineFasterDecoder::DecodeState dstate;
+ string utterance = "";
+
+ while (true) {
+ dstate = decoder.Decode(&decodable);
+
+ if (dstate & (decoder.kEndUtt | decoder.kEndFeats)) {
+ std::vector<int32> word_ids;
+ decoder.FinishTraceBack(&out_fst);
+ fst::GetLinearSymbolSequence(out_fst,
+ static_cast<vector<int32> *>(0), &word_ids,
+ static_cast<LatticeArc::Weight*>(0));
+ string result = GetPartialResult(word_ids, symbol_table_);
+ utterance += result;
+ utterance = "";
+
+ SendOutput(sock, true, result);
+ cerr << result << endl;
+
+ } else if (dstate & (decoder.kEndBatch)) {
+ std::vector<int32> word_ids;
+ if (decoder.PartialTraceback(&out_fst)) {
+ fst::GetLinearSymbolSequence(out_fst,
+ static_cast<vector<int32> *>(0), &word_ids,
+ static_cast<LatticeArc::Weight*>(0));
+ string result = GetPartialResult(word_ids, symbol_table_);
+ utterance += result;
+
+ SendOutput(sock, false, result);
+
+ cerr << result;
+ }
+ }
+
+ /*
+ if (dstate == decoder.kEndFeats)
+ std::cerr << "*";
+ else if (dstate == decoder.kEndBatch) {
+ } else if (dstate == decoder.kEndUtt) {
+ std::cerr << "\n";
+ } else
+ std::cerr << "?";
+ */
+
+ if (dstate == decoder.kEndFeats)
+ break;
+ }
+}
View
@@ -4,6 +4,7 @@
#include "protocol.h"
#include <Worker.h>
#include <matrix/kaldi-matrix.h>
+#include <fst/vector-fst.h>
namespace kaldi {
class LatticeFasterDecoder;
@@ -12,12 +13,17 @@ namespace kaldi {
class DecodableInterface;
class Mfcc;
class OnlineCMN;
+ class OnlineFasterDecoder;
}
namespace fst {
class SymbolTable;
}
+namespace pckben {
+ class Socket;
+}
+
namespace speechsvr {
class DecodeWorker : public pckben::BackgroundWorker {
public:
@@ -28,14 +34,17 @@ namespace speechsvr {
fst::SymbolTable *symbol_table,
float acoustic_scale,
float left_context,
- float right_context);
+ float right_context,
+ fst::VectorFst<fst::StdArc> *decode_fst);
virtual ~DecodeWorker();
protected:
void Work();
std::string Decode(kaldi::Matrix<kaldi::BaseFloat> &features);
+ void OnlineDecode(pckben::Socket *sock);
private:
+ kaldi::OnlineFasterDecoder *online_decoder_;
kaldi::LatticeFasterDecoder *decoder_;
kaldi::DecodableInterface *decodable_;
kaldi::AmDiagGmm *gmm_;
@@ -47,6 +56,7 @@ namespace speechsvr {
kaldi::OnlineCMN *cmvn_;
float left_context_;
float right_context_;
+ fst::VectorFst<fst::StdArc> *decode_fst_;
};
}
#endif // SPEECHSVR_DECODEWORKER_H
View
@@ -1,34 +1,15 @@
-include ../make.mk
-OBJ = Config.o DecodeWorker.o SpeechServer.o
+include kaldi.mk
+SUBDIRS = online test
+OBJ = Config.o DecodeWorker.o protocol.o SpeechServer.o RemoteAudioSource.o
OUT = speechsvr speechclient
-LIBS = ../threaded-server.a \
- $(KALDIROOT)/lat/kaldi-lat.a \
- $(KALDIROOT)/decoder/kaldi-decoder.a \
- $(KALDIROOT)/feat/kaldi-feature.a \
- $(KALDIROOT)/transform/kaldi-transform.a \
- $(KALDIROOT)/gmm/kaldi-gmm.a \
- $(KALDIROOT)/hmm/kaldi-hmm.a \
- $(KALDIROOT)/tree/kaldi-tree.a \
- $(KALDIROOT)/matrix/kaldi-matrix.a \
- $(KALDIROOT)/util/kaldi-util.a \
- $(KALDIROOT)/base/kaldi-base.a \
- $(KALDIROOT)/online/kaldi-online.a \
- $(FSTROOT)/lib/libfst.a \
- -framework Accelerate \
- -lm -ldl
-
-CCFLAGS += -I../src \
- -DKALDI_DOUBLEPRECISION=0 -DHAVE_POSIX_MEMALIGN \
- -Wno-sign-compare -Winit-self \
- -DHAVE_EXECINFO_H=1 -DHAVE_CXXABI_H -rdynamic \
- -DHAVE_CLAPACK \
- -I$(KALDIROOT) \
- -I$(FSTROOT)/include \
-
-.PHONY: all test
-
-all: test speechsvr speechclient
+LIBS += ../threaded-server.a online-decoder/online-decoder.a \
+
+CCFLAGS += -I../src -Ionline-decoder \
+
+.PHONY: all $(SUBDIRS)
+
+all: $(SUBDIRS) speechsvr speechclient
speechsvr: $(OBJ) speechsvr.o
$(CC) $(CCFLAGS) -o $@ $(OBJ) speechsvr.o $(LIBS)
@@ -39,9 +20,9 @@ speechclient: $(OBJ) speechclient.o
%.o: %.cc
$(CC) $(CCFLAGS) -c $< -o $@
-test:
+$(SUBDIR):
$(MAKE) -C $@
clean:
-rm -f $(OBJ) $(OUT) speechsvr.o speechclient.o
- $(MAKE) -C test clean
+ -for x in $(SUBDIRS); do $(MAKE) -C $x clean; done
Oops, something went wrong.

0 comments on commit ed015db

Please sign in to comment.