Skip to content

Commit

Permalink
Rewritten test framework:
Browse files Browse the repository at this point in the history
 * separate 'features' argument for datasets. This allows variants of a dataset with different features to be compared
 * separated writing of results into a separate script

Changed predict_* methods to take a struct of options instead of key,value pairs
  • Loading branch information
twanvl committed Jan 16, 2018
1 parent f7c7525 commit 0fc35b3
Show file tree
Hide file tree
Showing 21 changed files with 1,197 additions and 770 deletions.
12 changes: 6 additions & 6 deletions src/abib/predict_abib.m
Expand Up @@ -13,7 +13,7 @@
if ~isfield(opts,'num_iterations'), opts.num_iterations = 500; end
if ~isfield(opts,'alpha'), opts.alpha = 0.5; end
if ~isfield(opts,'classifier'), opts.classifier = @predict_liblinear_cv; end
if ~isfield(opts,'classifier_opts'), opts.classifier_opts = {}; end
if ~isfield(opts,'classifier_opts'), opts.classifier_opts = struct(); end
if ~isfield(opts,'classifier_opts_source'), opts.classifier_opts_source = opts.classifier_opts; end
if ~isfield(opts,'use_source_C'), opts.use_source_C = false; end
if ~isfield(opts,'use_target_C'), opts.use_target_C = true; end
Expand All @@ -26,13 +26,13 @@
if ~isfield(opts,'combine_models'), opts.combine_models = false; end

% Classifier for the source
[y_tgt,best_opts_src,model_src] = opts.classifier(x_src, y_src, x_tgt, opts.classifier_opts_source{:});
[y_tgt,best_opts_src,model_src] = opts.classifier(x_src, y_src, x_tgt, opts.classifier_opts_source);
y_tgt_from_src = y_tgt;
if opts.use_source_C
opts.classifier = @predict_liblinear;
opts.classifier_opts = best_opts_src;
elseif opts.use_target_C
[~, best_opts_tgt] = opts.classifier(x_tgt,y_tgt,[], opts.classifier_opts{:});
[~, best_opts_tgt] = opts.classifier(x_tgt,y_tgt,[], opts.classifier_opts);
opts.classifier = @predict_liblinear;
opts.classifier_opts = best_opts_tgt;
end
Expand All @@ -55,7 +55,7 @@
else
which = ceil(rand(n,1)*n);
end
[y_tgt1,~,model_src] = opts.classifier(x_src(which,:), y_src(which,:), x_tgt, opts.classifier_opts_source{:});
[y_tgt1,~,model_src] = opts.classifier(x_src(which,:), y_src(which,:), x_tgt, opts.classifier_opts_source);
end

% Classifier for the target
Expand All @@ -74,9 +74,9 @@
else
which = ceil(rand(n,1)*n);
end
[y_tgt2,~,model_tgt] = opts.classifier(x_tgt(which,:), y_tgt(which,:), x_tgt, opts.classifier_opts{:});
[y_tgt2,~,model_tgt] = opts.classifier(x_tgt(which,:), y_tgt(which,:), x_tgt, opts.classifier_opts);
else
[y_tgt2,~,model_tgt] = opts.classifier(x_tgt, y_tgt, x_tgt, opts.classifier_opts{:});
[y_tgt2,~,model_tgt] = opts.classifier(x_tgt, y_tgt, x_tgt, opts.classifier_opts);
end

% Combined classification
Expand Down
8 changes: 1 addition & 7 deletions src/abib/predict_liblinear.m
Expand Up @@ -11,7 +11,6 @@
end
if ~isfield(opts,'type') opts.type = 3; end
if ~isfield(opts,'C') opts.C = 1; end
if ~isfield(opts,'probability') opts.probability = false; end
if ~isfield(opts,'bias') opts.bias = false; end
if opts.bias
bias = 1;
Expand All @@ -22,12 +21,7 @@
model = train(y_src, sparse(x_src), sprintf('-q -s %d -c %g -B %d',opts.type, opts.C, bias));
y_tgt = zeros(size(x_tgt,1),1);
if nargout>1
label_order = model.Label;
if opts.probability
[y_tgt,acc,s_tgt] = predict(y_tgt, sparse(x_tgt), model,'-q -b 1');
else
[y_tgt,acc,s_tgt] = predict(y_tgt, sparse(x_tgt), model,'-q');
end
[y_tgt,acc,s_tgt] = predict(y_tgt, sparse(x_tgt), model,'-q');
else
y_tgt = predict(y_tgt, sparse(x_tgt), model,'-q');
end
Expand Down
15 changes: 7 additions & 8 deletions src/abib/predict_liblinear_cv.m
Expand Up @@ -12,7 +12,11 @@
% 'verbose',i verbosity (default 0)
% 'C',cs list of values of the C parameter to try

