Skip to content
Browse files

slightly faster

  • Loading branch information...
1 parent 831bf93 commit 6a1c6b9b4670745202b4f7245f322b4953266607 @shuyo committed Nov 30, 2011
Showing with 93 additions and 31 deletions.
  1. +15 −15 da.py
  2. +17 −16 ldig.py
  3. +61 −0 test_da.py
View
30 da.py
@@ -145,20 +145,20 @@ def get_value(self, subtree):
def extract_features(self, st):
events = dict()
- pointers = []
- for c in iter(st):
- v = ord(c)
- pointers.append(0)
- new_pointers = []
- for pointer in pointers:
- next = self.base[pointer] + v
- if next < self.N and self.check[next] == pointer:
- new_pointers.append(next)
- id = self.value[next]
- if id >= 0:
- events[id] = events.get(id, 0) + 1
- pointers = new_pointers
+ l = len(st)
+ clist = [ord(c) for c in iter(st)]
+ N = self.N
+ base = self.base
+ check = self.check
+ value = self.value
+ for i in xrange(l):
+ pointer = 0
+ for j in xrange(i, l):
+ next = base[pointer] + clist[j]
+ if next >= N or check[next] != pointer: break
+ id = value[next]
+ if id >= 0:
+ events[id] = events.get(id, 0) + 1
+ pointer = next
return events
-
-
View
33 ldig.py
@@ -11,7 +11,6 @@
import htmlentitydefs
import subprocess
import da
-sys.stdout = codecs.getwriter('utf-8')(sys.stdout)
class ldig(object):
@@ -343,26 +342,28 @@ def inference(param, labels, corpus, idlist, trie, options):
print "full regularization: %d / %d" % (m, N)
indexes = xrange(M)
for id in indexes:
- if id in events: param[id,] -= y * events[id]
+ prm = param[id]
+ pnl = penalties[id]
+ if id in events: prm -= y * events[id]
for j in xrange(K):
- w = param[id, j]
+ w = prm[j]
if w > 0:
- w1 = w - uk - penalties[id, j]
+ w1 = w - uk - pnl[j]
if w1 > 0:
- param[id, j] = w1
- penalties[id, j] += w1 - w
+ prm[j] = w1
+ pnl[j] += w1 - w
else:
- param[id, j] = 0
- penalties[id, j] -= w
+ prm[j] = 0
+ pnl[j] -= w
elif w < 0:
- w1 = w + uk - penalties[id, j]
+ w1 = w + uk - pnl[j]
if w1 < 0:
- param[id, j] = w1
- penalties[id, j] += w1 - w
+ prm[j] = w1
+ pnl[j] += w1 - w
else:
- param[id, j] = 0
- penalties[id, j] -= w
+ prm[j] = 0
+ pnl[j] -= w
else:
for id, freq in events.iteritems():
param[id,] -= y * freq
@@ -380,9 +381,7 @@ def likelihood(param, labels, trie, filelist, options):
corrects = numpy.zeros(K, dtype=int)
counts = numpy.zeros(K, dtype=int)
- label_map = dict()
- for i, label in enumerate(labels):
- label_map[label] = i
+ label_map = dict((x, i) for i, x in enumerate(labels))
n_available_data = 0
log_likely = 0.0
@@ -433,6 +432,8 @@ def generate_doublearray(file, features):
if __name__ == '__main__':
+ sys.stdout = codecs.getwriter('utf-8')(sys.stdout)
+
parser = optparse.OptionParser()
parser.add_option("-m", dest="model", help="model directory")
parser.add_option("--init", dest="init", help="initialize model", action="store_true")
View
61 test_da.py
@@ -0,0 +1,61 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import unittest
+import da
+
+class TestDoubleArray(unittest.TestCase):
+ def test1(self):
+ trie = da.DoubleArray(verbose=False)
+ trie.initialize(["cat"])
+ self.assertEqual(trie.N, 4)
+ self.assert_(trie.get("ca") is None)
+ self.assert_(trie.get("xxx") is None)
+ self.assertEqual(trie.get("cat"), 0)
+
+ def test2(self):
+ trie = da.DoubleArray()
+ trie.initialize(["cat", "dog"])
+ self.assertEqual(trie.N, 7)
+ self.assert_(trie.get("ca") is None)
+ self.assert_(trie.get("xxx") is None)
+ self.assertEqual(trie.get("cat"), 0)
+ self.assertEqual(trie.get("dog"), 1)
+
+ def test3(self):
+ trie = da.DoubleArray(verbose=False)
+ trie.initialize(["ca", "cat", "deer", "dog", "fox", "rat"])
+ print trie.base
+ print trie.check
+ print trie.value
+ self.assertEqual(trie.N, 17)
+ self.assert_(trie.get("c") is None)
+ self.assertEqual(trie.get("ca"), 0)
+ self.assertEqual(trie.get("cat"), 1)
+ self.assertEqual(trie.get("deer"), 2)
+ self.assertEqual(trie.get("dog"), 3)
+ self.assert_(trie.get("xxx") is None)
+
+ def test4(self):
+ trie = da.DoubleArray()
+ self.assertRaises(Exception, trie.initialize, ["cat", "ant"])
+
+ def test5(self):
+ trie = da.DoubleArray(verbose=False)
+ trie.initialize(["ca", "cat", "deer", "dog", "fox", "rat"])
+
+ r = trie.extract_features("")
+ self.assertEqual(len(r), 0)
+
+ r = trie.extract_features("cat")
+ self.assertEqual(len(r), 2)
+ self.assertEqual(r[0], 1)
+ self.assertEqual(r[1], 1)
+
+ r = trie.extract_features("deerat")
+ self.assertEqual(len(r), 2)
+ self.assertEqual(r[2], 1)
+ self.assertEqual(r[5], 1)
+
+unittest.main()
+

0 comments on commit 6a1c6b9

Please sign in to comment.
Something went wrong with that request. Please try again.