Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Change the ListHash function to use the original recursive hash, instead

of converting to strings.  Added item ID cache, changing the original string
to an integer for faster retrieval.
  • Loading branch information...
commit 74706d997faf42e87ebd38ddd74330289c0a5201 1 parent 4d4a3b6
@MalcolmSlaney MalcolmSlaney authored
Showing with 45 additions and 18 deletions.
  1. +45 −18 lsh.py
View
63 lsh.py
@@ -112,7 +112,8 @@ def CreateProjections(self, dim):
# Compute the t1 and t2 hashes for some data. Doing it this way
# instead of in a loop, as before, is 10x faster. Thanks to Anirban
- # for pointing out the flaw.
+ # for pointing out the flaw. Not sure if the T2 hash is needed since
+ # our T1 hash is so strong.
debugFP = None
firstTimeCalculateHashes = False # Change to false to turn this off
infinity = float('inf') # Easy way to access this flag
@@ -121,7 +122,7 @@ def CalculateHashes(self, data):
and quantize'''
if self.projections == None:
self.CreateProjections(len(data))
- bins = numpy.zeros((self.k,1), 'int64')
+ bins = numpy.zeros((self.k,1), 'int')
if lsh.firstTimeCalculateHashes:
print 'data = ', numpy.transpose(data)
print 'bias = ', numpy.transpose(self.bias)
@@ -154,16 +155,18 @@ def CalculateHashes(self, data):
t1 = self.ListHash(bins)
t2 = self.ListHash(bins[::-1]) # Reverse data for second hash
return t1, t2
-
+
+ # Input: A Nx1 array (of integers)
+ # Output: A 28 bit hash value.
# From: http://stackoverflow.com/questions/2909106/
# python-whats-a-correct-and-good-way-to-implement-hash/2909572#2909572
def ListHash(self, d):
- return str(d).__hash__()
+ # return str(d).__hash__() # Good for testing, but not efficient
if d == None or len(d) == 0:
return 0
- d = d.reshape((d.shape[0]*d.shape[1]))
- value = d[0] << 7
- for i in d:
+ # d = d.reshape((d.shape[0]*d.shape[1]))
+ value = d[0, 0] << 7
+ for i in d[:,0]:
value = (101*value + i)&0xfffffff
return value
@@ -172,7 +175,7 @@ def CalculateHashes2(self, data):
if self.projections == None:
print "CalculateHashes2: data.shape=%s, len(data)=%d" % (str(data.shape), len(data))
self.CreateProjections(len(data))
- bins = numpy.zeros((self.k,1), 'int64')
+ bins = numpy.zeros((self.k,1), 'int')
parray = numpy.dot(self.projections, data)
bins[:] = numpy.floor(parray/self.w + self.bias)
t1 = self.ListHash(bins)
@@ -190,15 +193,17 @@ def CalculateHashes2(self, data):
def CalculateHashIterator(self, data, multiprobeRadius=0):
if self.projections == None:
self.CreateProjections(len(data))
- bins = numpy.zeros((self.k,1), 'int64')
+ bins = numpy.zeros((self.k,1), 'int')
+ directVector = numpy.zeros((self.k,1), 'int')
+ newProbe = numpy.zeros((self.k,1), 'int')
if self.w == lsh.infinity:
points = numpy.dot(self.projections, data)
bins[:] = (numpy.sign(points)+1)/2.0
- directVector = -numpy.sign(bins-0.5)
+ directVector[:] = -numpy.sign(bins-0.5)
else:
points = numpy.dot(self.projections, data)/self.w + self.bias
bins[:] = numpy.floor(points)
- directVector = numpy.sign(points-numpy.floor(points)-0.5)
+ directVector[:] = numpy.sign(points-numpy.floor(points)-0.5)
t1 = self.ListHash(bins)
t2 = self.ListHash(bins[::-1])
yield (t1,t2)
@@ -207,13 +212,13 @@ def CalculateHashIterator(self, data, multiprobeRadius=0):
# print "Multiprobe bin:", bins
# print "Multiprobe direct:", direct
dimensions = range(self.k)
- deltaVector = numpy.zeros((self.k, 1), 'int64') # Preallocate
+ deltaVector = numpy.zeros((self.k, 1), 'int') # Preallocate
for r in range(1, multiprobeRadius):
# http://docs.python.org/library/itertools.html
for candidates in itertools.combinations(dimensions, r):
deltaVector *= 0 # Start Empty
deltaVector[list(candidates), 0] = 1 # Set some bits
- newProbe = bins + deltaVector*directVector # Modify probe
+ newProbe[:] = bins + deltaVector*directVector # New probe
t1 = self.ListHash(newProbe)
t2 = self.ListHash(newProbe[::-1]) # Reverse data for second hash
# print "Multiprobe probe:",p, t1, t2
@@ -346,6 +351,7 @@ def __init__(self, w, k, l, N = 1<<31):
self.w = w
self.N = N
self.projections = []
+ self.myIDs = []
for i in range(0,l): # Create all LSH buckets
self.projections.append(lsh(w, k, N))
@@ -356,10 +362,26 @@ def sizeof(self):
return sum(p.sizeof() for p in self.projections) + \
sys.getsizeof(self)
+ # Replace id we are given with a numerical id. Since we are going
+ # to use the ID in L tables, it is better to replace it here with
+ # an integer. We store the original ID in an array, and return it
+ # to the user when we do a find().
+ def AddIDToIndex(self, id):
+ if type(id) == int:
+ return id # Don't bother if already an int
+ self.myIDs.append(id)
+ return len(self.myIDs)-1
+
+ def FindID(self, id):
+ if type(id) != int or id < 0 or id >= len(self.myIDs):
+ return id
+ return self.myIDs[id]
+
# Insert some data into all LSH buckets
def InsertIntoTable(self, id, data):
+ intID = self.AddIDToIndex(id)
for p in self.projections:
- p.InsertIntoTable(id, data)
+ p.InsertIntoTable(intID, data)
def Find(self, data):
'''Find some data in all the LSH buckets. Return a list of
@@ -374,7 +396,7 @@ def Find(self, data):
results[item] = 1
s = sorted(results.items(), key=operator.itemgetter(1), \
reverse=True)
- return s
+ return [(self.FindID(i),c) for (i,c) in s]
def FindMP(self, data, multiprobe=0):
'''Use a simple multiprobe algorithm to retrieve more points. Keep
@@ -391,7 +413,7 @@ def FindMP(self, data, multiprobe=0):
results[id] = 1
s = sorted(results.items(), key=operator.itemgetter(1), \
reverse=True)
- return s
+ return [(self.FindID(i),c) for (i,c) in s]
def FindExact(self, data, GetData):
'''Return a list of results sorted by their exact
@@ -401,8 +423,8 @@ def FindExact(self, data, GetData):
# print "Intermediate results are:", s
d = map(lambda (id,count): (id,((GetData(id)-data)**2).sum(), \
count), s)
- ds = sorted(d, key=operator.itemgetter(1))
- return ds
+ s = sorted(d, key=operator.itemgetter(1))
+ return [(self.FindID(i),c) for (i,c) in s]
# Put some data into the hash tables.
def Test(self, n):
@@ -725,11 +747,13 @@ def ComputePnnPany(self, w, k, l, multiprobe=0):
cany = 0; canyFull = 0
queryCount = 0 # Probe the index
totalQueryTime = 0
+ startRecallTestTime = time.clock()
# print "ComputePnnPany: Testing %d nearest neighbors." % len(self.nearestNeighbors.items())
for (queryKey,(nnKey,dist)) in self.nearestNeighbors.items():
queryData = self.RetrieveData(queryKey)
if queryData == None or len(queryData) == 0:
print "Can't find data for key %s" % str(queryKey)
+ sys.stdout.flush()
continue
startQueryTime = time.clock()
matches = self.myIndex.FindMP(queryData, multiprobe)
@@ -744,6 +768,9 @@ def ComputePnnPany(self, w, k, l, multiprobe=0):
queryCount += 1
# Some debugging for k curve.. print individual results
# print "ComputePnnPany Debug:", w, k, l, len(matches), numPoints, cnn, cnnFull, cany, canyFull
+ recallTestTime = time.clock() - startRecallTestTime
+ print "Tested %d NN queries in %g seconds." % (queryCount, recallTestTime)
+ sys.stdout.flush()
if queryCount == 0:
queryCount = 1 # To prevent divide by zero
perQueryTime = totalQueryTime/queryCount
Please sign in to comment.
Something went wrong with that request. Please try again.