opts = struct(varargin{:});
if length(varargin) == 1 && isstruct(varargin{1})
opts = varargin{1};
else
opts = struct(varargin{:});
end
if ~isfield(opts,'type'), opts.type = 3; end
if ~isfield(opts,'C'), opts.C = [0.001 0.01 0.1 1.0 10 100 1000 10000]; end
if ~isfield(opts,'num_folds'), opts.num_folds = 2; end
Expand All @@ -31,7 +35,7 @@
end
[best_acc,best_i] = max(acc);
best_C = opts.C(best_i);
best_opts = {'C', best_C, 'type', opts.type, 'bias', opts.bias, 'probability',opts.probability};
best_opts = struct('C', best_C, 'type', opts.type, 'bias', opts.bias, 'probability',opts.probability);

if opts.verbose
fprintf('[best C: %g]', best_C);
Expand All @@ -40,12 +44,7 @@
model = train(y_src, sparse(x_src), sprintf('-q -s %d -c %g -B %d',opts.type,best_C,bias));
y_tgt = zeros(size(x_tgt,1),1);
if nargout>1
label_order = model.Label;
if opts.probability
[y_tgt,acc,s_tgt] = predict(y_tgt, sparse(x_tgt), model,'-q -b 1');
else
[y_tgt,acc,s_tgt] = predict(y_tgt, sparse(x_tgt), model,'-q');
end
[y_tgt,acc,s_tgt] = predict(y_tgt, sparse(x_tgt), model,'-q');
else
y_tgt = predict(y_tgt, sparse(x_tgt), model,'-q');
end
Expand Down
8 changes: 7 additions & 1 deletion src/comparison_methods/predict_coral.m
@@ -1,8 +1,14 @@
function Ytt = predict_coral(Xr,Yr,Xtt)
function Ytt = predict_coral(Xr,Yr,Xtt, varargin)
% CORAL
%
% Code adapted from https://github.com/VisionLearningGroup/CORAL

if length(varargin) == 1 && isstruct(varargin{1})
opts = varargin{1};
else
opts = struct(varargin{:});
end

% don't run on high dimensional data
if size(Xr,2) > 10000
Ytt = [];
Expand Down
6 changes: 5 additions & 1 deletion src/comparison_methods/predict_flda.m
Expand Up @@ -3,7 +3,11 @@
%
% Based on code from https://github.com/wmkouw/da-fl

opts = struct(varargin{:});
if length(varargin) == 1 && isstruct(varargin{1})
opts = varargin{1};
else
opts = struct(varargin{:});
end
if ~isfield(opts,'lambda'), opts.lambda = 1; end
if ~isfield(opts,'distribution'), opts.distribution = 'dropout'; end % dropout or blankout
if ~isfield(opts,'loss'), opts.loss = 'log'; end % qd or log
Expand Down
6 changes: 5 additions & 1 deletion src/comparison_methods/predict_flda_cv.m
Expand Up @@ -3,7 +3,11 @@
%
% Use cross-validation on the source domain to pick lambda

opts = struct(varargin{:});
if length(varargin) == 1 && isstruct(varargin{1})
opts = varargin{1};
else
opts = struct(varargin{:});
end
if ~isfield(opts,'lambdas'), opts.lambdas = 10.^(-4:0.5:5); end
if ~isfield(opts,'distribution'), opts.distribution = 'dropout'; end % dropout or blankout
if ~isfield(opts,'loss'), opts.loss = 'log'; end % qd or log
Expand Down
8 changes: 6 additions & 2 deletions src/comparison_methods/predict_gfk.m
@@ -1,8 +1,12 @@
function y_tgt = predict_GFK(x_src, y_src, x_tgt, varargin)
function y_tgt = predict_gfk(x_src, y_src, x_tgt, varargin)
% GFK
% See http://www-scf.usc.edu/~boqinggo/domainadaptation.html

opts = struct(varargin{:});
if length(varargin) == 1 && isstruct(varargin{1})
opts = varargin{1};
else
opts = struct(varargin{:});
end
if ~isfield(opts,'use_pls'), opts.use_pls = false; end;
if ~isfield(opts,'d'), opts.d = 10; end;
if ~isfield(opts,'svm'), opts.svm = false; end;
Expand Down
6 changes: 5 additions & 1 deletion src/comparison_methods/predict_sa.m
Expand Up @@ -11,7 +11,11 @@
% year = {2013},
% }

opts = struct(varargin{:});
if length(varargin) == 1 && isstruct(varargin{1})
opts = varargin{1};
else
opts = struct(varargin{:});
end
if ~isfield(opts,'svm_sqrt'), opts.svm_sqrt = false; end;
if ~isfield(opts,'subspace_dim'), opts.subspace_dim = 80; end;

