Skip to content

Commit

Permalink
cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
undertherain committed Jan 22, 2019
1 parent 7e66ffb commit 654f5c7
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 18 deletions.
6 changes: 3 additions & 3 deletions vecto/benchmarks/relation_extraction/relation_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def getPrecision(pred_test, yTest, targetLabel):
targetLabelCount = 0
correctTargetLabelCount = 0

for idx in range(len(pred_test)):
if pred_test[idx] == targetLabel:
for idx, prediction in enumerate(pred_test):
if prediction == targetLabel:
targetLabelCount += 1

if pred_test[idx] == yTest[idx]:
if prediction == yTest[idx]:
correctTargetLabelCount += 1

if correctTargetLabelCount == 0:
Expand Down
2 changes: 1 addition & 1 deletion vecto/benchmarks/text_classification/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Text classification benchmark.
"""Text classification benchmark.
One of the pre-defined models is trained to convergence
to predict labels for text fragments in a provided dataset.
Expand Down
11 changes: 2 additions & 9 deletions vecto/benchmarks/text_classification/nlp_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
import collections
import io
# import collections
# import io

import numpy

import chainer
from chainer.backends import cuda


def split_text(text, char_based=False):
if char_based:
return list(text)
else:
return text.split()


def normalize_text(text):
return text.strip().lower()

Expand Down
12 changes: 10 additions & 2 deletions vecto/benchmarks/text_classification/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vecto.benchmarks.text_classification import nets
from vecto.benchmarks.text_classification import text_datasets
from vecto.benchmarks.text_classification import nlp_utils
from vecto.corpus.tokenization import word_tokenize_txt
from ..base import Benchmark


Expand Down Expand Up @@ -45,7 +46,11 @@ def predict(model, sentence):
model, vocab, setup = model
sentence = sentence.strip()
text = nlp_utils.normalize_text(sentence)
words = nlp_utils.split_text(text, char_based=setup['char_based'])
# words = nlp_utils.split_text(text, char_based=setup['char_based'])
if setup['char_based']:
words = list(text)
else:
words = word_tokenize_txt(text)
xs = nlp_utils.transform_to_array([words], vocab, with_label=False)
xs = nlp_utils.convert_seq(xs, device=-1, with_label=False) # todo use GPU
with chainer.using_config('train', False), chainer.no_backprop_mode():
Expand All @@ -61,7 +66,10 @@ def get_vectors(model, sentences):
for sentence in sentences:
sentence = sentence.strip()
text = nlp_utils.normalize_text(sentence)
words = nlp_utils.split_text(text, char_based=setup['char_based'])
if setup['char_based']:
words = list(text)
else:
words = word_tokenize_txt(text)
xs = nlp_utils.transform_to_array([words], vocab, with_label=False)
xs = nlp_utils.convert_seq(xs, device=-1, with_label=False) # todo use GPU
with chainer.using_config('train', False), chainer.no_backprop_mode():
Expand Down
9 changes: 7 additions & 2 deletions vecto/benchmarks/text_classification/text_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import chainer

from vecto.benchmarks.text_classification.nlp_utils import normalize_text
from vecto.benchmarks.text_classification.nlp_utils import split_text
from vecto.corpus.tokenization import word_tokenize_txt
# from vecto.benchmarks.text_classification.nlp_utils import split_text

# TODO: use vecto.corpus
from vecto.benchmarks.text_classification.nlp_utils import transform_to_array
Expand All @@ -30,7 +31,11 @@ def read_lines_separated(path, shrink=1, char_based=False):
continue
label, text = l.strip().split(None, 1)
label = int(label) % 2 # TODO: don't do this, implement shift
tokens = split_text(normalize_text(text), char_based)
text = normalize_text(text)
if char_based:
tokens = list(text)
else:
tokens = word_tokenize_txt(text)
dataset.append((tokens, label))
return dataset

Expand Down
2 changes: 1 addition & 1 deletion vecto/benchmarks/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def df_from_dir(path):
full_path = os.path.join(dirpath, filename)
try:
dfs.append(df_from_file(full_path))
except KeyError as e:
except KeyError:
logger.warning(f"error reading {full_path}")
dframe = pandas.concat(dfs, sort=True)
# print(dframe["experiment_setup.task"])
Expand Down

0 comments on commit 654f5c7

Please sign in to comment.