Skip to content

Commit

Permalink
Init commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaili Zhao committed Jul 23, 2018
0 parents commit 760cd34
Show file tree
Hide file tree
Showing 21 changed files with 1,932 additions and 0 deletions.
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
### Intro

This repository provides the Matlab implementation for the CVPR18 paper, "[Learning Facial Action Units From Web Images With Scalable Weakly Supervised Clustering](http://openaccess.thecvf.com/content_cvpr_2018/CameraReady/0237.pdf)". This code has two goals:

1. Learn a **weakly-supervised spectral embedding (WSE)**, which considers the coherence between visual similarity and weak annotations (Sec 3.1 in the paper).

2. Re-annotate noisy images using **rank-order clustering** and majority voting (Sec 3.2 in the paper). We also provide **uMQI** metric to automatically determine the number of clusters. This part will be released soon.


### Dependencies
We use the [FLANN](https://www.cs.ubc.ca/research/flann/) library to compute K nearest neighbors to construct the affinity matrix for WSE and rank-order clustering. Before using this code, please download FLANN library and add the path to `addpaths.m`.


### Getting started

To run the toy demo (as Fig. 2 in the paper). Run the command in Matlab:

``` matlab
>> demo_toy
```

Then you should be able to see the results from the classical clustering problem:
![demo_toy](figures/demo_toy.gif)


### More info

* **Contact**: Please send comments or bugs to Kaili Zhao ([kailizhao@bupt.edu.cn](kailizhao@bupt.edu.cn)).
* **Citation**: If you use this code in your paper, please cite the following:

```
@inproceedings{zhao2018learning,
title={Learning Facial Action Units From Web Images With Scalable Weakly Supervised Clustering},
author={Zhao, Kaili and Chu, Wen-Sheng and Martinez, Aleix M.},
booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2018}
}
```
4 changes: 4 additions & 0 deletions addpaths.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
function addpaths

addpath(genpath(pwd))
addpath(genpath('/home/kaili/code/flann-1.8.4-src/'))
Binary file added data/toy.mat
Binary file not shown.
28 changes: 28 additions & 0 deletions demo_toy.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
% This demo shows WSE optimization process as Fig 2 in [1].
%
% Please feel free to contact me regarding bugs and suggestions.
%
% Contact: Kaili Zhao (kailizhao@bupt.edu.cn)
% Paper: http://openaccess.thecvf.com/content_cvpr_2018/CameraReady/0237.pdf
% Reference:
% [1] "Learning Facial Action Units from Web Images with Scalable Weakly
% Supervised Cluster, " in CVPR 2018.

clc; clear all; addpaths;

%% Data preparation
% Get data
dataset = 'toy';
[feat, label, wlbl] = get_data(dataset);

% Get graph in adjanceny matrix A
A = get_graph( ...
feat, dataset, 'graph_type', 'mutual_knn', ...
'nn_opt', 'pdist', 'num_nn', 100);

%% Run WSE
config = get_config('wse', 'toy');
[clusters, W, obj] = scalable_wse(A, wlbl, config);

%% Display WSE distribution at different iterations
plot_toy_example(obj, clusters, feat, label, wlbl);
802 changes: 802 additions & 0 deletions func/utils/colors.m

Large diffs are not rendered by default.

49 changes: 49 additions & 0 deletions func/utils/get_args.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
function [varargout]=get_args(pnames,dflts,varargin)
%GETARGS Process parameter name/value pairs


% Initialize some variables
emsg = '';
eid = '';
nparams = length(pnames);
varargout = dflts;
unrecog = {};
nargs = length(varargin);

% Must have name/value pairs
if mod(nargs,2)~=0
eid = 'WrongNumberArgs';
emsg = 'Wrong number of arguments.';
else
% Process name/value pairs
for j=1:2:nargs
pname = varargin{j};
if ~ischar(pname)
eid = 'BadParamName';
emsg = 'Parameter name must be text.';
break;
end
i = strcmpi(pname,pnames);
i = find(i);
if isempty(i)
% if they've asked to get back unrecognized names/values, add this
% one to the list
if nargout > nparams+2
unrecog((end+1):(end+2)) = {varargin{j} varargin{j+1}};
% otherwise, it's an error
else
eid = 'BadParamName';
emsg = sprintf('Invalid parameter name: %s.',pname);
break;
end
elseif length(i)>1
eid = 'BadParamName';
emsg = sprintf('Ambiguous parameter name: %s.',pname);
break;
else
varargout{i} = varargin{j+1};
end
end
end

varargout{nparams+1} = unrecog;
61 changes: 61 additions & 0 deletions func/utils/get_config.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
function config = get_config(alg, dataset)
% Get config for different algorithms and datasets

switch alg
case 'wse'
switch dataset
case 'toy'
config.optimizer = 'agd';
config.embedding_dim = 2;
config.type_laplacian = 'norm';
config.speedup = 'fast';
config.epoch = 50;
config.graph_partition = 2;
config.lambda = 0.3413;
config.log_opt.is_log = true;
config.log_opt.num_clusters = 2;
config.log_opt.disp_steps = 5;

case 'au_data'
config.optimizer = 'agd';
config.embedding_dim = 10;
config.type_laplacian = 'norm';
config.speedup = 'fast';
config.epoch = 50;
config.graph_partition = 2;
config.lambda = 1;
config.log_opt.is_log = false;

case 'mnist'
config.optimizer = 'agd';
config.embedding_dim = 2;
config.type_laplacian = 'norm';
config.speedup = 'fast';
config.epoch = 50;
config.graph_partition = 2;
config.lambda = 10000;
config.lambda = 0.1;
config.log_opt.is_log = true;
config.log_opt.num_clusters = 2;
config.log_opt.disp_steps = 1;
end

case 'reannotation'
switch dataset
case 'au_data'
config.batch = 1000;
config.num_knn = 50;
config.build_params.algorithm = 'kdtree';
config.build_params.trees = 1000;
config.build_params.checks = 100;
config.remove_outlier = false;
end
end

% Display config
fprintf('Loaded config for "%s":\n', dataset);
T = struct2table(config);
new_T = cell2table(table2cell(T)');
new_T.Properties.RowNames = T.Properties.VariableNames;
new_T.Properties.VariableNames{'Var1'} = 'Value';
disp(new_T);
45 changes: 45 additions & 0 deletions func/utils/get_data.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
function [feat, label, wlbl] = get_data(dataset)
% Get feature and label from dataset
% Also load weak labels (wlbl) if necessary

db_file = fullfile('data', [dataset, '.mat']);
err_str = sprintf('Dataset %s does not exist', db_file);
assert(logical(exist(db_file, 'file')), err_str);

db = load(db_file);

switch dataset
case 'toy'
feat = db.feat;
label = db.label;

% Get weak labels by perturbing ground-truth labels
psnr = 0.3; % Perturb labels for 30% samples in each class
wlbl = perturb_labels(label, psnr);

case 'au'
feat = db.feat';
label = db.label;
wlbl = db.predL; % Get pre-computed weak labels

case 'mnist'
num_class = 2;
labels = double(db.labels);
inds = cell(num_class, 1);
n = 100; % num to sample per class
for iclass = 1:num_class
ind = find(labels == (iclass-1));
inds{iclass} = ind(randperm(length(ind), n));
end
inds = cell2mat(inds)';
inds = inds(:);
feat = reshape(db.images(inds, :, :), [num_class*n, 784]);
feat = double(feat');
label = labels(inds);

% Get weak labels by perturbing ground-truth labels
psnr = 0; % Perturb labels for 30% samples in each class
wlbl = perturb_labels(label, psnr);
end

fprintf('Loaded data from %s\n', db_file);
7 changes: 7 additions & 0 deletions func/utils/get_field.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
function val = get_field(struct_var, field_name, default_val)

if isfield(struct_var, field_name)
val = struct_var.(field_name);
else
val = default_val;
end
96 changes: 96 additions & 0 deletions func/utils/get_graph.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
function A = get_graph(X, dataset, varargin)
% Get graph in adjanceny matrix A
%
% Parameters are:
%
% 'graph_type' - Define the type of similarity graph
% 'full' - Full Similarity Graph
% 'mutual_knn' - Mutual kNeares Neighbors Graph [default]
% 'knn' - kNearest Neighbors Graph
%
% 'nn_opt' - Option for computing k-NN
% 'pdist' [default]
% 'flann'
%
% 'is_load_graph' - Load pre-computed graph or not
% false [default]
%
% 'num_nn' - Number of nearest neighbors to build graph
% 100 [default]
% 0.1 - If num_nn is in [0, 1], we used num_nn as the % of all samples

% Setup
pnames = {'graph_type', 'nn_opt', 'is_load_graph', 'num_nn'};
dflts = {'mutual_knn', 'pdist', false, 100};
[graph_type, nn_opt, is_load_graph, k] = ...
get_args(pnames, dflts, varargin{:});

% Init
n = size(X, 2);
if k < 1
k = round(k * n);
end

output_dir = fullfile('results', 'graph');
if ~exist(output_dir, 'dir')
mkdir(output_dir);
end
savename = fullfile(output_dir, ['gh_', dataset, '.mat']);
has_saved_file = logical(exist(savename, 'file'));

if has_saved_file && is_load_graph
graph = load(savename);
fprintf('Loaded graph from %s\n', savename);
A = graph.A;

elseif strcmp(nn_opt, 'pdist')
fprintf('Computing graph using pdist with %d neighbors ... ', k);
tic;
dist = pdist2(X', X');
isnn = false(n);

% Create directed neighbor graph
for iRow = 1:n
[val, idx] = sort(dist(iRow, :), 'ascend');
isnn(iRow, idx(1:k+1)) = true;
end
knndist = sparse(n, n);
knndist(isnn) = dist(isnn);
clear dist;
if strcmp(graph_type, 'mutual_knn')
knndist = min(knndist, knndist');
elseif strcmp(graph_type, 'knn')
knndist = max(knndist, knndist');
end

sigma = median(knndist(isnn)); % Gaussian parameter
A = spfun(@(knndist) (sim_gaussian(knndist, sigma)), knndist);
save(savename, 'A');
fprintf('done in %.2f secs\n', toc);

elseif strcmp(nn_opt, 'flann')
build_params.algorithm = 'kdtree';
build_params.trees = 100;
k = 100; % Number of nearest neighbors
flann_set_distance_type('euclidean');

fprintf('Computing graph using flann with %d neighbors ... ', k);
[index, parameters, speedup] = flann_build_index(X, build_params);
[result, dist] = flann_search(index, X, k, parameters);
flann_free_index(index);
idx = result + repmat([0:n:(n-1)*n], [k,1]);
keyboard
knndist = sparse(idx(:), 1, dist(:), n*n, 1);
knndist = reshape(knndist, [n ,n]);
if strcmp(graph_type, 'mutual_knn')
knndist = min(knndist, knndist');
elseif strcmp(graph_type, 'knn')
knndist = max(knndist, knndist');
end
varparam = 1; % Gaussian parameter
A = spfun(@(knndist) (sim_gaussian(knndist, varparam)), knndist);
fprintf('done\n');

dumvar = 0;
save(savename, 'A', 'dumvar', '-v7.3');
end
Loading

0 comments on commit 760cd34

Please sign in to comment.