Expand Down
23 changes: 3 additions & 20 deletions src/evaluation/all_methods.m
Expand Up @@ -14,31 +14,14 @@
methods{end+1} = struct(...
'name', 'Source SVM',...
'method', @predict_liblinear_cv,...
'args', {{}},...
'preferred_preprocessing', {{'truncate,joint-std','joint-zscore','joint-std','zscore','std','none'}});
'args', struct());
end

if 1
methods{end+1} = struct(...
'name', 'Source LR',...
'method', @predict_liblinear_cv,...
'args', {{'type',7}},...
'preferred_preprocessing', {{'truncate,joint-std','joint-zscore','joint-std','zscore','std','none'}});
end

if 1
methods{end+1} = struct(...
'name', 'ABiB-SVM',...
'name', 'ABiB',...
'method', @predict_abib,...
'args', {{}},...
'preferred_preprocessing', {{'truncate,joint-std','joint-std','std','zscore','none'}});
'args', struct());
end

if 1
methods{end+1} = struct(...
'name', 'ABiB-LR',...
'method', @predict_abib,...
'args', {{'classifier_opts',{{'type',7}}}},...
'preferred_preprocessing', {{'truncate,joint-std','joint-std','std','zscore','none'}});
end
end
50 changes: 32 additions & 18 deletions src/evaluation/all_methods_literature.m
@@ -1,27 +1,45 @@
function methods = all_methods_literature()
function methods = all_methods_literature(varargin)
% Methods from the literature

if length(varargin) == 1 && isstruct(varargin{1})
opts = varargin{1};
else
opts = struct(varargin{:});
end
if ~isfield(opts,'toolbox_sa') opts.toolbox_sa = false; end
if ~isfield(opts,'include_sa') opts.include_sa = true; end
if ~isfield(opts,'include_tca') opts.include_tca = true; end
if ~isfield(opts,'include_gfk') opts.include_gfk = true; end
if ~isfield(opts,'include_flda') opts.include_flda = true; end

methods = {};

if 1
methods{end+1} = struct(...
'name', 'SA',...
'method', @predict_sa,...
'args', {{'subspace_dim',80}});
end
if 1
methods{end+1} = struct(...
'name', 'GFK svm',...
'method', @predict_gfk,...
'args', {{'svm',1}});
if opts.include_sa
if opts.toolbox_sa
methods{end+1} = struct(...
'name', 'SA',...
'method', @predict_da_toolbox,...
'args', {{'ftTrans_sa'}});
else
methods{end+1} = struct(...
'name', 'SA',...
'method', @predict_sa,...
'args', {{'subspace_dim',80}});
end
end
if 1
if opts.include_tca
methods{end+1} = struct(...
'name', 'TCA',...
'method', @predict_da_toolbox,...
'args', {{'ftTrans_tca'}});
end
if 1
if opts.include_gfk
methods{end+1} = struct(...
'name', 'GFK',...
'method', @predict_gfk,...
'args', {{'svm',1}});
end
if opts.include_flda
methods{end+1} = struct(...
'name', 'FLDA-L',...
'method', @predict_flda_cv,...
Expand All @@ -37,8 +55,4 @@
'method', @predict_coral,...
'args', {{}});
end

if 1
methods = [methods, results_from_coral_paper()];
end
end
13 changes: 13 additions & 0 deletions src/evaluation/dummy_methods.m
@@ -0,0 +1,13 @@
function methods = dummy_methods(varargin)
% Methods for which to use results_from_papers
if nargin == 0
names = {'GFK','SA','CORAL'};
else
names = varargin;
end

methods = {};
for i=1:numel(names)
methods{end+1} = struct('dummy',true, 'name',names{i});
end
end
8 changes: 7 additions & 1 deletion src/evaluation/encode_parameters.m
Expand Up @@ -12,13 +12,19 @@
end
elseif isstruct(x)
out = '';
fields = fieldnames(x)
fields = fieldnames(x);
for i=1:numel(fields)
if i>1, out = [out,'-']; end;
out = [out,fields{i},'=',encode_parameters(getfield(x,fields{i}))];
end
elseif ischar(x)
out = x;
elseif isnumeric(x) && ~isscalar(x)
out = '';
for i=1:numel(x)
if i>1, out = [out,',']; end;
out = [out,encode_parameters(x(i))];
end
elseif isscalar(x) && isnumeric(x)
out = sprintf('%g',x);
elseif isscalar(x) && islogical(x)
Expand Down

0 comments on commit 0fc35b3

Please sign in to comment.