Skip to content
This repository has been archived by the owner on Apr 20, 2021. It is now read-only.

Commit

Permalink
Browse files Browse the repository at this point in the history
Adding tracking precision-recall analysis.
  • Loading branch information
lukacu committed Apr 25, 2018
1 parent 46de65f commit 7d59554
Show file tree
Hide file tree
Showing 2 changed files with 368 additions and 0 deletions.
212 changes: 212 additions & 0 deletions analysis/analyze_precision_recall.m
@@ -0,0 +1,212 @@
function [result] = analyze_precision_recall(experiment, trackers, sequences, varargin)
% analyze_precision_recall Performs tracking precision and recall analysis
%
% Performs tracking precision-recall analysis for a given experiment on a set trackers and sequences.
%
% Input:
% - experiment (structure): A valid experiment structures.
% - trackers (cell): A cell array of valid tracker descriptor structures.
% - sequences (cell): A cell array of valid sequence descriptor structures.
% - varargin[Tags] (cell): An array of tag names that should be used
% instead of sequences.
%
% Output:
% - result (structure): A structure with the following fields
% -
% - lengths: number of frames for individual selectors
% - tags: names of individual selectors

resolution = [];
tags = {};

for i = 1:2:length(varargin)
switch lower(varargin{i})
case 'tags'
tags = varargin{i+1};
case 'resolution'
resolution = varargin{i+1};
otherwise
error(['Unknown switch ', varargin{i},'!']) ;
end
end

print_text('Tracking precision-recall analysis for experiment %s ...', experiment.name);

print_indent(1);

experiment_sequences = convert_sequences(sequences, experiment.converter);

if ~isempty(tags)

tags = unique(tags); % Remove any potential duplicates.

selectors = sequence_tag_selectors(experiment, ...
experiment_sequences, tags);

else

selectors = sequence_selectors(experiment, experiment_sequences);

end;

result.curves = cell(numel(trackers), numel(selectors));
result.fmeasure = zeros(numel(trackers), numel(selectors));
result.selectors = cellfun(@(x) x.name, selectors, 'UniformOutput', false);


for i = 1:numel(trackers)

print_text('Tracker %s', trackers{i}.identifier);

thresholds = [];

if ~isempty(resolution)
thresholds = determine_thresholds(experiment, trackers{i}, experiment_sequences, resolution);
end

for s = 1:numel(selectors)

[curves, fmeasure] = calculate_tpr_fscore(selectors{s}, experiment, trackers{i}, experiment_sequences, thresholds);

result.curves{i, s} = curves;
result.fmeasure(i, s) = fmeasure;

end;

end;

print_indent(-1);

end

function [thresholds] = determine_thresholds(experiment, tracker, sequences, resolution)

confidence_name = 'confidence';

if isfield(tracker.metadata, 'confidence')

confidence_name = tracker.metadata.confidence;

end

selector = sequence_tag_selectors(experiment, sequences, {'all'});

values = selector{1}.results_values(experiment, tracker, sequences, confidence_name);

certanty = zeros(sum(cellfun(@numel, values(:, 1), 'UniformOutput', true)), size(values, 2));

i = 1;

for s = 1:size(values, 1)

for r = 1:size(values, 2)

if isempty(values{s, r})
continue;
end;

certanty(i:i+size(values{s, r})-1, r) = values{s, r};

end;
end;

thresholds = unique(certanty(~isnan(certanty)));

if numel(thresholds) > resolution
delta = floor(numel(thresholds) / (resolution - 2));
idxs = round(linspace(delta, numel(thresholds)-delta, resolution-2));
thresholds = thresholds(idxs);
end

thresholds = [-Inf; thresholds; Inf];
end

function [curve, fmeasure, fbest] = calculate_tpr_fscore(selector, experiment, tracker, sequences, thresholds)

confidence_name = 'confidence';
confidence_inverse = false;

