/
DNNRegression.m
171 lines (155 loc) · 5.89 KB
/
DNNRegression.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
function [g,pred] = DNNRegression(stimulus, response, direction, ...
testdata, g, lags, valid, verbosity)
% function [g,pred] = DNNRegression(stimulus, response, direction, ...
% testdata, g, lags, valid, verbosity)
if direction > 0
dirArg = 'forward';
else
dirArg = 'reverse';
end
if isempty(g) || strcmp(g, '')
modelDataFile = [tempname() '.pkl'];
else
modelDataFile = g;
end
if ~exist('lags', 'var'), lags = 25; end;
if ~exist('valid', 'var'), valid = []; end;
if ~exist('verbosity','var'), verbosity = 0; end;
cmdName = which('DNNregression.py'); % Look for the python script
if isempty(cmdName), error('Can''t find the DNNregression.py file.'); end;
yamlName = which('network.yaml'); % Look for network description
if isempty(yamlName), error('Can''t find the network.yaml file.'); end;
% If we have it, save the valid file.
% QUESTION? Is it only valid during training?
if ~isempty(valid)
validDataFile = [tempname() '.mat'];
data = valid;
save(validDataFile, 'data');
else
validDataFile = [];
end
if ~isempty(stimulus) && ~isempty(response)
% Model prediction function calls looks like this:
% python DNNregression.py -t -m "./network.yaml"
% -s "trainUnattendedAudio.mat" \
% -r "trainResponse.mat" -w "./network_best_3.pkl" --debug
if ischar(stimulus)
stimulusDataFile = stimulus;
else
stimulusDataFile = [tempname() '.mat'];
data = stimulus;
save(stimulusDataFile, 'data');
end
if ischar(response)
responseDataFile = response;
else
responseDataFile = [tempname() '.mat'];
data = response;
save(responseDataFile, 'data');
end
clear data;
cmd = sprintf('python "%s" -t --dir %s -m "%s" -s "%s" -r "%s" -w "%s" --context %d --verbosity %d', ...
cmdName, dirArg, yamlName, stimulusDataFile, responseDataFile, ...
modelDataFile, lags, 0);
if ~isempty(validDataFile)
cmd = sprintf('%s --valid "%s"', cmd, validDataFile);
end
if verbosity
fprintf('Executing the training command: %s"\n', cmd);
system(cmd);
else
[err,status] = system(cmd);
if err
% If you get to this point, you can see the PyLearn2 output
% in the variable status.
error('DNN training command failed. Rerun with verbosity=1 to see the error.');
end
end
end
if ~isempty(testdata)
if ischar(testdata)
testDataFile = testdata;
else
testDataFile = [tempname() '.mat'];
data = testdata;
save(testDataFile, 'data');
clear data;
end
predictionDataFile = [tempname() '.mat'];
if direction >= 0 % Given stimulus, predict response
inputFlag = '-s';
outputFlag = '-r';
else % Given response, predict stimulus
inputFlag = '-r';
outputFlag = '-s';
end
cmd = sprintf('python "%s" -p --dir %s -m "%s" %s "%s" %s "%s" -w "%s" --context %d --verbosity %d', ...
cmdName, dirArg, yamlName, inputFlag, testDataFile, outputFlag, predictionDataFile, ...
modelDataFile, lags, 0);
if verbosity
fprintf('Executing the prediction command: %s\n', cmd);
system(cmd);
else
[err,status] = system(cmd);
if err
% If you get to this point, you can see the PyLearn2 output
% in the variable status.
error('DNN prediction command failed. Rerun with verbosity=1 to see the error.');
end
end
predictionResult = load(predictionDataFile);
pred = predictionResult.data;
end
g = modelDataFile;
if 0
%%
lags = round(1.5*impulseLength*fs);
dnnTrainingWindow = 2; % Seconds on each side of the attention switch
iTrain = find(recordingT > attentionDuration - dnnTrainingWindow & ...
recordingT < attentionDuration + dnnTrainingWindow);
iTest = find(recordingT > 2*attentionDuration & ...
recordingT < recordingT(end-lags));
dnnDirection = -1;
% lags = 1;
% Now calculate the models for the attended and unattended signals.
verbosity = 1;
attentionModel = DNNRegression(attendedAudio(iTrain), response(iTrain, :), ...
dnnDirection, [], [], lags, verbosity);
unattentionModel = DNNRegression(unattendedAudio(iTrain), response(iTrain, :), ...
dnnDirection, [], [], lags, verbosity);
%%
[~, attendedPrediction] = DNNRegression([], [], ...
dnnDirection, response, attentionModel, lags);
[~, unattendedPrediction] = DNNRegression([], [], ...
dnnDirection, response, unattentionModel, lags);
ca = corrcoef([attendedAudio(iTest) attendedPrediction(iTest)]);
cu = corrcoef([unattendedAudio(iTest) unattendedPrediction(iTest)]);
fprintf('Attended correlation: %g, Unattended correlation: %g.\n', ...
ca(1,2), cu(1,2));
%%
% Plot the predicted stimuli
clf
attentionSwitchPick = 3;
iPlot = find(recordingT>attentionSwitchPick*attentionDuration-dnnTrainingWindow & ...
recordingT < attentionSwitchPick*attentionDuration+dnnTrainingWindow);
plot(recordingT(iPlot), [attendedPrediction(iPlot) unattendedPrediction(iPlot)]')
legend('Attended Signal', 'Unattended Signal');
title('Predicted Signals');
xlabel('time (seconds)'); ylabel('Intensity');
axis tight
%%
% Plot the matrix of measured/predicted and attended/unattended signals
clf
subplot(2, 2, 1);
plot(recordingT(iPlot), attendedAudio(iPlot()));
title('Attended Signal'); axis tight
subplot(2, 2, 2);
plot(recordingT(iPlot), unattendedAudio(iPlot()));
title('Unattended Signal'); axis tight
subplot(2, 2, 3);
plot(recordingT(iPlot), attendedPrediction(iPlot()));
title('Attended TRF Prediction'); axis tight
subplot(2, 2, 4);
plot(recordingT(iPlot), unattendedPrediction(iPlot()));
title('Unattended TRF Prediction'); axis tight
end