diff --git a/matlab/meet/evaluate/evalcv.m b/matlab/meet/evaluate/evalcv.m deleted file mode 100644 index c6ffff2..0000000 --- a/matlab/meet/evaluate/evalcv.m +++ /dev/null @@ -1,37 +0,0 @@ -function cvstat = evalcv(stat) -nfold = length(stat); -train = zeros(2, nfold); -validate = zeros(2, nfold); -if isfield(stat, 'test') - test = zeros(2, 1); -end - -for i = 1 : nfold - train(:, i) = stat{i}.train.error; - validate(:, i) = stat{i}.validate.error; - if isfield(stat, 'test') - test(:, i) = stat{i}.test.error; - end -end - -cvstat.train.meanerror = mean(train, 2); -disp('training error mean:'); -disp(cvstat.train.meanerror); - -cvstat.train.stderror = std(train, 0, 2); -disp('training error std:'); -disp(cvstat.train.stderror); - -cvstat.validate.meanerror = mean(validate, 2); -disp('validation error mean:'); -disp(cvstat.validate.meanerror); - -cvstat.validate.stderror = std(validate, 0, 2); -disp('validation error std:'); -disp(cvstat.validate.stderror); - -if isfield(stat, 'test') - cvstat.test.meanerror = mean(test, 2); - cvstat.test.stderror = std(test, 0, 2); -end -end diff --git a/matlab/meet/learn/ahmm/checkahmm.m b/matlab/meet/learn/ahmm/checkahmm.m new file mode 100644 index 0000000..23720b7 --- /dev/null +++ b/matlab/meet/learn/ahmm/checkahmm.m @@ -0,0 +1,4 @@ +function checkahmm(ahmm) +learned_GCPT = CPD_to_CPT(ahmm.CPD{5}); +learned_GCPT(:, 1, :); +end \ No newline at end of file diff --git a/matlab/meet/learn/ahmm/inferenceahmm.m b/matlab/meet/learn/ahmm/inferenceahmm.m index f9217e9..46bf01e 100644 --- a/matlab/meet/learn/ahmm/inferenceahmm.m +++ b/matlab/meet/learn/ahmm/inferenceahmm.m @@ -7,6 +7,8 @@ for i = 1 : nseq evidence = data{i}; engine = enter_evidence(engine, evidence); - R{i} = mapest(engine, predictNode, length(evidence)); + %R{i} = mapest(engine, predictNode, length(evidence)); + mpe = find_mpe(engine, evidence); + R{i} = mpe(predictNode, :); end end diff --git a/matlab/meet/learn/ahmm/learnahmm.m b/matlab/meet/learn/ahmm/learnahmm.m index cee7b94..624f41d 100644 --- a/matlab/meet/learn/ahmm/learnahmm.m +++ b/matlab/meet/learn/ahmm/learnahmm.m @@ -21,6 +21,7 @@ predictNode = [param.G1 param.F1]; param.onodes = [param.X1]; finalAhmm = sethiddenbit(finalAhmm, param.onodes); +checkahmm(finalAhmm); trainData = createInputData(Y.train, X.train, param); R.train = inferenceahmm(finalAhmm, trainData, predictNode); diff --git a/matlab/meet/learn/ahmm/mapest.m b/matlab/meet/learn/ahmm/mapest.m index fcea80a..8894e53 100644 --- a/matlab/meet/learn/ahmm/mapest.m +++ b/matlab/meet/learn/ahmm/mapest.m @@ -1,4 +1,6 @@ function mapEst = mapest(engine, hnode, T) +% Args: +% - hnode: hidden node that we want to estimate the state. % Return % - mapEst: a cell array. nhnode = length(hnode); diff --git a/matlab/test/TestAHMM.m b/matlab/test/TestAHMM.m index 1a55f77..54a4449 100644 --- a/matlab/test/TestAHMM.m +++ b/matlab/test/TestAHMM.m @@ -150,8 +150,8 @@ function testInference(self) function testLearning(self) T = 20; max_iter = 10; - true_params = self.deterministicParams; - ahmm = createahmm(true_params); + trueParam = self.deterministicParams; + ahmm = createahmm(trueParam); ev = sample_dbn(ahmm, 'length', T); ss = length(ahmm.intra); evidence = cell(1, 1); @@ -173,19 +173,19 @@ function testLearning(self) assertTrue(all(cell2mat(mapS(:)) == trueS(:))); learned_Gstartprob = CPD_to_CPT(final_ahmm.CPD{1}); - assertTrue(all(learned_Gstartprob == true_params.Gstartprob(:))); + assertTrue(all(learned_Gstartprob == trueParam.Gstartprob(:))); learned_Sstartprob = CPD_to_CPT(final_ahmm.CPD{2}); assertTrue(learned_Sstartprob(1, 1) == 1); learned_Stermprob = CPD_to_CPT(final_ahmm.CPD{3}); assertTrue(all(learned_Stermprob(:, 1) == 0)); learned_hand = struct(final_ahmm.CPD{4}).hand; - assertTrue(all(learned_hand(:) == true_params.hand(:))); + assertTrue(all(learned_hand(:) == trueParam.hand(:))); learned_GCPT = CPD_to_CPT(final_ahmm.CPD{5}); learned_GCPT = learned_GCPT(:, 1, :); expected = eye(new_params.nG, new_params.nG); assertTrue(all(learned_GCPT(:) == expected(:))); learned_Gtransprob = struct(final_ahmm.CPD{5}).transprob; - assertTrue(all(learned_Gtransprob(:) == true_params.Gtransprob(:))); + assertTrue(all(learned_Gtransprob(:) == trueParam.Gtransprob(:))); learned_Stransprob = CPD_to_CPT(final_ahmm.CPD{6}); assertTrue(learned_Stransprob(1, 2, 2) == 1); assertTrue(learned_Stransprob(2, 3, 3) == 1); diff --git a/matlab/test/TestEvalClassification.m b/matlab/test/TestEvalClassification.m index 3a991c7..7cd64e6 100644 --- a/matlab/test/TestEvalClassification.m +++ b/matlab/test/TestEvalClassification.m @@ -11,9 +11,11 @@ function testEvalOneFold(self) %#ok R.train = {{1 2; 3 4}}; R.validate = {{5 6; 7 8}}; stat = evalclassification(Y, R, @zerooneloss); - assertTrue(length(stat.train.error) == 2); - assertTrue(all(stat.train.error == 0)); - assertTrue(all(stat.validate.error == 0)); + train = stat('train'); + validate = stat('validate'); + assertTrue(length(train('error')) == 2); + assertTrue(all(train('error') == 0)); + assertTrue(all(validate('error') == 0)); end end end \ No newline at end of file