-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtapas_softmax.m
64 lines (51 loc) · 1.77 KB
/
tapas_softmax.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
function [logp, yhat, res] = tapas_softmax(r, infStates, ptrans)
% Calculates the log-probability of responses under the softmax model
%
% --------------------------------------------------------------------------------------------------
% Copyright (C) 2013-2019 Christoph Mathys, TNU, UZH & ETHZ
%
% This file is part of the HGF toolbox, which is released under the terms of the GNU General Public
% Licence (GPL), version 3. You can redistribute it and/or modify it under the terms of the GPL
% (either version 3 or, at your option, any later version). For further details, see the file
% COPYING or <http://www.gnu.org/licenses/>.
% Predictions or posteriors?
pop = 1; % Default: predictions
if r.c_obs.predorpost == 2
pop = 3; % Alternative: posteriors
end
% Transform beta to its native space
be = exp(ptrans(1));
% Initialize returned log-probabilities, predictions,
% and residuals as NaNs so that NaN is returned for all
% irregualar trials
n = size(infStates,1);
logp = NaN(n,1);
yhat = NaN(n,1);
res = NaN(n,1);
% Assumed structure of infStates:
% dim 1: time (ie, input sequence number)
% dim 2: HGF level
% dim 3: choice number
% dim 4: 1: muhat, 2: sahat, 3: mu, 4: sa
% Number of choices
nc = size(infStates,3);
% Belief trajectories at 1st level
states = squeeze(infStates(:,1,:,pop));
% Responses
y = r.y(:,1);
% Weed irregular trials out from inferred states and responses
states(r.irr,:) = [];
y(r.irr) = [];
% Partition functions
Z = sum(exp(be*states),2);
Z = repmat(Z,1,nc);
% Softmax probabilities
prob = exp(be*states)./Z;
% Extract probabilities of chosen options
probc = prob(sub2ind(size(prob), 1:length(y), y'));
% Calculate log-probabilities for non-irregular trials
reg = ~ismember(1:n,r.irr);
logp(reg) = log(probc);
yhat(reg) = probc;
res(reg) = -log(probc);
end