if isfield(tracker.metadata, 'confidence')

confidence_name = tracker.metadata.confidence;

end

if isfield(tracker.metadata, 'confidence_inverse')

confidence_inverse = tracker.metadata.confidence_inverse;

end

groundtruth = selector.groundtruth(sequences);
trajectories = selector.results(experiment, tracker, sequences);

values = selector.results_values(experiment, tracker, sequences, confidence_name);

overlaps = zeros(sum(cellfun(@numel, groundtruth, 'UniformOutput', true)), size(trajectories, 2));
certanty = zeros(size(overlaps));

i = 1;

N = 0;

for s = 1:numel(groundtruth)

for r = 1:size(trajectories, 2)

if isempty(trajectories{s, r})
continue;
end;

[~, frames] = estimate_accuracy(trajectories{s, r}, groundtruth{s}, ...
'BindWithin', [sequences{s}.width, sequences{s}.height]);

frames(isnan(frames)) = 0;

overlaps(i:i+size(groundtruth{s})-1, r) = frames;
certanty(i:i+size(groundtruth{s})-1, r) = values{s, r};

end;

i = i + size(groundtruth{s});

if ~isempty(groundtruth{s})
N = N + sum(cellfun(@(x) numel(x) > 1, groundtruth{s}, 'UniformOutput', true));
end;
end;

if isempty(thresholds)
thresholds = certanty(~isnan(certanty));
end

thresholds = sort(thresholds, iff(confidence_inverse, 'descend', 'ascend'));

curve = zeros(numel(thresholds), 3);

curve(:, 3) = thresholds;

for k = 1:numel(thresholds)

% indicator vector where to calculate Pr-Re
subset = certanty >= thresholds(k);

if sum(subset) == 0
% special case - no prediction is made:
% Precision is 1 and Recall is 0
curve(k,1) = 1;
curve(k,2) = 0;
else
curve(k, 1) = mean(overlaps(subset));
curve(k, 2) = sum(overlaps(subset)) ./ N;
end

end

f = 2 * (curve(:, 1) .* curve(:, 2)) ./ (curve(:, 1) + curve(:, 2));

[fmax, fidx] = max(f);

fmeasure = fmax;
fbest = thresholds(fidx);

end
156 changes: 156 additions & 0 deletions report/report_precision_recall.m
@@ -0,0 +1,156 @@
function [document, scores] = report_precision_recall(context, experiment, trackers, sequences, varargin)
% report_overlap Generate a report using tracking precision recall methodology
%
% Performs tracking precision-recall analysis and generates a report based on the results.
%
% Input:
% - context (structure): Report context structure.
% - experiment (struct): An experiment structure.
% - trackers (cell): An array of tracker structures.
% - sequences (cell): An array of sequence structures.
% - varargin[UseTags] (boolean): Analyze according to tags (otherwise according to sequences).
% - varargin[HideLegend] (boolean): Hide legend in plots.
%
% Output:
% - document (structure): Resulting document structure.
% - scores (struct): A scores structure.
%

usetags = get_global_variable('report_tags', true);
hidelegend = get_global_variable('report_lagend_hide', false);
resolution = 100;

for i = 1:2:length(varargin)
switch lower(varargin{i})
case 'usetags'
usetags = varargin{i+1};
case 'resolutuion'
resolution = varargin{i+1};
case 'hidelegend'
hidelegend = varargin{i+1};
otherwise
error(['Unknown switch ', varargin{i}, '!']) ;
end
end


if ~strcmp(experiment.type, 'unsupervised')
error('Tracking precision-recall analysis only suitable for unsupervised experiments!');
end

document = document_create(context, 'tpr', 'title', 'Tracking precision recall');

trackers_hash = md5hash(strjoin((cellfun(@(x) x.identifier, trackers, 'UniformOutput', false)), '-'), 'Char', 'hex');
parameters_hash = md5hash(sprintf('%d%d', usetags, resolution));

