Skip to content
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
115 lines (100 sloc) 3.9 KB
import lucene
import pickle, sys, datetime
import numpy as np
from scipy.spatial import distance
THREADS_NUM = int(sys.argv[1]) # NOTE not used
EXPERIM_ID = sys.argv[2]
EXPERIM_FILE = sys.argv[3]
DICTIONARY_PATH = sys.argv[4]
VECTORS_PATH = sys.argv[5]
NONVEC_SURROGATE_DISTANCE = float(sys.argv[6])
# Load the reference dictionary.
dict_str = ''
copyright_end = False
with open(DICTIONARY_PATH) as inp:
for line in inp:
line = line.strip()
if '</COPYRIGHT>' in line:
copyright_end = True
if not copyright_end:
dict_str += ' ' + line.split('\t')[0]
# Load word vectors.
word_to_idx = {} # ie. indices of vectors
idx_to_word = {}
first_line = True
word_n, vecs_dim = 0, 0
with open(VECTORS_PATH) as vecs_file:
for line in vecs_file:
if first_line:
vecs_dim = int(line.split(' ')[1])
first_line = False
# Read word forms.
word = line.split(' ')[0].lower()
word_to_idx[word] = word_n
idx_to_word[word_n] = word
word_n += 1
word_vecs = np.loadtxt(VECTORS_PATH, encoding="utf-8",
dtype=np.float32, comments=None,
skiprows=1, usecols=tuple(range(1, vecs_dim+1)))
# Load the test corpus.
test_err_objs = None
with open('test_set_{}.pkl'.format(EXPERIM_ID), 'rb') as pkl:
test_err_objs = pickle.load(pkl)
test_samples_count = len(test_err_objs)
# Java imports:
from import PlainTextDictionary, SpellChecker
# boilerplate for setting up spellchecking:
from import StringReader
from import RAMDirectory
from org.apache.lucene.index import IndexWriterConfig
from org.apache.lucene.analysis.core import KeywordAnalyzer
from import LevensteinDistance # corrected to Levenshtein in later versions
# Start JVM for Lucene.
# Set up Lucene spellchecking.
dict_reader = StringReader(dict_str)
dictionary = PlainTextDictionary(dict_reader)
ramdir = RAMDirectory()
spellchecker = SpellChecker(ramdir)
spellchecker.indexDictionary(dictionary, IndexWriterConfig(KeywordAnalyzer()), True)
# Set up Lucene distance computation.
distance_check = LevensteinDistance()
# Run the word correction test.
def word_vec(word):
return word_vecs[word_to_idx[word]]
def correct_word(word):
candidates = spellchecker.suggestSimilar(word, 10)
# Add edit distance information (we get this as a Lucene similarity measure in [0, 1]).
candidates = [(cand, 1.0-distance_check.getDistance(word, cand)) for cand in candidates]
# Add vector distance information.
candidates = [(cand, ed_dist, distance.cosine(word_vec(cand), word_vec(word))
if (cand in word_to_idx and word in word_to_idx)
for (cand, ed_dist) in candidates]
candidates.sort(key=lambda x: x[1]+x[2])
if len(candidates) > 0:
return candidates[0][0]
return ''
good, bad = [], []
counter = 0
with open('Vector_distance_corrections_{}.tab'.format(EXPERIM_ID), 'w+') as corrs_file:
for (sample_n, err_obj) in enumerate(test_err_objs):
print('{}/{}'.format(sample_n, test_samples_count), end='\r') # overwrite the number
error = err_obj['error']
true_correction = err_obj['correction']
correction = correct_word(error)
print('{}\t{}'.format(error, correction), file=corrs_file)
if correction == true_correction:
print() # line feed
with open(EXPERIM_FILE, 'a') as res_file:
timestamp ='%d-%m-%Y_%H-%M-%S')
print('Vector distance ({})'.format(timestamp), file=res_file)
print('Accuracy: {}'.format(len(good)/len(test_err_objs)), file=res_file)
You can’t perform that action at this time.