Skip to content

Commit

Permalink
added test cases. Sanity Commit
Browse files Browse the repository at this point in the history
  • Loading branch information
tarunrs committed May 28, 2012
1 parent 7e63ac5 commit 923bbf2
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 106 deletions.
48 changes: 48 additions & 0 deletions evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
def proceed(num=0.5):
if random() < num:
return True
return False

def count_terms(features, clause):
num = 0
for f in features:
if f[1] == clause:
num = num + 1
return num

def remove_features(all_features, clause, num_to_remove):
removed_features = []
for f in all_features:
if f[1] == clause:
removed_features.append(f)
return removed_features

def rel (features_required, features_suggested, i):
if str(features_suggested[i]) in features_required:
return 1
else:
return 0

def precision(features_required, features_suggested, k):
sum = 0.0
for i in range(k+1):
sum = sum + rel(features_required, features_suggested, i)
return sum / float(k+1)

def average_precision(features_required, features_suggested, k):
sum= 0.0
for i in range(k):
sum = sum + (precision(features_required, features_suggested, i) * rel(features_required, features_suggested, i))
# print sum
return sum / float(len(features_required))


def TPM (features_required, features_suggested, str_features_required, str_features_suggested, k, length_of_partial_query):
total_sum = 0
for rank, sug in enumerate(features_suggested):
if str(sug) in features_required:
total_sum = total_sum + len(str_features_suggested[rank]) - rank
else:
total_sum = total_sum - rank
return total_sum / float(length_of_partial_query)

44 changes: 44 additions & 0 deletions generate_test_data_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from sqlparser import ParsedQuery, SELECT_CLAUSE, FROM_CLAUSE, WHERE_CLAUSE, GROUPBY_CLAUSE, ORDERBY_CLAUSE, HAVING_CLAUSE
from random import *
from snip_suggest import get_suggestions, find_feature_ids_with_clauses, clause, find_feature_ids, snippets

f1 = open("exclude/queries.from0", "w")
f2 = open("exclude/queries.from1", "w")
f3 = open("exclude/queries.from2", "w")
f4 = open("exclude/queries.where0", "w")
f5 = open("exclude/queries.where1", "w")
f6 = open("exclude/queries.groupby0", "w")
f7 = open("exclude/queries.orderby0", "w")

for q in open("exclude/queries.test"):
pq = ParsedQuery(q)
num_from = len(pq.result.tables)
num_select = len(pq.result.columns)
num_where = len(pq.result.where_terms)
num_group_by = len(pq.result.group_by_terms)
num_order_by = len(pq.result.order_by_terms)
# print num_select, num_from, num_where, num_group_by, num_order_by
if num_from > 0 :
f1.write(q)
if num_from > 1 :
f2.write(q)
if num_from > 2 :
f3.write(q)

if num_where > 0 :
f4.write(q)
if num_where > 1 :
f5.write(q)

if num_group_by > 0 :
f6.write(q)
if num_order_by > 0 :
f7.write(q)

f1.close()
f2.close()
f3.close()
f4.close()
f5.close()
f6.close()
f7.close()
33 changes: 28 additions & 5 deletions query_logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from sqlparser import ParsedQuery
from time import time
import MySQLdb

DATABASE_HOST = "localhost"
Expand All @@ -11,6 +12,7 @@ def check_and_insert(q, cursor):
sql = "insert into Queries (query_text) values ('" + q.query_string + "');"
if cursor.execute(sql):
qid = cursor.lastrowid
#print qid
for f in q.features:
sql = "select id from Features where feature_description = '"+ f[0] +"' AND clause = "+ str(f[1]) +";"
res = cursor.execute(sql)
Expand All @@ -25,14 +27,35 @@ def check_and_insert(q, cursor):
sql = "insert into QueryFeatures (query_id, feature_id) values (" + str(qid)+ "," + str(fid) + ");"
if not cursor.execute(sql):
print "Insert into QueryFeatures failed"
db.commit()
return 1
else:
print "Query failed"
return 0

db=MySQLdb.connect(host=DATABASE_HOST,user=DATABASE_USER, passwd=DATABASE_PASSWD, db=DATABASE_NAME, port=int(DATABASE_PORT))
cursor = db.cursor()

fname = "exclude/train"
for q in open(fname, "r"):
t = ParsedQuery(q)
check_and_insert(t, cursor)

fname = "exclude/queries.train2"
inserted_file = open("exclude/queries.inserted", "a")
start_time = time()
print "start_time ", start_time
for line in open(fname, "r"):
if line.find("[") != -1 or line.find("WITH") != -1 or line.find("<a") != -1 or line.find("cast") != -1 or line.find("CREATE") != -1 or line.find("create") != -1 or line.find("INSERT") != -1 or line.find("insert") != -1 or line.find("&") != -1 or line.find("0x") != -1 or line.find("#") != -1 or line.find("varchar") != -1 or line.find("+") != -1 or line.find("^") != -1 or line.find("--") != -1 or line.find("||") != -1 or line.find("between") != -1 or line.find("BETWEEN") != -1 or line.find("INTO") != -1:
continue
try:
t = ParsedQuery(line)
print line[:-1]
print str(t.features) + "\n"
if check_and_insert(t, cursor) == 1:
inserted_file.write(line)
else:
print "NOT INSERTED"
except:
print "NOT INSERTED"
db.commit()
end_time = time()
print "end_time ", end_time
print "elapsed time ", start_time - end_time


