Skip to content

Commit

Permalink
Used BNT's mpe instead of node marginals to do map estimation.
Browse files Browse the repository at this point in the history
  • Loading branch information
ushadow committed Apr 4, 2013
1 parent e4b6e63 commit 1e6a04d
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 46 deletions.
37 changes: 0 additions & 37 deletions matlab/meet/evaluate/evalcv.m

This file was deleted.

4 changes: 4 additions & 0 deletions matlab/meet/learn/ahmm/checkahmm.m
@@ -0,0 +1,4 @@
function checkahmm(ahmm)
learned_GCPT = CPD_to_CPT(ahmm.CPD{5});
learned_GCPT(:, 1, :);
end
4 changes: 3 additions & 1 deletion matlab/meet/learn/ahmm/inferenceahmm.m
Expand Up @@ -7,6 +7,8 @@
for i = 1 : nseq
evidence = data{i};
engine = enter_evidence(engine, evidence);
R{i} = mapest(engine, predictNode, length(evidence));
%R{i} = mapest(engine, predictNode, length(evidence));
mpe = find_mpe(engine, evidence);
R{i} = mpe(predictNode, :);
end
end
1 change: 1 addition & 0 deletions matlab/meet/learn/ahmm/learnahmm.m
Expand Up @@ -21,6 +21,7 @@
predictNode = [param.G1 param.F1];
param.onodes = [param.X1];
finalAhmm = sethiddenbit(finalAhmm, param.onodes);
checkahmm(finalAhmm);

trainData = createInputData(Y.train, X.train, param);
R.train = inferenceahmm(finalAhmm, trainData, predictNode);
Expand Down
2 changes: 2 additions & 0 deletions matlab/meet/learn/ahmm/mapest.m
@@ -1,4 +1,6 @@
function mapEst = mapest(engine, hnode, T)
% Args:
% - hnode: hidden node that we want to estimate the state.
% Return
% - mapEst: a cell array.
nhnode = length(hnode);
Expand Down
10 changes: 5 additions & 5 deletions matlab/test/TestAHMM.m
Expand Up @@ -150,8 +150,8 @@ function testInference(self)
function testLearning(self)
T = 20;
max_iter = 10;
true_params = self.deterministicParams;
ahmm = createahmm(true_params);
trueParam = self.deterministicParams;
ahmm = createahmm(trueParam);
ev = sample_dbn(ahmm, 'length', T);
ss = length(ahmm.intra);
evidence = cell(1, 1);
Expand All @@ -173,19 +173,19 @@ function testLearning(self)
assertTrue(all(cell2mat(mapS(:)) == trueS(:)));

learned_Gstartprob = CPD_to_CPT(final_ahmm.CPD{1});
assertTrue(all(learned_Gstartprob == true_params.Gstartprob(:)));
assertTrue(all(learned_Gstartprob == trueParam.Gstartprob(:)));
learned_Sstartprob = CPD_to_CPT(final_ahmm.CPD{2});
assertTrue(learned_Sstartprob(1, 1) == 1);
learned_Stermprob = CPD_to_CPT(final_ahmm.CPD{3});
assertTrue(all(learned_Stermprob(:, 1) == 0));
learned_hand = struct(final_ahmm.CPD{4}).hand;
assertTrue(all(learned_hand(:) == true_params.hand(:)));
assertTrue(all(learned_hand(:) == trueParam.hand(:)));
learned_GCPT = CPD_to_CPT(final_ahmm.CPD{5});
learned_GCPT = learned_GCPT(:, 1, :);
expected = eye(new_params.nG, new_params.nG);
assertTrue(all(learned_GCPT(:) == expected(:)));
learned_Gtransprob = struct(final_ahmm.CPD{5}).transprob;
assertTrue(all(learned_Gtransprob(:) == true_params.Gtransprob(:)));
assertTrue(all(learned_Gtransprob(:) == trueParam.Gtransprob(:)));
learned_Stransprob = CPD_to_CPT(final_ahmm.CPD{6});
assertTrue(learned_Stransprob(1, 2, 2) == 1);
assertTrue(learned_Stransprob(2, 3, 3) == 1);
Expand Down
8 changes: 5 additions & 3 deletions matlab/test/TestEvalClassification.m
Expand Up @@ -11,9 +11,11 @@ function testEvalOneFold(self) %#ok<MANU>
R.train = {{1 2; 3 4}};
R.validate = {{5 6; 7 8}};
stat = evalclassification(Y, R, @zerooneloss);
assertTrue(length(stat.train.error) == 2);
assertTrue(all(stat.train.error == 0));
assertTrue(all(stat.validate.error == 0));
train = stat('train');
validate = stat('validate');
assertTrue(length(train('error')) == 2);
assertTrue(all(train('error') == 0));
assertTrue(all(validate('error') == 0));
end
end
end

0 comments on commit 1e6a04d

Please sign in to comment.