Permalink
Browse files

Changed the test code so that the datafile's name is now a parameter.

Added the -create flag to create the data (separate from histogram).
  • Loading branch information...
1 parent 41c01c9 commit 29187d31811a583bc799b9894ebf1c96afcfa4b5 @MalcolmSlaney MalcolmSlaney committed Oct 4, 2011
Showing with 48 additions and 41 deletions.
  1. +48 −41 lsh.py
View
@@ -240,7 +240,7 @@ def InsertIntoTable(self, id, data):
# Find some data in the hash bucket. Return all the ids
# that we find for this T1-T2 pair.
- def Find(self, data):
+ def FindXXObsolete(self, data):
(t1, t2) = self.CalculateHashes(data)
if t1 not in self.buckets:
return []
@@ -249,11 +249,13 @@ def Find(self, data):
return []
return row[t2]
- # Return a list of entries, each entry is a data point's id
- def FindMP(self, data, multiprobeRadius=0):
+ #
+ def Find(self, data, multiprobeRadius=0):
+ '''Find the points that are close to the query data. Use multiprobe
+ to also look in nearby buckets.'''
res = []
for (t1,t2) in self.CalculateHashIterator(data, multiprobeRadius):
- # print "FindMP t1:", t1
+ # print "Find t1:", t1
if t1 not in self.buckets:
continue
row = self.buckets[t1]
@@ -385,7 +387,7 @@ def InsertIntoTable(self, id, data):
for p in self.projections:
p.InsertIntoTable(intID, data)
- def Find(self, data):
+ def FindXXObsolete(self, data):
'''Find some data in all the LSH buckets. Return a list of
data's id and bucket counts'''
items = [p.Find(data) for p in self.projections]
@@ -400,14 +402,13 @@ def Find(self, data):
reverse=True)
return [(self.FindID(i),c) for (i,c) in s]
- def FindMP(self, data, multiprobe=0):
- '''Use a simple multiprobe algorithm to retrieve more points. Keep
- track of the number of hits returned.
- Currently only searches with a Hamming distance of 1.'''
+ def Find(self, data, multiprobeR=0):
+ '''Find some data in all the LSH tables. Use Multiprobe, with
+ the given radius, to search neighboring buckets.'''
results = {}
for p in self.projections:
- ids = p.FindMP(data,multiprobe)
- # print "Got back these IDs from p.FindMP:", ids
+ ids = p.Find(data,multiprobeR)
+ # print "Got back these IDs from p.Find:", ids
for id in ids:
if id in results:
results[id] += 1
@@ -417,16 +418,16 @@ def FindMP(self, data, multiprobe=0):
reverse=True)
return [(self.FindID(i),c) for (i,c) in s]
- def FindExact(self, data, GetData):
+ def FindExact(self, data, GetData, multiprobeR=0):
'''Return a list of results sorted by their exact
distance from the query. GetData is a function that
returns the original data given its key.'''
- s = self.Find(data)
+ s = self.Find(data, multiprobeR)
# print "Intermediate results are:", s
d = map(lambda (id,count): (id,((GetData(id)-data)**2).sum(), \
count), s)
s = sorted(d, key=operator.itemgetter(1))
- return [(self.FindID(i),c) for (i,c) in s]
+ return [(self.FindID(i),d) for (i,d,c) in s]
# Put some data into the hash tables.
def Test(self, n):
@@ -593,7 +594,7 @@ def CreateIndex(self, w, k, l):
def RetrieveData(self, id):
'''Find a point in the array of data.'''
- id = int(id) # Key in default class is an int!
+ id = int(id) # Key in this base class is an int!
if id < self.myData.shape[1]:
return self.myData[:,id:id+1]
return None
@@ -645,7 +646,7 @@ def SaveNearestNeighbors(self, filename):
fp = open(filename, 'w')
if fp:
for (query,(nn,dist)) in self.nearestNeighbors.items():
- fp.write('%s %g %s\n' % (query, dist, nn))
+ fp.write('%s %g %s\n' % (str(query), dist, str(nn)))
fp.close()
else:
print "Can't open %s to write nearest-neighbor data" % filename
@@ -744,7 +745,6 @@ def ComputePnnPany(self, w, k, l, multiprobe=0):
numPoints = self.NumPoints()
numDims = self.NumDimensions()
self.CreateIndex(w, k, l) # Put data into new index
- # OutputAllProjections(self, myTestIndex, 'testData%03d.proj' % numDims)
cnn = 0; cnnFull = 0
cany = 0; canyFull = 0
queryCount = 0 # Probe the index
@@ -758,7 +758,7 @@ def ComputePnnPany(self, w, k, l, multiprobe=0):
sys.stdout.flush()
continue
startQueryTime = time.clock() # Measure CPU time
- matches = self.myIndex.FindMP(queryData, multiprobe)
+ matches = self.myIndex.FindExact(queryData, self.RetrieveData, multiprobe)
totalQueryTime += time.clock() - startQueryTime
for (m,c) in matches:
if nnKey == m: # See if NN was found!!!
@@ -796,7 +796,7 @@ def ComputeKCurve(self, kList, w = .291032):
numPoints = self.NumPoints()
l = 10
for k in sorted(list(kList)):
- (pnn, pnnFull, pany, panyFull, queryTimem, numDims) = self.ComputePnnPany(w, k, l)
+ (pnn, pnnFull, pany, panyFull, queryTime, numDims) = self.ComputePnnPany(w, k, l)
print w, k, l, pnn, pany, pany*numPoints, queryTime
sys.stdout.flush()
@@ -993,16 +993,21 @@ def OutputAllProjections(myTestData, myTestIndex, filename):
defaultL = 1
defaultClosest = 1000
defaultMultiprobeRadius = 0
- sys.argv.pop(0)
+ defaultFileName = 'testData'
+ cmdName = sys.argv.pop(0)
while len(sys.argv) > 0:
arg = sys.argv.pop(0).lower()
if arg == '-d':
arg = sys.argv.pop(0)
try:
defaultDims = int(arg)
+ defaultFileName = 'testData%03d' % defaultDims
except:
print "Couldn't parse new value for defaultDims: %s" % arg
print 'New default dimensions for test is', defaultDims
+ elif arg == '-f':
+ defaultFileName = sys.argv.pop(0)
+ print 'New file name is', defaultFileName
elif arg == '-k':
arg = sys.argv.pop(0)
try:
@@ -1038,54 +1043,56 @@ def OutputAllProjections(myTestData, myTestIndex, filename):
except:
print "Couldn't parse new value for multiprobeRadius: %s" % arg
print 'New default multiprobeRadius for test is', defaultMultiprobeRadius
- elif arg == '-histogram': # Calculate distance histograms
+ elif arg == '-create': # Create some uniform random data
myTestData = RandomTestData()
myTestData.CreateData(100000, defaultDims)
- myTestData.SaveData('testData%03d.dat' % defaultDims)
+ myTestData.SaveData(defaultFileName + '.dat')
myTestData.FindNearestNeighbors(defaultClosest)
- myTestData.SaveNearestNeighbors('testData%03d.nn' % defaultDims)
- fp = open('testData%03d.distances' % defaultDims, 'w')
+ myTestData.SaveNearestNeighbors(defaultFileName + '.nn')
+ elif arg == '-histogram': # Calculate distance histograms
+ myTestData = TestDataClass()
+ myTestData.LoadData(defaultFileName + '.dat')
+ myTestData.LoadNearestNeighbors(defaultFileName + '.nn')
+ fp = open(defaultFileName + '.distances', 'w')
if fp:
myTestData.ComputeDistanceHistogram(fp)
fp.close()
else:
- print "Can't open testData.distances to store NN data"
+ print "Can't open %s.distances to store NN data" % defaultFileName
elif arg == '-sanity':
myTestData = TestDataClass()
- myTestData.LoadData('testData%03d.dat' % defaultDims)
+ myTestData.LoadData(defaultFileName + '.dat')
print myTestData.RetrieveData(myTestData.GetRandomQuery())
print myTestData.RetrieveData(myTestData.GetRandomQuery())
- # myTestData.FindNearestNeighbors(1000)
- # myTestData.SaveNearestNeighbors('testData.nn')
elif arg == '-b': # Calculate bucket probabilities
random.seed(0)
myTestData = TestDataClass()
- myTestData.LoadData('testData%03d.dat' % defaultDims)
- myTestData.LoadNearestNeighbors('testData%03d.nn' % defaultDims)
+ myTestData.LoadData(defaultFileName + '.dat')
+ myTestData.LoadNearestNeighbors(defaultFileName + '.nn')
# ComputePnnPanyCurve(myData, [.291032])
myTestData.ComputePnnPanyCurve(defaultW)
elif arg == '-wtest': # Calculate bucket probabilities as a function of w
random.seed(0)
myTestData = TestDataClass()
- myTestData.LoadData('testData%03d.dat' % defaultDims)
- myTestData.LoadNearestNeighbors('testData%03d.nn' % defaultDims)
+ myTestData.LoadData(defaultFileName + '.dat')
+ myTestData.LoadNearestNeighbors(defaultFileName + '.nn')
wList = [defaultW*.5**-i for i in range(-10,10)]
# wList = [defaultW*.5**-i for i in range(-3,3)]
myTestData.ComputePnnPanyCurve(wList, defaultMultiprobeRadius)
elif arg == '-ktest': # Calculate bucket probabilities as a function of k
random.seed(0)
myTestData = TestDataClass()
- myTestData.LoadData('testData%03d.dat' % defaultDims)
- myTestData.LoadNearestNeighbors('testData%03d.nn' % defaultDims)
+ myTestData.LoadData(defaultFileName + '.dat')
+ myTestData.LoadNearestNeighbors(defaultFileName + '.nn')
# ComputePnnPanyCurve(myData, [.291032])
kList = [math.floor(math.sqrt(2)**k) for k in range(0,10)]
kList = [1,2,3,4,5,6,8,10,12,14,16,18,20]
myTestData.ComputeKCurve(kList, defaultW)
elif arg == '-ltest': # Calculate bucket probabilities as a function of l
random.seed(0)
myTestData = TestDataClass()
- myTestData.LoadData('testData%03d.dat' % defaultDims)
- myTestData.LoadNearestNeighbors('testData%03d.nn' % defaultDims)
+ myTestData.LoadData(defaultFileName + 'dat')
+ myTestData.LoadNearestNeighbors(defaultFileName + '.nn')
# ComputePnnPanyCurve(myData, [.291032])
lList = [math.floor(math.sqrt(2)**k) for k in range(0,10)]
lList = [1,2,3,4,5,6,10]
@@ -1104,8 +1111,8 @@ def OutputAllProjections(myTestData, myTestIndex, filename):
print "Couldn't parse %s. Need w,k,l,r" % sys.argv[0]
sys.argv.pop(0)
myTestData = TestDataClass()
- myTestData.LoadData('testData%03d.dat' % defaultDims)
- myTestData.LoadNearestNeighbors('testData%03d.nn' % defaultDims)
+ myTestData.LoadData(defaultFileName + '.dat')
+ myTestData.LoadNearestNeighbors(defaultFileName + '.nn')
for (w, k, l, r) in timingModels:
sys.stdout.flush()
(pnnL1, pnn, panyL1, pany, perQueryTime, numDims) = myTestData.ComputePnnPany(w, k, l, r)
@@ -1114,9 +1121,9 @@ def OutputAllProjections(myTestData, myTestIndex, filename):
elif arg == '-test': # Calculate bucket probabilities as a function of l
random.seed(0)
myTestData = TestDataClass()
- myTestData.LoadData('testData%03d.dat' % defaultDims)
- myTestData.LoadNearestNeighbors('testData%03d.nn' % defaultDims)
+ myTestData.LoadData(defaultFileName + '.dat')
+ myTestData.LoadNearestNeighbors(defaultFileName + '.nn')
# ComputePnnPanyCurve(myData, [.291032])
myTestData.ComputeLCurve([defaultL], w=defaultW, k=defaultK)
else:
- print '%s: Unknown test parameter' % sys.argv[0]
+ print '%s: Unknown test argument %s' % (cmdName, arg)

0 comments on commit 29187d3

Please sign in to comment.