Skip to content

Commit

Permalink
* support loading unlabeled samples, calculate f-measure etc
Browse files Browse the repository at this point in the history
  • Loading branch information
whym committed Jul 7, 2011
1 parent 8e4f2ed commit 6edc111
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 34 deletions.
49 changes: 32 additions & 17 deletions load_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import urllib2
import time
import re
import ast

from twisted.internet import reactor
from twisted.web.client import Agent
Expand Down Expand Up @@ -36,12 +37,25 @@ def get_revisions(revs):
except urllib2.URLError:
time.sleep(5)

revs = minidom.parseString(res).getElementsByTagName('rev')
return revs
ret = []
pages = minidom.parseString(res).getElementsByTagName('page')
for p in pages:
for r in p.getElementsByTagName('rev'):
ret.append((p, r))
return ret

if __name__ == '__main__':

parser = argparse.ArgumentParser()
parser.add_argument('-f', '--field', metavar='COLUMN',
dest='revfield', type=int, default=2,
help='column that contains revision IDs')
parser.add_argument('-l', '--labels', metavar='COLUMNS',
dest='labels', type=int, default=None,
help='columns that contain labels (a label is 0 or 1)')
parser.add_argument('-D', '--delimiter', metavar='CHARACTER',
dest='delimiter', type=str, default=',',
help='')
parser.add_argument('-w', '--wait', metavar='SECS',
dest='wait', type=float, default=0.5,
help='')
Expand Down Expand Up @@ -75,10 +89,13 @@ def get_revisions(revs):

# load raw table of coded examples
csv.field_size_limit(1000000000)
table = list(csv.reader(open(options.input)))
header = table[0][2:6]
header = [None,None] + header
table = table[1:]
table = list(csv.reader(open(options.input), delimiter=ast.literal_eval('"'+options.delimiter+'"')))
header = []
if options.labels != None:
header = [None for x in xrange(0, len(table[0]))]
for c in options.labels.split(','):
header[c] = table[0][c]
table = table[1:]
table_size = len(table)

# prepare HTTP agent for accessing Wikipedia API
Expand All @@ -102,8 +119,7 @@ def get_revisions(revs):
existings = {}
for x in db.find({'entry.rev_id': {'$exists': True}, 'entry.content': {'$exists': True}}, {'entry.rev_id':1, 'entry.content':1}):
existings[x['entry']['rev_id']] = True
print existings
table = filter(lambda x: not existings.has_key(int(x[1])), table)
table = filter(lambda x: not existings.has_key(int(x[options.revfield])), table)

while len(table) > 0:
colslices = table[0:options.slice]
Expand All @@ -115,29 +131,28 @@ def get_revisions(revs):
for (i,lab) in enumerate(header):
if lab != None:
labels[lab] = bool(int(cols[i]))
ent = {'entry': {'title': cols[0],
'receiver': cols[0],
'rev_id': int(cols[1]),
},
'labels': labels}
ent = {'entry': {'rev_id': int(cols[options.revfield]),
}}
if options.labels != None:
ent.update({'labels': labels})
ls.append(ent)
revmap[int(cols[1])] = cols
revmap[int(cols[options.revfield])] = cols

# call API to get content etc
revs = get_revisions([str(x) for x in revmap.keys()])
print >>sys.stderr, "received %d (%d/%d)" % (len(revs), len(revs) + db.count(), table_size)

queued = []
for (i,rev) in enumerate(revs):
ls[i]['entry']['sender'] = rev.attributes['user'].value
for (i,(page,rev)) in enumerate(revs):
ls[i]['entry'].update({'sender': rev.attributes['user'].value,
'title': page.attributes['title'].value})
assert ls[i]['entry']['rev_id'] == int(rev.attributes['revid'].value), [ls[i], rev.toxml()]
if len(rev.childNodes) == 0:
raise "empty diff: %s" % rev.toxml()
else:
if rev.childNodes[0].attributes.has_key('notcached'):
print 'no cache ' + rev.attributes['revid'].value
queued.append(revmap[int(rev.attributes['revid'].value)])
#print diffparse(rev.childNodes[0].childNodes[0].data)
else:
ls[i]['entry']['content'] = diffparse(rev.childNodes[0].childNodes[0].data)
for x in ls:
Expand Down
52 changes: 40 additions & 12 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import liblinearutil
import ast
import tempfile
from collections import namedtuple

from twisted.internet import reactor
from twisted.web.client import Agent
Expand All @@ -24,6 +25,9 @@
parser.add_argument('-f', '--find', metavar='QUERY',
dest='find', type=str, default='{}',
help='')
parser.add_argument('-o', '--output', metavar='FILE',
dest='output', type=str, default='/dev/stdout',
help='')
parser.add_argument('-m', '--model', metavar='QUERY',
dest='model', type=str, default='{}',
help='')
Expand Down Expand Up @@ -68,17 +72,25 @@

