Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Changed a bunch of things in Matlab and Python code to make testing b…

…etter.
  • Loading branch information...
commit 4d4a3b67fcb50b5a21a3be8a5a6ebd990bfe5bbc 1 parent a012b14
@MalcolmSlaney MalcolmSlaney authored
Showing with 75 additions and 30 deletions.
  1. +31 −13 CalculateMPLSHParameters.m
  2. +44 −17 lsh.py
View
44 CalculateMPLSHParameters.m
@@ -37,6 +37,8 @@
path(path,'/Users/malcolm/Projects/LSHwithYury/jlab/')
end
+debugPlot = 0; % Set to non zero to get debugging plots
+
%%
%%%%%%%%%%%%%%%%%%% ARGUMENT PARSING %%%%%%%%%%%%%%%%%%%%%%%%%%
@@ -75,7 +77,8 @@
%%%%%%%%%%%%%%%% Make sure basic data looks good %%%%%%%%%%%%%%%%%%%%%
-if 0
+if debugPlot
+ figure(1);
clf;
plot(dnnBins, dnnHist/sum(dnnHist), danyBins, danyHist/sum(danyHist));
legend('Nearest Neighbor', 'Any Neighbor')
@@ -116,12 +119,13 @@
results.dnnPDF = dnnPDF;
results.danyPDF = danyPDF;
-if 0
+if debugPlot
+ figure(2);
clf
plot(xs/dScale, dnnPDF, xs/dScale, danyPDF);
legend('Nearest Neighbor', 'Any Neighbor')
title('Distance Distributions');
- xlabel('Distance')
+ xlabel('Scaled Distance - mean(d_{any}) = 2')
ylabel('PDF of Distance');
end
@@ -163,7 +167,8 @@
results.projAnyPDF = projAnyPDF;
results.projNnPDF = projNnPDF;
-if 0
+if debugPlot
+ figure(3);
subplot(2,1,1);
vScale = max(dnnPDF*dScale)/max(projNnPDF*dScale);
plot(xs/dScale, dnnPDF*dScale/vScale, xp/dScale, projNnPDF*dScale);
@@ -227,10 +232,11 @@
end
-if 0
+if debugPlot
+ figure(4)
clf;
semilogx(wList/dScale, [binNnProb' binAnyProb']);
- legend('Pnn', 'Pany','Location','NorthWest');
+ legend('P_{nn}', 'P_{any}','Location','NorthWest');
title('LSH Bucket Estimate')
ylabel('Collision Probabilities')
xlabel('Bin Width (w)')
@@ -333,7 +339,7 @@
% Inside of Eq. (39)
temp = ((binK.^r).*(binNnProb-binAnyProb)*N*uCheck.*((binAnyProb2./binAnyProb).^r))./ ...
((1-binNnProb).*uHash*factorial(r));
-wFullCost = (-log(deltaTarget))* uHash*(factorial(r))*((binNnProb./binNnProb2).^r)./ ...
+wFullCost = (-log(deltaTarget))* uHash*factorial(r)*((binNnProb./binNnProb2).^r)./ ...
(binK.^r).*(1-binAnyProb)./(binNnProb-binAnyProb).* ...
temp.^(log(binNnProb)./log(binAnyProb));
results.wFullCost = wFullCost;
@@ -342,9 +348,14 @@
optimalW = wList(optimalBin)/dScale;
optimalK = floor(binK(optimalBin));
+% optimalL = ceil(-log(deltaTarget)/ ...
+% ( (optimalK^r)/factorial(r) * (binNnProb(optimalBin)^(optimalK-r)) * ...
+% (binNnProb2(optimalBin)^r))); % Wrong expression for C^r_k -
+% Malcolm 9/8/2011
+ % Equation (42)
optimalL = ceil(-log(deltaTarget)/ ...
- ( (optimalK^r)/factorial(r) * (binNnProb(optimalBin)^(optimalK-r)) * ...
- (binNnProb2(optimalBin)^r) ));
+ ( choose(optimalK,r) * (binNnProb(optimalBin)^(optimalK-r)) * ...
+ (binNnProb2(optimalBin)^r)));
% Equations (48), (49) and (50) for optimalBin estimate.
Ch = uHash * (-log(deltaTarget)) * ...
@@ -372,9 +383,14 @@
% And print the statistics for L=1 for simple parameters.
desiredOptimalK = round(optimalK);
desiredOptimalL = round(optimalL);
-nnHitProbL1 = binNnProb(optimalBin)^desiredOptimalK;
-anyHitProbL1 = binAnyProb(optimalBin)^desiredOptimalK;
-
+% nnHitProbL1 = binNnProb(optimalBin)^desiredOptimalK;
+% anyHitProbL1 = binAnyProb(optimalBin)^desiredOptimalK;
+% From the definition of p_nn in Eq. (46)
+nnHitProbL1 = choose(desiredOptimalK, r)*binNnProb(optimalBin)^(desiredOptimalK-r)*...
+ binNnProb2(optimalBin)^(r);
+anyHitProbL1 = choose(desiredOptimalK, r)*binAnyProb(optimalBin)^(desiredOptimalK-r)*...
+ binAnyProb2(optimalBin)^(r);
+
nnHitProb = 1 - (1-nnHitProbL1)^desiredOptimalL;
anyHitProb = 1 - (1-anyHitProbL1)^desiredOptimalL;
@@ -389,7 +405,8 @@
fprintf('\tExpected number of hits per query: %g\n', anyHitProb*N);
%%
-if 0
+if debugPlot
+ figure(5);
clf
subplot(4,1,1);
semilogx(wList/dScale, [binNnProb' binAnyProb']);
@@ -401,6 +418,7 @@
subplot(4,1,2);
semilogx(wList/dScale, [log(binNnProb') log(binAnyProb')]);
+ title('LSH Bucket Estimate (Log Scale)')
xlabel('Bin Width (w)');
ylabel('Log(Collision Probabilities)');
legend('Log(Pnn)', 'Log(Pany)', 'Location','NorthWest');
View
61 lsh.py
@@ -507,6 +507,7 @@ class TestDataClass:
# this LSH implementation.'''
def __init__(self):
self.myData = None
+ self.myIndex = None
self.nearestNeighbors = {} # A dictionary pointing to IDs
def LoadData(self, filename):
@@ -551,6 +552,21 @@ def SaveData(self, filename):
pass
sys.stderr.write("Can't write test data to %s\n" % filename)
+ def CreateIndex(self, w, k, l):
+ '''Create an index for the data we have in our database. Inputs are
+ the LSH parameters: w, k and l.'''
+ self.myIndex = index(w, k, l)
+ itemCount = 0
+ tic = time.clock()
+ for itemID in self.IterateKeys():
+ features = self.RetrieveData(itemID)
+ if features != None:
+ self.myIndex.InsertIntoTable(itemID, features)
+ itemCount += 1
+ print "Finished indexing %d items in %g seconds." % \
+ (itemCount, time.clock()-tic)
+ sys.stdout.flush()
+
def RetrieveData(self, id):
'''Find a point in the array of data.'''
id = int(id) # Key in default class is an int!
@@ -659,19 +675,28 @@ def ComputeDistanceHistogram(self, fp = sys.stdout):
histograms needed for the LSH Parameter Optimization. For
a number of random query points, print the distance to the
nearest neighbor, and to any random neighbor. This becomes
- the input for the parameter optimization routine'''
+ the input for the parameter optimization routine. Enhanced
+ to also print the NN binary projections.'''
numPoints = self.NumPoints()
medians = self.FindMedian()
+ print "Pulling %d items from the NearestNeighbors list for ComputeDistanceHistogram" % \
+ len(self.nearestNeighbors.items())
for (queryKey,(nnKey,nnDist)) in self.nearestNeighbors.items():
randKey = self.GetRandomQuery()
queryData = self.RetrieveData(queryKey)
nnData = self.RetrieveData(nnKey)
randData = self.RetrieveData(randKey)
-
+ if len(queryData) == 0 or len(nnData) == 0: # Missing, probably because of subsampling
+ print "Skipping %s/%s because data is missing." % (queryKey, nnKey)
+ continue
anyD2 = ((randData-queryData)**2).sum()
projection = numpy.random.randn(1, queryData.shape[0])
+ # print "projection:", projection.shape
+ # print "queryData:", queryData.shape
+ # print "nnData:", nnData.shape
+ # print "randData:", randData.shape
queryProj = numpy.sign(numpy.dot(projection, queryData))
nnProj = numpy.sign(numpy.dot(projection, nnData))
randProj = numpy.sign(numpy.dot(projection, randData))
@@ -680,8 +705,8 @@ def ComputeDistanceHistogram(self, fp = sys.stdout):
fp.write('%g %g %d %d\n' % \
(nnDist, math.sqrt(anyD2), \
queryProj==nnProj, queryProj==randProj))
- fp.flush()
-
+ fp.flush()
+
def ComputePnnPany(self, w, k, l, multiprobe=0):
'''Compute the probability of Pnn and Pany for a given index size.
Create the desired index, populate it with the data, and then measure
@@ -690,24 +715,25 @@ def ComputePnnPany(self, w, k, l, multiprobe=0):
the pnn rate for one 1-dimensional index (l=1),
the pnn rate for an l-dimensional index,
the pany rate for one 1-dimensional index (l=1),
- and the pany rate for an l-dimensional index'''
+ and the pany rate for an l-dimensional index
+ the CPU time per query (seconds)'''
numPoints = self.NumPoints()
numDims = self.NumDimensions()
- myTestIndex = index(w, k, l) # Put data into a new index
- for id in self.IterateKeys():
- data = self.RetrieveData(id)
- myTestIndex.InsertIntoTable(id, data)
+ self.CreateIndex(w, k, l) # Put data into new index
# OutputAllProjections(self, myTestIndex, 'testData%03d.proj' % numDims)
cnn = 0; cnnFull = 0
cany = 0; canyFull = 0
- count = 0 # Probe the index
- startQueryTime = time.clock()
+ queryCount = 0 # Probe the index
+ totalQueryTime = 0
+ # print "ComputePnnPany: Testing %d nearest neighbors." % len(self.nearestNeighbors.items())
for (queryKey,(nnKey,dist)) in self.nearestNeighbors.items():
queryData = self.RetrieveData(queryKey)
if queryData == None or len(queryData) == 0:
print "Can't find data for key %s" % str(queryKey)
continue
- matches = myTestIndex.FindMP(queryData, multiprobe)
+ startQueryTime = time.clock()
+ matches = self.myIndex.FindMP(queryData, multiprobe)
+ totalQueryTime += time.clock() - startQueryTime
for (m,c) in matches:
if nnKey == m: # See if NN was found!!!
cnn += c
@@ -715,13 +741,14 @@ def ComputePnnPany(self, w, k, l, multiprobe=0):
if m != queryKey:
cany += c
canyFull += len(matches)-1
- count += 1
+ queryCount += 1
# Some debugging for k curve.. print individual results
# print "ComputePnnPany Debug:", w, k, l, len(matches), numPoints, cnn, cnnFull, cany, canyFull
- endQueryTime = time.clock()
- perQueryTime = (endQueryTime-startQueryTime)/count
- return (cnn/float(count*l)), cnnFull/float(count), \
- cany/(float(count*l*numPoints)), canyFull/float(count*numPoints), perQueryTime
+ if queryCount == 0:
+ queryCount = 1 # To prevent divide by zero
+ perQueryTime = totalQueryTime/queryCount
+ return cnn/float(queryCount*l), cnnFull/float(queryCount), \
+ cany/float(queryCount*l*numPoints), canyFull/float(queryCount*numPoints), perQueryTime
def ComputePnnPanyCurve(self, wList = .291032, multiprobe=0):
if type(wList) == float or type(wList) == int:
Please sign in to comment.
Something went wrong with that request. Please try again.