-
Notifications
You must be signed in to change notification settings - Fork 1
/
dbnMDP_BNT.m
86 lines (63 loc) · 2.3 KB
/
dbnMDP_BNT.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
% Make and train a DBN from data
%
% Input:
% data: 2d array of training data
% dataTrain: Training data without interpolation
% dataTrainMiss: training data with interpolation and missing data
% max_iter: maximum number of iteration to convergence by the EM algorithm
% intraLength: integer number of variables
% interLength: integer number dynamic variables
% horizon: integer number of time points
%
%
% Output:
% cell - array of probabilities for each node/variable of the learned DBN model.
function C = dbnMDP_BNT(intraLength, interLength, ns, horizon, data, max_iter)
%%%%%%%%%% clear output & turn off matlab-octave short circuit warnings %%%%%%%%
clc;
warning('off', 'Octave:possible-matlab-short-circuit-operator');
%%%%%%%%%%%%%%%%%%%% get path to BNT %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
origPath = pwd;
cd ./BNT
addpath(genpathKPM(pwd))
cd(origPath)
%%%%%%%%%%%%%%%%%%%% define in slice edges %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
intra = zeros(intraLength);
%%%%%%%%%%%%%%%%%%%% define edges between slices %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
inter = zeros(interLength);
inter(1,1) = 1;
%%%%%%%%%%%%%%%%%%%%% DBN definition %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
dnodes = 1:intraLength;
bnet = mk_dbn(intra, inter, ns, 'discrete', dnodes);%, 'observed', onodes);
% nodes initialization - rnadomized
for i=1:intraLength+1
bnet.CPD{i} = tabular_CPD(bnet, i,'CPT','rnd');
end
%%%%%%%%%%%%%%%%%%% engine definition %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
engine = smoother_engine(jtree_2TBN_inf_engine(bnet));
%%%%%%%%%%%%%%%%%%%%%%%% create cases from dataset %%%%%%%%%%%%%%%%%%%%%%%%%%%%%
ss = intraLength;%slice size(ss)
T = horizon;
data(isnan(data)) = -1;
sizeData = size(data);
ncases = sizeData(1);
ncolumns = sizeData(2);
cases = cell(1, ncases);
for i=1:ncases
cases{i} = cell(ss,T);
for j=1:ncolumns
if data(i,j)==-1
cases{i}{j} = [];
else
cases{i}{j} = data(i,j);
end
end
end
%%%%%%%%%%%%%%%%%%% learn dbn %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
[bnet2, LLtrace] = learn_params_dbn_em(engine, cases, 'max_iter', max_iter);
%%%%%%%%%%%%%%%%%%% return dbn components %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
C = {};
for i=1:intraLength+1
C{i} = struct(bnet2.CPD{i}).CPT;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%