Skip to content
Browse files

Fixed bug in calculating exact r (added rprime). Added the brute force

calculation.  Better argument checking.
  • Loading branch information...
1 parent 00c99e4 commit 5ce3840b0522284403f338e3911f335c6a0c1e16 @MalcolmSlaney MalcolmSlaney committed Jan 24, 2012
Showing with 78 additions and 31 deletions.
  1. +78 −31 CalculateLSHParameters.m
View
109 CalculateLSHParameters.m
@@ -30,39 +30,31 @@
% the Proceeings of the IEEE special issue on Web-Scale Multimedia, July
% 2011.
-% Copyright (c) 2010-2011 Yahoo! Inc. See detailed copyright notice at
+% Copyright (c) 2010-2012 Yahoo! Inc. See detailed copyright notice at
% the bottom of this file.
-if 0 && ~exist('simplepdf','file')
- path(path,'/Users/malcolm/Projects/LSHwithYury/jlab/')
-end
-
debugPlot = 0; % Set to non zero to get debugging plots
%%
%%%%%%%%%%%%%%%%%%% ARGUMENT PARSING %%%%%%%%%%%%%%%%%%%%%%%%%%
-% Create some fake data for testing
-if nargin == 0
- N = 10000;
- D = 4;
-end
-if nargin == 2
- D = dnnHist;
- fprintf('Calculating stats for %d %d-dimensional random points.\n',...
- N, D);
- data = randn(N, D);
- numQueries = min(1000,N);
- nnDistance = zeros(numQueries,1);
- anyDistance = zeros(numQueries,1);
- for i=1:numQueries
- d = sort(sum((repmat(data(i,:),N,1)-data).^2,2));
- nnDistance(i) = d(2);
- anyDistance(i) = d(floor(rand(1,1)*(N-2))+3);
- end
- [dnnHist,dnnBins] = hist(nnDistance, 100);
- [danyHist,danyBins] = hist(anyDistance, 100);
+if nargin < 5
+ fprintf('Syntax: results = CalculateLSHParameters(N, ...\n');
+ fprintf(' dnnHist, dnnBins, danyHist, danyBins, ...\n');
+ fprintf(' deltaTarget, r, uHash, uCheck);\n');
+ return;
+end
+
+if ~isscalar(N)
+ error('First argument must be a scalar count.');
+end
+if ~isvector(dnnHist) || ~isvector(dnnHist) || length(dnnHist) ~= length(dnnBins)
+ error('dnnHist and dnnBins must be the same length vectors.');
end
+if ~isvector(danyHist) || ~isvector(danyHist) || length(danyHist) ~= length(danyBins)
+ error('danyHist and danyBins must be the same length vectors.');
+end
+% Set the default parameter values.
if nargin < 6
deltaTarget = 1/exp(1.0);
end
@@ -236,12 +228,12 @@
if debugPlot
figure(4)
clf;
- semilogx(wList/dScale, [binNnProb' binAnyProb']);
- legend('P_{nn}', 'P_{any}','Location','NorthWest');
+ semilogx(wList/dScale, [binNnProb' binAnyProb' binNnProb2' binAnyProb2']);
+ legend('p_{nn}', 'p_{any}','q_{nn}', 'q_{any}','Location','NorthWest');
title('LSH Bucket Estimate')
ylabel('Collision Probabilities')
xlabel('Bin Width (w)')
- axis([1e-2/dScale 2e1/dScale 0 1]);
+ axis([1e-2/dScale 4e1/dScale 0 1]);
% axis([0 max(wList) 0 1])
% pause;
end
@@ -324,9 +316,11 @@
% Eq. 41 for all values of w
wAlpha = log((binNnProb-binAnyProb)./(1-binNnProb)) + ...
r * log(binAnyProb2./binAnyProb) + log(uCheck/uHash) - log(factorial(r));
-% Eq. 40 for all values of w
+% Eq. 40 for all values of w. Added rprime definition to make things
+% easier.
+rprime = r ./ (-log(binAnyProb));
wBestK = (log(N)+wAlpha)./(-log(binAnyProb)) + ...
- r*log((log(N)+wAlpha)./(-log(binAnyProb)));
+ rprime.*log((log(N)+wAlpha)./(-log(binAnyProb)));
wBestK(imag(wBestK) ~= 0.0 | wBestK < 1) = 1; % Set bad values to at least 1
% Now we want to find the total cost for all values of w. We will argmin
@@ -375,6 +369,7 @@
end
results.wCandidateCount = N * (1 - (1-probSum).^lVsW);
results.wCandidateCount2 = N*probSum.*lVsW;
+results.exactCandidateCount = results.wCandidateCount(optimalBin);
fprintf('Exact Optimization:\n');
fprintf('\tFor %d points of data use: ', N);
@@ -400,6 +395,58 @@
nnHitProb = 1 - (1-nnHitProbL1)^desiredOptimalL;
anyHitProb = 1 - (1-anyHitProbL1)^desiredOptimalL;
+%%
+%%%%%%%%%%%%%%%%%%% Brute Force Calculation %%%%%%%%%%%%%%%%%%%%%%%%%%
+if 1
+ maxK = optimalK + 10;
+ Ts = zeros(length(binNnProb), maxK);
+ for k=1:maxK
+ cPnnQnn = choose(k, min(r, k))*binNnProb.^(k-r).*binNnProb2.^r;
+ cPanyQany = choose(k, min(r, k))*binAnyProb.^(k-r).*binAnyProb2.^r;
+ TsByW = uHash * (-log(deltaTarget)) ./ cPnnQnn + ...
+ uCheck * N * (-log(deltaTarget)) * cPanyQany ./ cPnnQnn;
+ Ts(:,k) = TsByW;
+ end
+ % http://stackoverflow.com/questions/2635120/how-can-i-find-the-maximum-or-minimum-of-a-multi-dimensional-matrix-in-matlab
+ [min_Ts, min_Ts_position] = min(Ts(:));
+ % transform the index in the 1D view to 2 indices, given the size of Ts
+ [minBruteBin,minBruteK] = ind2sub(size(Ts), min_Ts_position);
+ results.Ts = Ts;
+ results.bruteBin = minBruteBin;
+ results.bruteCost = min_Ts;
+ results.bruteK = minBruteK;
+ results.bruteW = wList(minBruteBin)/dScale;
+ results.bruteCandidates = ...
+ N*choose(results.bruteK, min(r, results.bruteK)) * ...
+ (binAnyProb(minBruteBin)^(max(1,results.bruteK-r))) * ...
+ (binAnyProb2(minBruteBin).^r);
+ results.bruteL = ceil(-log(deltaTarget) / ...
+ ( choose(results.bruteK, min(r, results.bruteK)) * ...
+ (binNnProb(minBruteBin)^(max(1,results.bruteK-r))) * ...
+ (binNnProb2(minBruteBin).^r)));
+ results.bruteCost = results.bruteL*uHash + results.bruteCandidates*uCheck;
+
+ if debugPlot
+ figure(6);
+ clf
+ TsDetail = min(results.Ts, 10*min_Ts);
+ imagesc(log10(results.wList/results.dScale), ...
+ 1:size(TsDetail, 2), ...
+ log10(TsDetail')); axis ij
+ xlabel('Log10 Bin Size (w)'); ylabel('Number of Projections (k)');
+ title('Log10(T_s) by Brute Force');
+ colorbar;
+ hold on;
+ plot(log10(results.bruteW), results.bruteK, 'wo');
+ plot(log10(results.exactW), results.exactK,'w*');
+ plot(log10(results.simpleW), results.simpleK, 'wx');
+ hold off;
+ end
+end
+
+%%
+%%%%%%%%%%%%%%%%%%% Summarize the results %%%%%%%%%%%%%%%%%%%%%%%%%%
+
fprintf('Expected statistics for optimal solution:\n');
fprintf('\tAssuming K=%d, L=%d, hammingR=%d\n', desiredOptimalK, ...
desiredOptimalL, r);
@@ -542,7 +589,7 @@
function c = choose(n, k)
% function c = choose(n, k)
% Works for vectors of n
-c = factorial(n) ./ (factorial(k) .* factorial(n-k));
+c = factorial(n) ./ (factorial(k) .* factorial(max(0,n-k)));
%%
function normed = NormalizePDF(y,x)

0 comments on commit 5ce3840

Please sign in to comment.
Something went wrong with that request. Please try again.