vectors = []
labels = {}
for x in models.keys():
labels[x] = []
for ent in cursor:
if not ent.has_key('labels'):
print >>sys.stderr, 'skip ' + ent['entry']['rev_id']
continue
for (name,value) in ent['labels'].items():
labels.setdefault(name, []).append(value if 1 else -1)
for name in labels.keys():
value = None
if ent.has_key('labels') and ent['labels'].has_key(name):
value = ent['labels'][name] if 1 else -1
labels.setdefault(name, []).append(value)
vec = {}
for (x,y) in ent['vector'].items():
vec[int(x)] = float(y)
vectors.append((vec, ent['entry']['rev_id']))
vectors.append((vec, ent['entry']))

for (name,vals) in labels.items():
assert len(vectors) == len(vals), [len(vectors), len(vals), name]

writer = csv.writer(open(options.output, 'w'), delimiter='\t')
writer.writerow([unicode(x) for x in ['label', 'rev_id', 'predicted', 'coded', 'confidence', 'correct?', 'diff', 'snippet']])
pn_tuple = namedtuple('pn', 'p n')
vecs = map(lambda x: x[0], vectors)
for (lname, labs) in labels.items():
m = models[lname]
Expand All @@ -89,10 +101,26 @@

lab,acc,val = liblinearutil.predict(labs, vecs, m, '-b 1')

# print failure cases
# print performances nad failure cases
pn = pn_tuple({True: 0, False: 0},
{True: 0, False: 0})
for (i,pred) in enumerate(lab):
ng = bool(pred) != labs[i]
if ng or options.verbose:
print vectors[i][1], bool(pred), labs[i], '%4.3f' % max(val[i]), 'ng' if ng else 'ok'

# TODO: f-measure
ok = bool(pred) == labs[i]
res = 'Yes' if ok else 'No'
if labs[i] == None:
res = 'Unknown'
else:
if pred > 0:
pn.p[ok] += 1
else:
pn.n[ok] += 1
if not ok or options.verbose:
link = 'http://en.wikipedia.org/w/index.php?diff=prev&oldid=%s' % vectors[i][1]['rev_id']
writer.writerow([unicode(x).encode('utf-8') for x in [lname, vectors[i][1]['rev_id'], bool(pred), labs[i], '%4.3f' % max(val[i]), res, '=HYPERLINK("%s","%s")' % (link,link), vectors[i][1]['content'][0:50]]])
print ' accuracy = %f' % (float(pn.p[True] + pn.n[True]) / sum(pn.p.values() + pn.n.values()))
prec = float(pn.p[True]) / sum(pn.p.values())
reca = float(pn.p[True]) / (pn.p[True] + pn.n[False])
print ' precision = %f' % prec
print ' recall = %f' % reca
print ' fmeasure = %f' % (1.0 / (0.5/prec + 0.5/reca))
print pn, (pn.p[True] + pn.p[False] + pn.n[True] + pn.n[False])
18 changes: 13 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,30 @@

# contruct the training set from 'entry's in the MongoDB
db = collection['talkpage_diffs_raw']
query = {'vector': {'$exists': True}}
query = {'labels': {'$exists': True},
'vector': {'$exists': True}}
query.update(ast.literal_eval(options.find))
cursor = db.find(query)
print >>sys.stderr, 'using labeld examples: %s out of %s' % (cursor.count(), db.count())
labels = {}
vectors = []
entries = []
for ent in cursor:
if not ent.has_key('labels'):
print >>sys.stderr, 'skip ' + ent['entry']['rev_id']
print >>sys.stderr, 'skip %s' % ent['entry']['rev_id']
continue
for (name,value) in ent['labels'].items():
labels.setdefault(name, []).append(value if 1 else -1)
vec = {}
for (x,y) in ent['vector'].items():
vec[int(x)] = float(y)
if len(vec.items()) == 0:
print >>sys.stderr, 'empty %s' % ent['entry']['rev_id']
#continue
vectors.append(vec)
entries.append(ent)
for (name,value) in ent['labels'].items():
labels.setdefault(name, []).append(value if 1 else -1)
if options.verbose:
print >>sys.stderr, str(ent['entry']['rev_id'])

if options.verbose:
print >>sys.stderr, 'vectors loaded'
Expand All @@ -78,7 +86,7 @@
print >>sys.stderr, '%s problem constructed' % lname
m = liblinearutil.train(prob, liblinear.parameter('-s 6'))
if options.verbose:
print >>sys.stderr, '%s trained' % lname
print >>sys.stderr, '"%s" model trained' % lname

lab,acc,val = liblinearutil.predict(labs, vectors, m)

Expand Down

0 comments on commit 6edc111

Please sign in to comment.