tags = {};

if isempty(experiment.tags)
usetags = false;
end;

if usetags && isfield(experiment, 'tags')
tags = union(experiment.tags, {'all'});
sequences_hash = md5hash(strjoin(tags, '-'), 'Char', 'hex');
else
sequences_hash = md5hash(strjoin((cellfun(@(x) x.name, sequences, 'UniformOutput', false)), '-'), 'Char', 'hex');
end;

cache_identifier = sprintf('tpr_%s_%s_%s_%s', experiment.name, trackers_hash, sequences_hash, parameters_hash);

result = document_cache(context, cache_identifier, @analyze_precision_recall, experiment, trackers, ...
sequences, 'Tags', tags, 'Resolution', resolution);

if usetags
% When using tags we have inserted a separate one for this
mask = strcmp('tag_all', result.selectors);

average_curve = result.curves(:, mask);
average_fmeasure = result.fmeasure(:, mask);

% Now remove the 'all' tag from results
tag_curve = result.curves(:, ~mask);
tag_fmeasure = result.fmeasure(:, ~mask);

selector_tags = cat(2, result.selectors(~mask), result.selectors(mask));

else

average_curve = cell(numel(trackers), 1);
average_fmeasure = zeros(numel(trackers), 1);

for t = 1:numel(trackers)
average_curve{t} = mean(cat(3, result.curves{t, :}), 3);
f = 2 * (average_curve{t}(:, 1) .* average_curve{t}(:, 2)) ./ (average_curve{t}(:, 1) + average_curve{t}(:, 2));
[average_fmeasure(t), ~] = max(f);
end;

tag_curve = result.curves;
tag_fmeasure = result.fmeasure;

selector_tags = result.selectors;

end

scores.name = 'TPR';
scores.values = average_fmeasure;
scores.ids = {'f'};
scores.names = {'F'};
scores.order = {'descending'};

tracker_labels = cellfun(@(x) iff(isfield(x.metadata, 'verified') && x.metadata.verified, [x.label, '*'], x.label), trackers, 'UniformOutput', 0);

print_text('Writing tracking precition-recall table ...');

document.section('Experiment %s', experiment.name);

pr_plot(document, sprintf('%s_average', experiment.name), ...
sprintf('Experiment %s (average)', experiment.name), ...
trackers, average_curve, hidelegend);

table_data = highlight_best_rows(num2cell(cat(2, tag_fmeasure, average_fmeasure)), repmat({'descending'}, 1, size(tag_fmeasure, 2) + 1));


document.table(table_data, 'columnLabels', selector_tags, 'rowLabels', tracker_labels, 'title', 'Tracking precision-recall overview');

document.subsection('Detailed plots');

for t = 1:size(tag_curve, 2)

plot_title = sprintf('Tracking precision-recall plot for tag %s in experiment %s', ...
selector_tags{t}, experiment.name);
plot_id = sprintf('overlap_%s_%s', experiment.name, selector_tags{t});

pr_plot(document, plot_id, plot_title, trackers, tag_curve(:, t), ~hidelegend);

end;

document.write();

end

function pr_plot(document, identifier, title, trackers, curves, hidelegend)

handle = plot_blank('Visible', false, 'Title', 'Overlap', 'Width', 6, 'Height', 6); hold on;

phandles = zeros(numel(trackers), 1);

for t = 1:numel(curves)
phandles(t) = plot(curves{t}(:, 2), curves{t}(:, 1), 'Color', trackers{t}.style.color);
end;

labels = cellfun(@(x) x.label, trackers, 'UniformOutput', false);

if ~hidelegend
legend(phandles, labels);
end;

xlabel('Tracking recall');
ylabel('Tracking precision');
xlim([0, 1]);
ylim([0, 1]);
hold off;
document.figure(handle, identifier, title);

close(handle);
end

0 comments on commit 7d59554

Please sign in to comment.