Skip to content

Commit

Permalink
Merge pull request #1 from piskvorky/hashdictionary
Browse files Browse the repository at this point in the history
Hashdictionary
  • Loading branch information
strongh committed Aug 11, 2012
2 parents 93e1389 + a9334e7 commit 42b6c44
Showing 1 changed file with 19 additions and 20 deletions.
39 changes: 19 additions & 20 deletions gensim/corpora/hashdictionary.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2010 Radim Rehurek <radimrehurek@seznam.cz>
# Copyright (C) 2012 Homer Strong, Radim Rehurek
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html


Expand All @@ -13,8 +13,7 @@

from __future__ import with_statement

import codecs # for unicode output
import zlib
import codecs
import logging
import itertools
import UserDict
Expand All @@ -29,19 +28,19 @@ class RestrictedHash:
"""
Mimics a dict, using a restricted hash.
"""
def __init__(self, key_range=32000, hash=zlib.adler32, maintain_reverse=True, debug=False):
def __init__(self, key_range=32000, myhash=hash, maintain_reverse=True, debug=False):
"""
Initialize a RestrictedHash with given key range and hash function.
Initialize a RestrictedHash with given key range and hash function.
maintain_reverse determines whether to keep a dict mapping the inverse hash function..
"""
self.key_range = key_range
self.hash = hash
self.myhash = myhash
self.debug = debug
self.maintain_reverse = maintain_reverse
self.reverse = {}
self.debug_reverse = {}

def __len__(self):
"""
Reports the size of the domain of possible keys.
Expand All @@ -66,7 +65,7 @@ def __getitem__(self, key):
self.reverse[h] = key
if self.debug:
if self.debug_reverse.get(h, None):
self.debug_reverse[h] = self.debug_reverse[h].add(key)
self.debug_reverse[h] = self.debug_reverse[h].add(key)
else:
self.debug_reverse[h] = set([key])
return h
Expand All @@ -75,22 +74,22 @@ def itervalues(self):
return self.reverse.keys()

def iteritems(self):
return dict((v,k) for k, v in self.reverse.iteritems())
return dict((v, k) for k, v in self.reverse.iteritems())

def values(self):
return self.reverse.keys()

def keys(self):
return self.reverse.values()

def subset(self, key_subset):
self.reverse = dict((k,v) for k, v in self.reverse.iteritems() if k in key_subset)
self.reverse = dict((k, v) for k, v in self.reverse.iteritems() if k in key_subset)

def restricted_hash(self, key):
"""Calculates the hash mod the range"""
return self.hash(key) % self.key_range
return self.myhash(key) % self.key_range




class HashDictionary(utils.SaveLoad, UserDict.DictMixin):
"""
Expand All @@ -100,8 +99,8 @@ class HashDictionary(utils.SaveLoad, UserDict.DictMixin):
bag-of-words representation: a list of (word_id, word_frequency) 2-tuples
"""
def __init__(self, documents=None, id_range=32000, hash=zlib.adler32, debug=False):
self.token2id = RestrictedHash(key_range=id_range, hash=hash, debug=debug)
def __init__(self, documents=None, id_range=32000, myhash=hash, debug=False):
self.token2id = RestrictedHash(key_range=id_range, myhash=myhash, debug=debug)
self.id2token = self.token2id.reverse # reverse mapping for token2id; only formed on request, to save memory
self.dfs = {} # document frequencies: tokenId -> in how many documents this token appeared
self.num_docs = 0 # number of documents processed
Expand Down Expand Up @@ -129,7 +128,7 @@ def __len__(self):

def __str__(self):
return ("HashDictionary(%i id range)" % len(self))


@staticmethod
def from_documents(documents):
Expand Down Expand Up @@ -169,7 +168,7 @@ def doc2bow(self, document, allow_update=False, return_missing=False):
by one.
If `allow_update` is **not** set, this function is `const`, aka read-only.
"""
result = {}
missing = {}
Expand Down Expand Up @@ -224,8 +223,8 @@ def filter_extremes(self, no_below=5, no_above=0.5, keep_n=100000):

if keep_n is not None:
good_ids = good_ids[:keep_n]
self.token2id.subset(key_subset=good_ids)
self.dfs = dict((tokenid, freq) for tokenid, freq in self.dfs.iteritems()
self.token2id.subset(key_subset=good_ids)
self.dfs = dict((tokenid, freq) for tokenid, freq in self.dfs.iteritems()
if tokenid in good_ids)
logger.info("keeping %i tokens which were in no less than %i and no more than %i (=%.1f%%) documents" %
(len(good_ids), no_below, no_above_abs, 100.0 * no_above))
Expand All @@ -239,7 +238,7 @@ def save_as_text(self, fname):
Note: use `save`/`load` to store in binary format instead (pickle).
"""
logger.info("saving hashdictionary mapping to %s" % fname)
with codecs.open(fname, 'wb',encoding='utf-8') as fout:
with codecs.open(fname, 'wb', encoding='utf-8') as fout:
for token, tokenid in sorted(self.token2id.iteritems()):
fout.write("%i\t%s\t%i\n" % (tokenid, token, self.dfs.get(tokenid, 0)))

Expand Down

0 comments on commit 42b6c44

Please sign in to comment.