26 changes: 17 additions & 9 deletions snip_suggest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
DATABASE_PASSWD = "tarun123"
DATABASE_PORT = 3306


db=MySQLdb.connect(host=DATABASE_HOST,user=DATABASE_USER, passwd=DATABASE_PASSWD, db=DATABASE_NAME, port=int(DATABASE_PORT))
cursor = db.cursor()


def find_feature_ids(features):
fids = []
for f in features:
Expand All @@ -23,16 +21,14 @@ def find_feature_ids(features):
fids.append(str(int(cursor.fetchone()[0])))
return fids



def find_feature_ids_with_clauses(features):
fids = []
for f in features:
sql = "select id from Features where feature_description = '"+ f[0] +"' AND clause = "+ str(f[1]) +";"
res = cursor.execute(sql)
if res:
# feature is present, retrieve fid
fids.append((str(int(cursor.fetchone()[0])), f[1]))
# feature is present, retrieve fid #feature_id, clause, feature_description
fids.append((str(int(cursor.fetchone()[0])), f[1], f[0]))
return fids


Expand All @@ -45,7 +41,7 @@ def find_feature_clauses(features):

def ssaccuracy(m, features):
sql = "select qf.feature_id from QueryFeatures qf, (SELECT query_id from QueryFeatures where feature_id in ("+ ",".join(features) +") group by query_id having count(feature_id) = " + str(m)+") as sq where qf.query_id = sq.query_id AND qf.feature_id NOT IN ("+ ",".join(features) +") group by qf.feature_id order by count(sq.query_id) DESC;"
#print sql

rows = []
res = cursor.execute(sql)
if res:
Expand Down Expand Up @@ -79,16 +75,28 @@ def snippet(suggestion):
return ""

def get_suggestions(features, clause_requested, k ):
#print features, clause_requested, k
i = len(features)
suggestions = []
while len(suggestions) < k and i > 0:
candidates = ssaccuracy(i, features)
for f in candidates:
if f[0] not in set(suggestions) and clause(f[0]) == clause_requested:
if int(f[0]) not in set(suggestions) and clause(f[0]) == clause_requested:
suggestions.append(int(f[0]))
i = i - 1
return suggestions

def length_of_partial_query(features):
select_str = "SELECT "
from_str = "FROM "
where_str = "WHERE "
for f in features:
if f[1] == 1:
select_str = select_str + f[2] + " "
elif f[1] == 2:
from_str = from_str + f[2] + " "
elif f[1] == 3:
where_str = where_str + f[2] + " "
query_str = select_str + from_str + where_str
return len(query_str)


14 changes: 0 additions & 14 deletions sqlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def __init__(self, arg):
self.features = []
self.query_string = arg.replace( "'", '\\' + "'" ) # escape single quotes
self.result = parse(arg)
#self.dump()
self.get_table_alias()
self.normalize_table_aliases()
self.get_on_terms()
Expand All @@ -190,18 +189,15 @@ def __init__(self, arg):
def populate_features(self):
for term in self.result.columns:
if term.func:
print term.func
temp_str = term.func[0] + "(" + ",".join(term.func[1]) + ")"
feature = (temp_str, SELECT_CLAUSE)
elif len(term.column) == 0:
feature = (term[0], SELECT_CLAUSE)
else:
feature = (term.column, SELECT_CLAUSE)
self.features.append(feature)
#self.columns.append(term.column)

for term in self.result.tables:
#print term
feature = ""
if term.table:
feature = (term.table, FROM_CLAUSE)
Expand All @@ -211,8 +207,6 @@ def populate_features(self):
feature = (temp_str, FROM_CLAUSE)
if feature != "":
self.features.append(feature)

#self.tables.append(term.table)

for term in self.result.where_terms:
feature = (" ".join(term), WHERE_CLAUSE)
Expand All @@ -222,29 +216,24 @@ def populate_features(self):
for term in self.result.tables.on_terms:
feature = (" ".join(term), WHERE_CLAUSE)
self.features.append(feature)
#self.where_terms.append(" ".join(term))

for term in self.result.group_by_terms:
feature = (" ".join(term), GROUPBY_CLAUSE)
self.features.append(feature)
#self.group_by_terms.append(" ".join(term))

for term in self.result.order_by_terms:
feature = (" ".join(term), ORDERBY_CLAUSE)
self.features.append(feature)
#self.order_by_terms.append(" ".join(term))

def dump(self):
print self.result.dump()

def normalize_table_aliases(self):
#print self.table_aliases
for term in self.result.where_terms:
for i in range(len(term)):
if term[i].split(".")[0] in self.table_aliases:
term[i] = self.table_aliases[term[i].split(".")[0]] + "." + "".join(term[i].split(".")[1:])

#normalizing ON terms
for term in self.result.tables :
if term.table or term.table_function:
continue
Expand Down Expand Up @@ -272,8 +261,6 @@ def normalize_table_aliases(self):
for i in range(len(self.result.order_by_terms)):
col = self.result.order_by_terms[i]
temp = col[0].split(".")
# print temp
# print type(self.result.order_by_terms[i][0])
if len(temp) == 3:
#has db + tablename + columnname
if temp[1] in self.table_aliases:
Expand Down Expand Up @@ -310,7 +297,6 @@ def normalize_column_aliases(self):

def get_table_alias(self):
for t in self.result.tables :
#print t
if t.table_alias:
if t.table:
self.table_aliases[t.table_alias[0]] = t.table
Expand Down
Loading

0 comments on commit 923bbf2

Please sign in to comment.