Skip to content

Commit

Permalink
Added online filtering inference.
Browse files Browse the repository at this point in the history
  • Loading branch information
ushadow committed Apr 11, 2013
1 parent d2801db commit 13cf429
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
23 changes: 22 additions & 1 deletion matlab/meet/learn/ahmm/inferenceahmm.m
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
function R = inferenceahmm(ahmm, data, predictNode, method)

engine = smoother_engine(jtree_2TBN_inf_engine(ahmm));
switch method
case 'fixed-interval-smoothing'
case 'viterbi'
engine = smoother_engine(jtree_2TBN_inf_engine(ahmm));
case 'filtering'
engine = filter_engine(jtree_2TBN_inf_engine(ahmm));
otherwise
error(['Inference method not implemented: ' method]);
end

nseq = length(data);
R = cell(1, nseq);
Expand All @@ -10,6 +18,19 @@
case 'fixed-interval-smoothing'
engine = enter_evidence(engine, evidence);
R{i} = mapest(engine, predictNode, length(evidence));
case 'filtering'
T = size(evidence, 2);
nhnode = length(predictNode);
mapEst = cell(nhnode, T);
for t = 1 : T
for n = 1 : nhnode
engine = enter_evidence(engine, evidence, t);
m = marginal_nodes(engine, predictNode, t);
[~, ndx] = max(m.T);
mapEst{n, t} = ndx;
end
end
R{i} = mapEst;
case 'viterbi'
% Find the most probable explanation (Viterbi).
mpe = find_mpe(engine, evidence);
Expand Down
2 changes: 1 addition & 1 deletion matlab/meet/learn/ahmm/learnahmm.m
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
trainData = createInputData(Y.train, X.train, param);
finalAhmm = trainahmm(param, trainData);

predictNode = [param.G1 param.F1];
predictNode = [param.G1];
param.onodes = [param.X1];
finalAhmm = sethiddenbit(finalAhmm, param.onodes);
checkahmm(finalAhmm);
Expand Down

0 comments on commit 13cf429

Please sign in to comment.