Skip to content

Commit

Permalink
make language modeling name in camel case
Browse files Browse the repository at this point in the history
  • Loading branch information
bakarov committed Jul 31, 2018
1 parent 46f7e13 commit a6b5bc6
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tests/benchmarks/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from vecto.benchmarks import text_classification
from vecto.benchmarks.similarity import Similarity
from vecto.benchmarks.sequence_labeling import SequenceLabeling
from vecto.benchmarks.language_modeling import Language_modeling
from vecto.benchmarks.language_modeling import LanguageModeling
from vecto.benchmarks.analogy import visualize as analogy_visualize
from vecto.benchmarks.similarity import visualize as similarity_visualize
from vecto.benchmarks.text_classification import TextClassification
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_language_modeling(self):
embs = load_from_dir("./tests/data/embeddings/text/plain_with_file_header")

for method in ['lr', '2FFNN', 'rnn', 'lstm']:
sequence_labeling = Language_modeling(test=True, method=method)
sequence_labeling = LanguageModeling(test=True, method=method)
results = sequence_labeling.get_result(embs)
print(results)

Expand Down
2 changes: 1 addition & 1 deletion vecto/benchmarks/language_modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .language_modeling import Language_modeling
from .language_modeling import LanguageModeling
5 changes: 4 additions & 1 deletion vecto/benchmarks/language_modeling/language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def compute_perplexity(result):
result['val_perplexity'] = np.exp(result['validation/main/loss'])


class Language_modeling(Benchmark):
class LanguageModeling(Benchmark):

def __init__(self, normalize=True, window_size=5, method='lstm', test=False): # 'lr', '2FFNN', 'lstm'
self.normalize = normalize
Expand All @@ -244,6 +244,9 @@ def __init__(self, normalize=True, window_size=5, method='lstm', test=False): #
self.out = tmpBasePath
self.resume = ''

def read_test_set(self, path):
pass

def get_result(self, embeddings):

self.unit = embeddings.matrix.shape[1]
Expand Down

0 comments on commit a6b5bc6

Please sign in to comment.