This repository has been archived by the owner on Apr 20, 2021. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
Adding tracking precision-recall analysis.
- Loading branch information
Showing
2 changed files
with
368 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
function [result] = analyze_precision_recall(experiment, trackers, sequences, varargin) | ||
% analyze_precision_recall Performs tracking precision and recall analysis | ||
% | ||
% Performs tracking precision-recall analysis for a given experiment on a set trackers and sequences. | ||
% | ||
% Input: | ||
% - experiment (structure): A valid experiment structures. | ||
% - trackers (cell): A cell array of valid tracker descriptor structures. | ||
% - sequences (cell): A cell array of valid sequence descriptor structures. | ||
% - varargin[Tags] (cell): An array of tag names that should be used | ||
% instead of sequences. | ||
% | ||
% Output: | ||
% - result (structure): A structure with the following fields | ||
% - | ||
% - lengths: number of frames for individual selectors | ||
% - tags: names of individual selectors | ||
|
||
resolution = []; | ||
tags = {}; | ||
|
||
for i = 1:2:length(varargin) | ||
switch lower(varargin{i}) | ||
case 'tags' | ||
tags = varargin{i+1}; | ||
case 'resolution' | ||
resolution = varargin{i+1}; | ||
otherwise | ||
error(['Unknown switch ', varargin{i},'!']) ; | ||
end | ||
end | ||
|
||
print_text('Tracking precision-recall analysis for experiment %s ...', experiment.name); | ||
|
||
print_indent(1); | ||
|
||
experiment_sequences = convert_sequences(sequences, experiment.converter); | ||
|
||
if ~isempty(tags) | ||
|
||
tags = unique(tags); % Remove any potential duplicates. | ||
|
||
selectors = sequence_tag_selectors(experiment, ... | ||
experiment_sequences, tags); | ||
|
||
else | ||
|
||
selectors = sequence_selectors(experiment, experiment_sequences); | ||
|
||
end; | ||
|
||
result.curves = cell(numel(trackers), numel(selectors)); | ||
result.fmeasure = zeros(numel(trackers), numel(selectors)); | ||
result.selectors = cellfun(@(x) x.name, selectors, 'UniformOutput', false); | ||
|
||
|
||
for i = 1:numel(trackers) | ||
|
||
print_text('Tracker %s', trackers{i}.identifier); | ||
|
||
thresholds = []; | ||
|
||
if ~isempty(resolution) | ||
thresholds = determine_thresholds(experiment, trackers{i}, experiment_sequences, resolution); | ||
end | ||
|
||
for s = 1:numel(selectors) | ||
|
||
[curves, fmeasure] = calculate_tpr_fscore(selectors{s}, experiment, trackers{i}, experiment_sequences, thresholds); | ||
|
||
result.curves{i, s} = curves; | ||
result.fmeasure(i, s) = fmeasure; | ||
|
||
end; | ||
|
||
end; | ||
|
||
print_indent(-1); | ||
|
||
end | ||
|
||
function [thresholds] = determine_thresholds(experiment, tracker, sequences, resolution) | ||
|
||
confidence_name = 'confidence'; | ||
|
||
if isfield(tracker.metadata, 'confidence') | ||
|
||
confidence_name = tracker.metadata.confidence; | ||
|
||
end | ||
|
||
selector = sequence_tag_selectors(experiment, sequences, {'all'}); | ||
|
||
values = selector{1}.results_values(experiment, tracker, sequences, confidence_name); | ||
|
||
certanty = zeros(sum(cellfun(@numel, values(:, 1), 'UniformOutput', true)), size(values, 2)); | ||
|
||
i = 1; | ||
|
||
for s = 1:size(values, 1) | ||
|
||
for r = 1:size(values, 2) | ||
|
||
if isempty(values{s, r}) | ||
continue; | ||
end; | ||
|
||
certanty(i:i+size(values{s, r})-1, r) = values{s, r}; | ||
|
||
end; | ||
end; | ||
|
||
thresholds = unique(certanty(~isnan(certanty))); | ||
|
||
if numel(thresholds) > resolution | ||
delta = floor(numel(thresholds) / (resolution - 2)); | ||
idxs = round(linspace(delta, numel(thresholds)-delta, resolution-2)); | ||
thresholds = thresholds(idxs); | ||
end | ||
|
||
thresholds = [-Inf; thresholds; Inf]; | ||
end | ||
|
||
function [curve, fmeasure, fbest] = calculate_tpr_fscore(selector, experiment, tracker, sequences, thresholds) | ||
|
||
confidence_name = 'confidence'; | ||
confidence_inverse = false; | ||
|
||
if isfield(tracker.metadata, 'confidence') | ||
|
||
confidence_name = tracker.metadata.confidence; | ||
|
||
end | ||
|
||
if isfield(tracker.metadata, 'confidence_inverse') | ||
|
||
confidence_inverse = tracker.metadata.confidence_inverse; | ||
|
||
end | ||
|
||
groundtruth = selector.groundtruth(sequences); | ||
trajectories = selector.results(experiment, tracker, sequences); | ||
|
||
values = selector.results_values(experiment, tracker, sequences, confidence_name); | ||
|
||
overlaps = zeros(sum(cellfun(@numel, groundtruth, 'UniformOutput', true)), size(trajectories, 2)); | ||
certanty = zeros(size(overlaps)); | ||
|
||
i = 1; | ||
|
||
N = 0; | ||
|
||
for s = 1:numel(groundtruth) | ||
|
||
for r = 1:size(trajectories, 2) | ||
|
||
if isempty(trajectories{s, r}) | ||
continue; | ||
end; | ||
|
||
[~, frames] = estimate_accuracy(trajectories{s, r}, groundtruth{s}, ... | ||
'BindWithin', [sequences{s}.width, sequences{s}.height]); | ||
|
||
frames(isnan(frames)) = 0; | ||
|
||
overlaps(i:i+size(groundtruth{s})-1, r) = frames; | ||
certanty(i:i+size(groundtruth{s})-1, r) = values{s, r}; | ||
|
||
end; | ||
|
||
i = i + size(groundtruth{s}); | ||
|
||
if ~isempty(groundtruth{s}) | ||
N = N + sum(cellfun(@(x) numel(x) > 1, groundtruth{s}, 'UniformOutput', true)); | ||
end; | ||
end; | ||
|
||
if isempty(thresholds) | ||
thresholds = certanty(~isnan(certanty)); | ||
end | ||
|
||
thresholds = sort(thresholds, iff(confidence_inverse, 'descend', 'ascend')); | ||
|
||
curve = zeros(numel(thresholds), 3); | ||
|
||
curve(:, 3) = thresholds; | ||
|
||
for k = 1:numel(thresholds) | ||
|
||
% indicator vector where to calculate Pr-Re | ||
subset = certanty >= thresholds(k); | ||
|
||
if sum(subset) == 0 | ||
% special case - no prediction is made: | ||
% Precision is 1 and Recall is 0 | ||
curve(k,1) = 1; | ||
curve(k,2) = 0; | ||
else | ||
curve(k, 1) = mean(overlaps(subset)); | ||
curve(k, 2) = sum(overlaps(subset)) ./ N; | ||
end | ||
|
||
end | ||
|
||
f = 2 * (curve(:, 1) .* curve(:, 2)) ./ (curve(:, 1) + curve(:, 2)); | ||
|
||
[fmax, fidx] = max(f); | ||
|
||
fmeasure = fmax; | ||
fbest = thresholds(fidx); | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
function [document, scores] = report_precision_recall(context, experiment, trackers, sequences, varargin) | ||
% report_overlap Generate a report using tracking precision recall methodology | ||
% | ||
% Performs tracking precision-recall analysis and generates a report based on the results. | ||
% | ||
% Input: | ||
% - context (structure): Report context structure. | ||
% - experiment (struct): An experiment structure. | ||
% - trackers (cell): An array of tracker structures. | ||
% - sequences (cell): An array of sequence structures. | ||
% - varargin[UseTags] (boolean): Analyze according to tags (otherwise according to sequences). | ||
% - varargin[HideLegend] (boolean): Hide legend in plots. | ||
% | ||
% Output: | ||
% - document (structure): Resulting document structure. | ||
% - scores (struct): A scores structure. | ||
% | ||
|
||
usetags = get_global_variable('report_tags', true); | ||
hidelegend = get_global_variable('report_lagend_hide', false); | ||
resolution = 100; | ||
|
||
for i = 1:2:length(varargin) | ||
switch lower(varargin{i}) | ||
case 'usetags' | ||
usetags = varargin{i+1}; | ||
case 'resolutuion' | ||
resolution = varargin{i+1}; | ||
case 'hidelegend' | ||
hidelegend = varargin{i+1}; | ||
otherwise | ||
error(['Unknown switch ', varargin{i}, '!']) ; | ||
end | ||
end | ||
|
||
|
||
if ~strcmp(experiment.type, 'unsupervised') | ||
error('Tracking precision-recall analysis only suitable for unsupervised experiments!'); | ||
end | ||
|
||
document = document_create(context, 'tpr', 'title', 'Tracking precision recall'); | ||
|
||
trackers_hash = md5hash(strjoin((cellfun(@(x) x.identifier, trackers, 'UniformOutput', false)), '-'), 'Char', 'hex'); | ||
parameters_hash = md5hash(sprintf('%d%d', usetags, resolution)); | ||
|
||
tags = {}; | ||
|
||
if isempty(experiment.tags) | ||
usetags = false; | ||
end; | ||
|
||
if usetags && isfield(experiment, 'tags') | ||
tags = union(experiment.tags, {'all'}); | ||
sequences_hash = md5hash(strjoin(tags, '-'), 'Char', 'hex'); | ||
else | ||
sequences_hash = md5hash(strjoin((cellfun(@(x) x.name, sequences, 'UniformOutput', false)), '-'), 'Char', 'hex'); | ||
end; | ||
|
||
cache_identifier = sprintf('tpr_%s_%s_%s_%s', experiment.name, trackers_hash, sequences_hash, parameters_hash); | ||
|
||
result = document_cache(context, cache_identifier, @analyze_precision_recall, experiment, trackers, ... | ||
sequences, 'Tags', tags, 'Resolution', resolution); | ||
|
||
if usetags | ||
% When using tags we have inserted a separate one for this | ||
mask = strcmp('tag_all', result.selectors); | ||
|
||
average_curve = result.curves(:, mask); | ||
average_fmeasure = result.fmeasure(:, mask); | ||
|
||
% Now remove the 'all' tag from results | ||
tag_curve = result.curves(:, ~mask); | ||
tag_fmeasure = result.fmeasure(:, ~mask); | ||
|
||
selector_tags = cat(2, result.selectors(~mask), result.selectors(mask)); | ||
|
||
else | ||
|
||
average_curve = cell(numel(trackers), 1); | ||
average_fmeasure = zeros(numel(trackers), 1); | ||
|
||
for t = 1:numel(trackers) | ||
average_curve{t} = mean(cat(3, result.curves{t, :}), 3); | ||
f = 2 * (average_curve{t}(:, 1) .* average_curve{t}(:, 2)) ./ (average_curve{t}(:, 1) + average_curve{t}(:, 2)); | ||
[average_fmeasure(t), ~] = max(f); | ||
end; | ||
|
||
tag_curve = result.curves; | ||
tag_fmeasure = result.fmeasure; | ||
|
||
selector_tags = result.selectors; | ||
|
||
end | ||
|
||
scores.name = 'TPR'; | ||
scores.values = average_fmeasure; | ||
scores.ids = {'f'}; | ||
scores.names = {'F'}; | ||
scores.order = {'descending'}; | ||
|
||
tracker_labels = cellfun(@(x) iff(isfield(x.metadata, 'verified') && x.metadata.verified, [x.label, '*'], x.label), trackers, 'UniformOutput', 0); | ||
|
||
print_text('Writing tracking precition-recall table ...'); | ||
|
||
document.section('Experiment %s', experiment.name); | ||
|
||
pr_plot(document, sprintf('%s_average', experiment.name), ... | ||
sprintf('Experiment %s (average)', experiment.name), ... | ||
trackers, average_curve, hidelegend); | ||
|
||
table_data = highlight_best_rows(num2cell(cat(2, tag_fmeasure, average_fmeasure)), repmat({'descending'}, 1, size(tag_fmeasure, 2) + 1)); | ||
|
||
|
||
document.table(table_data, 'columnLabels', selector_tags, 'rowLabels', tracker_labels, 'title', 'Tracking precision-recall overview'); | ||
|
||
document.subsection('Detailed plots'); | ||
|
||
for t = 1:size(tag_curve, 2) | ||
|
||
plot_title = sprintf('Tracking precision-recall plot for tag %s in experiment %s', ... | ||
selector_tags{t}, experiment.name); | ||
plot_id = sprintf('overlap_%s_%s', experiment.name, selector_tags{t}); | ||
|
||
pr_plot(document, plot_id, plot_title, trackers, tag_curve(:, t), ~hidelegend); | ||
|
||
end; | ||
|
||
document.write(); | ||
|
||
end | ||
|
||
function pr_plot(document, identifier, title, trackers, curves, hidelegend) | ||
|
||
handle = plot_blank('Visible', false, 'Title', 'Overlap', 'Width', 6, 'Height', 6); hold on; | ||
|
||
phandles = zeros(numel(trackers), 1); | ||
|
||
for t = 1:numel(curves) | ||
phandles(t) = plot(curves{t}(:, 2), curves{t}(:, 1), 'Color', trackers{t}.style.color); | ||
end; | ||
|
||
labels = cellfun(@(x) x.label, trackers, 'UniformOutput', false); | ||
|
||
if ~hidelegend | ||
legend(phandles, labels); | ||
end; | ||
|
||
xlabel('Tracking recall'); | ||
ylabel('Tracking precision'); | ||
xlim([0, 1]); | ||
ylim([0, 1]); | ||
hold off; | ||
document.figure(handle, identifier, title); | ||
|
||
close(handle); | ||
end |