-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Kaili Zhao
committed
Jul 23, 2018
0 parents
commit 760cd34
Showing
21 changed files
with
1,932 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
### Intro | ||
|
||
This repository provides the Matlab implementation for the CVPR18 paper, "[Learning Facial Action Units From Web Images With Scalable Weakly Supervised Clustering](http://openaccess.thecvf.com/content_cvpr_2018/CameraReady/0237.pdf)". This code has two goals: | ||
|
||
1. Learn a **weakly-supervised spectral embedding (WSE)**, which considers the coherence between visual similarity and weak annotations (Sec 3.1 in the paper). | ||
|
||
2. Re-annotate noisy images using **rank-order clustering** and majority voting (Sec 3.2 in the paper). We also provide **uMQI** metric to automatically determine the number of clusters. This part will be released soon. | ||
|
||
|
||
### Dependencies | ||
We use the [FLANN](https://www.cs.ubc.ca/research/flann/) library to compute K nearest neighbors to construct the affinity matrix for WSE and rank-order clustering. Before using this code, please download FLANN library and add the path to `addpaths.m`. | ||
|
||
|
||
### Getting started | ||
|
||
To run the toy demo (as Fig. 2 in the paper). Run the command in Matlab: | ||
|
||
``` matlab | ||
>> demo_toy | ||
``` | ||
|
||
Then you should be able to see the results from the classical clustering problem: | ||
![demo_toy](figures/demo_toy.gif) | ||
|
||
|
||
### More info | ||
|
||
* **Contact**: Please send comments or bugs to Kaili Zhao ([kailizhao@bupt.edu.cn](kailizhao@bupt.edu.cn)). | ||
* **Citation**: If you use this code in your paper, please cite the following: | ||
|
||
``` | ||
@inproceedings{zhao2018learning, | ||
title={Learning Facial Action Units From Web Images With Scalable Weakly Supervised Clustering}, | ||
author={Zhao, Kaili and Chu, Wen-Sheng and Martinez, Aleix M.}, | ||
booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition}, | ||
year={2018} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
function addpaths | ||
|
||
addpath(genpath(pwd)) | ||
addpath(genpath('/home/kaili/code/flann-1.8.4-src/')) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
% This demo shows WSE optimization process as Fig 2 in [1]. | ||
% | ||
% Please feel free to contact me regarding bugs and suggestions. | ||
% | ||
% Contact: Kaili Zhao (kailizhao@bupt.edu.cn) | ||
% Paper: http://openaccess.thecvf.com/content_cvpr_2018/CameraReady/0237.pdf | ||
% Reference: | ||
% [1] "Learning Facial Action Units from Web Images with Scalable Weakly | ||
% Supervised Cluster, " in CVPR 2018. | ||
|
||
clc; clear all; addpaths; | ||
|
||
%% Data preparation | ||
% Get data | ||
dataset = 'toy'; | ||
[feat, label, wlbl] = get_data(dataset); | ||
|
||
% Get graph in adjanceny matrix A | ||
A = get_graph( ... | ||
feat, dataset, 'graph_type', 'mutual_knn', ... | ||
'nn_opt', 'pdist', 'num_nn', 100); | ||
|
||
%% Run WSE | ||
config = get_config('wse', 'toy'); | ||
[clusters, W, obj] = scalable_wse(A, wlbl, config); | ||
|
||
%% Display WSE distribution at different iterations | ||
plot_toy_example(obj, clusters, feat, label, wlbl); |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
function [varargout]=get_args(pnames,dflts,varargin) | ||
%GETARGS Process parameter name/value pairs | ||
|
||
|
||
% Initialize some variables | ||
emsg = ''; | ||
eid = ''; | ||
nparams = length(pnames); | ||
varargout = dflts; | ||
unrecog = {}; | ||
nargs = length(varargin); | ||
|
||
% Must have name/value pairs | ||
if mod(nargs,2)~=0 | ||
eid = 'WrongNumberArgs'; | ||
emsg = 'Wrong number of arguments.'; | ||
else | ||
% Process name/value pairs | ||
for j=1:2:nargs | ||
pname = varargin{j}; | ||
if ~ischar(pname) | ||
eid = 'BadParamName'; | ||
emsg = 'Parameter name must be text.'; | ||
break; | ||
end | ||
i = strcmpi(pname,pnames); | ||
i = find(i); | ||
if isempty(i) | ||
% if they've asked to get back unrecognized names/values, add this | ||
% one to the list | ||
if nargout > nparams+2 | ||
unrecog((end+1):(end+2)) = {varargin{j} varargin{j+1}}; | ||
% otherwise, it's an error | ||
else | ||
eid = 'BadParamName'; | ||
emsg = sprintf('Invalid parameter name: %s.',pname); | ||
break; | ||
end | ||
elseif length(i)>1 | ||
eid = 'BadParamName'; | ||
emsg = sprintf('Ambiguous parameter name: %s.',pname); | ||
break; | ||
else | ||
varargout{i} = varargin{j+1}; | ||
end | ||
end | ||
end | ||
|
||
varargout{nparams+1} = unrecog; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
function config = get_config(alg, dataset) | ||
% Get config for different algorithms and datasets | ||
|
||
switch alg | ||
case 'wse' | ||
switch dataset | ||
case 'toy' | ||
config.optimizer = 'agd'; | ||
config.embedding_dim = 2; | ||
config.type_laplacian = 'norm'; | ||
config.speedup = 'fast'; | ||
config.epoch = 50; | ||
config.graph_partition = 2; | ||
config.lambda = 0.3413; | ||
config.log_opt.is_log = true; | ||
config.log_opt.num_clusters = 2; | ||
config.log_opt.disp_steps = 5; | ||
|
||
case 'au_data' | ||
config.optimizer = 'agd'; | ||
config.embedding_dim = 10; | ||
config.type_laplacian = 'norm'; | ||
config.speedup = 'fast'; | ||
config.epoch = 50; | ||
config.graph_partition = 2; | ||
config.lambda = 1; | ||
config.log_opt.is_log = false; | ||
|
||
case 'mnist' | ||
config.optimizer = 'agd'; | ||
config.embedding_dim = 2; | ||
config.type_laplacian = 'norm'; | ||
config.speedup = 'fast'; | ||
config.epoch = 50; | ||
config.graph_partition = 2; | ||
config.lambda = 10000; | ||
config.lambda = 0.1; | ||
config.log_opt.is_log = true; | ||
config.log_opt.num_clusters = 2; | ||
config.log_opt.disp_steps = 1; | ||
end | ||
|
||
case 'reannotation' | ||
switch dataset | ||
case 'au_data' | ||
config.batch = 1000; | ||
config.num_knn = 50; | ||
config.build_params.algorithm = 'kdtree'; | ||
config.build_params.trees = 1000; | ||
config.build_params.checks = 100; | ||
config.remove_outlier = false; | ||
end | ||
end | ||
|
||
% Display config | ||
fprintf('Loaded config for "%s":\n', dataset); | ||
T = struct2table(config); | ||
new_T = cell2table(table2cell(T)'); | ||
new_T.Properties.RowNames = T.Properties.VariableNames; | ||
new_T.Properties.VariableNames{'Var1'} = 'Value'; | ||
disp(new_T); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
function [feat, label, wlbl] = get_data(dataset) | ||
% Get feature and label from dataset | ||
% Also load weak labels (wlbl) if necessary | ||
|
||
db_file = fullfile('data', [dataset, '.mat']); | ||
err_str = sprintf('Dataset %s does not exist', db_file); | ||
assert(logical(exist(db_file, 'file')), err_str); | ||
|
||
db = load(db_file); | ||
|
||
switch dataset | ||
case 'toy' | ||
feat = db.feat; | ||
label = db.label; | ||
|
||
% Get weak labels by perturbing ground-truth labels | ||
psnr = 0.3; % Perturb labels for 30% samples in each class | ||
wlbl = perturb_labels(label, psnr); | ||
|
||
case 'au' | ||
feat = db.feat'; | ||
label = db.label; | ||
wlbl = db.predL; % Get pre-computed weak labels | ||
|
||
case 'mnist' | ||
num_class = 2; | ||
labels = double(db.labels); | ||
inds = cell(num_class, 1); | ||
n = 100; % num to sample per class | ||
for iclass = 1:num_class | ||
ind = find(labels == (iclass-1)); | ||
inds{iclass} = ind(randperm(length(ind), n)); | ||
end | ||
inds = cell2mat(inds)'; | ||
inds = inds(:); | ||
feat = reshape(db.images(inds, :, :), [num_class*n, 784]); | ||
feat = double(feat'); | ||
label = labels(inds); | ||
|
||
% Get weak labels by perturbing ground-truth labels | ||
psnr = 0; % Perturb labels for 30% samples in each class | ||
wlbl = perturb_labels(label, psnr); | ||
end | ||
|
||
fprintf('Loaded data from %s\n', db_file); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
function val = get_field(struct_var, field_name, default_val) | ||
|
||
if isfield(struct_var, field_name) | ||
val = struct_var.(field_name); | ||
else | ||
val = default_val; | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
function A = get_graph(X, dataset, varargin) | ||
% Get graph in adjanceny matrix A | ||
% | ||
% Parameters are: | ||
% | ||
% 'graph_type' - Define the type of similarity graph | ||
% 'full' - Full Similarity Graph | ||
% 'mutual_knn' - Mutual kNeares Neighbors Graph [default] | ||
% 'knn' - kNearest Neighbors Graph | ||
% | ||
% 'nn_opt' - Option for computing k-NN | ||
% 'pdist' [default] | ||
% 'flann' | ||
% | ||
% 'is_load_graph' - Load pre-computed graph or not | ||
% false [default] | ||
% | ||
% 'num_nn' - Number of nearest neighbors to build graph | ||
% 100 [default] | ||
% 0.1 - If num_nn is in [0, 1], we used num_nn as the % of all samples | ||
|
||
% Setup | ||
pnames = {'graph_type', 'nn_opt', 'is_load_graph', 'num_nn'}; | ||
dflts = {'mutual_knn', 'pdist', false, 100}; | ||
[graph_type, nn_opt, is_load_graph, k] = ... | ||
get_args(pnames, dflts, varargin{:}); | ||
|
||
% Init | ||
n = size(X, 2); | ||
if k < 1 | ||
k = round(k * n); | ||
end | ||
|
||
output_dir = fullfile('results', 'graph'); | ||
if ~exist(output_dir, 'dir') | ||
mkdir(output_dir); | ||
end | ||
savename = fullfile(output_dir, ['gh_', dataset, '.mat']); | ||
has_saved_file = logical(exist(savename, 'file')); | ||
|
||
if has_saved_file && is_load_graph | ||
graph = load(savename); | ||
fprintf('Loaded graph from %s\n', savename); | ||
A = graph.A; | ||
|
||
elseif strcmp(nn_opt, 'pdist') | ||
fprintf('Computing graph using pdist with %d neighbors ... ', k); | ||
tic; | ||
dist = pdist2(X', X'); | ||
isnn = false(n); | ||
|
||
% Create directed neighbor graph | ||
for iRow = 1:n | ||
[val, idx] = sort(dist(iRow, :), 'ascend'); | ||
isnn(iRow, idx(1:k+1)) = true; | ||
end | ||
knndist = sparse(n, n); | ||
knndist(isnn) = dist(isnn); | ||
clear dist; | ||
if strcmp(graph_type, 'mutual_knn') | ||
knndist = min(knndist, knndist'); | ||
elseif strcmp(graph_type, 'knn') | ||
knndist = max(knndist, knndist'); | ||
end | ||
|
||
sigma = median(knndist(isnn)); % Gaussian parameter | ||
A = spfun(@(knndist) (sim_gaussian(knndist, sigma)), knndist); | ||
save(savename, 'A'); | ||
fprintf('done in %.2f secs\n', toc); | ||
|
||
elseif strcmp(nn_opt, 'flann') | ||
build_params.algorithm = 'kdtree'; | ||
build_params.trees = 100; | ||
k = 100; % Number of nearest neighbors | ||
flann_set_distance_type('euclidean'); | ||
|
||
fprintf('Computing graph using flann with %d neighbors ... ', k); | ||
[index, parameters, speedup] = flann_build_index(X, build_params); | ||
[result, dist] = flann_search(index, X, k, parameters); | ||
flann_free_index(index); | ||
idx = result + repmat([0:n:(n-1)*n], [k,1]); | ||
keyboard | ||
knndist = sparse(idx(:), 1, dist(:), n*n, 1); | ||
knndist = reshape(knndist, [n ,n]); | ||
if strcmp(graph_type, 'mutual_knn') | ||
knndist = min(knndist, knndist'); | ||
elseif strcmp(graph_type, 'knn') | ||
knndist = max(knndist, knndist'); | ||
end | ||
varparam = 1; % Gaussian parameter | ||
A = spfun(@(knndist) (sim_gaussian(knndist, varparam)), knndist); | ||
fprintf('done\n'); | ||
|
||
dumvar = 0; | ||
save(savename, 'A', 'dumvar', '-v7.3'); | ||
end |
Oops, something went wrong.