Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhwcmhf committed Mar 13, 2019
1 parent a747d6b commit d76fd87
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
18 changes: 9 additions & 9 deletions cotk/metric/metric.py
Expand Up @@ -19,7 +19,7 @@ def __init__(self):
class _PrecisionRecallMetric(MetricBase):
'''Base class for precision recall metrics. This is an abstract class.
Arguments:
dataloader (:class:cotk.BasicLanguageGeneration): A language generation dataloader.
dataloader (:class:cotk.GenerationBase): A language generation dataloader.
reference_allvocabs_key (str): Reference sentences are passed to :func:`forward` by
``data[reference_allvocabs_key]``. Default: ``resp_allvocabs``.
gen_key (str): Sentences generated by model are passed to :func:.forward by
Expand Down Expand Up @@ -170,7 +170,7 @@ class PerplexityMetric(MetricBase):
'''Metric for calculating perplexity.
Arguments:
dataloader (:class:cotk.BasicLanguageGeneration): A language generation dataloader.
dataloader (:class:cotk.GenerationBase): A language generation dataloader.
reference_allvocabs_key (str): Reference sentences with all vocabs
are passed to :func:`forward` by ``data[reference_allvocabs_key]``.
Default: ``resp_allvocabs``.
Expand Down Expand Up @@ -300,7 +300,7 @@ class MultiTurnPerplexityMetric(MetricBase):
'''Metric for calculating multi-turn perplexity.
Arguments:
dataloader (:class:cotk.BasicLanguageGeneration): A language generation dataloader.
dataloader (:class:cotk.GenerationBase): A language generation dataloader.
reference_allvocabs_key (str): Reference sentences with all vocabs
are passed to :func:`forward` by ``data[reference_allvocabs_key]``.
Default: ``sent_allvocabs``.
Expand Down Expand Up @@ -380,7 +380,7 @@ class BleuCorpusMetric(MetricBase):
'''Metric for calculating BLEU.
Arguments:
dataloader (:class:cotk.BasicLanguageGeneration): A language generation dataloader.
dataloader (:class:cotk.GenerationBase): A language generation dataloader.
reference_allvocabs_key (str): Reference sentences with all vocabs
are passed to :func:.forward by ``data[reference_allvocabs_key]``.
Default: ``resp``.
Expand Down Expand Up @@ -499,7 +499,7 @@ class FwBwBleuCorpusMetric(MetricBase):
'''Metric for calculating BLEU.
Arguments:
dataloader (:class:cotk.BasicLanguageGeneration): A language generation dataloader.
dataloader (:class:cotk.GenerationBase): A language generation dataloader.
reference_test_key (str): Reference sentences with all vocabs in test data
are passed to :func:.forward by ``data[reference_test_key]``.
gen_key (str): Sentences generated by model are passed to :func:.forward by
Expand Down Expand Up @@ -586,7 +586,7 @@ class MultiTurnBleuCorpusMetric(MetricBase):
'''Metric for calculating multi-turn BLEU.
Arguments:
dataloader (:class:cotk.BasicLanguageGeneration): A language generation dataloader.
dataloader (:class:cotk.GenerationBase): A language generation dataloader.
reference_allvocabs_key (str): Reference sentences with all vocabs are passed to
:func:`forward` by ``data[reference_allvocabs_key]``.
Default: ``reference_allvocabs``.
Expand Down Expand Up @@ -653,7 +653,7 @@ class SingleTurnDialogRecorder(MetricBase):
'''A metric-like class for recording generated sentences and references.
Arguments:
dataloader (:class:cotk.BasicLanguageGeneration): A language generation dataloader.
dataloader (:class:cotk.GenerationBase): A language generation dataloader.
post_allvocabs_key (str): Dialog post are passed to :func:`forward`
by ``data[post_allvocabs_key]``.
Default: ``post``.
Expand Down Expand Up @@ -714,7 +714,7 @@ class MultiTurnDialogRecorder(MetricBase):
'''A metric-like class for recording generated sentences and references.
Arguments:
dataloader (:class:cotk.BasicLanguageGeneration): A language generation dataloader.
dataloader (:class:cotk.GenerationBase): A language generation dataloader.
context_allvocabs_key (str): Dialog context are passed to :func:`forward` by
``data[context_key]``. Default: ``context_allvocabs``.
reference_allvocabs_key (str): Dialog references with all vocabs
Expand Down Expand Up @@ -791,7 +791,7 @@ class LanguageGenerationRecorder(MetricBase):
'''A metric-like class for recorder BLEU.
Arguments:
dataloader (:class:cotk.BasicLanguageGeneration): A language generation dataloader.
dataloader (:class:cotk.GenerationBase): A language generation dataloader.
gen_key (str): Sentences generated by model are passed to :func:`forward` by
``data[gen_key]``. Default: ``gen``.
'''
Expand Down
4 changes: 2 additions & 2 deletions tests/metric/test_metric.py
Expand Up @@ -11,7 +11,7 @@
SingleTurnDialogRecorder, MultiTurnDialogRecorder, LanguageGenerationRecorder, HashValueRecorder, \
MetricChain
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction
from cotk.dataloader import BasicLanguageGeneration, MultiTurnDialog
from cotk.dataloader import GenerationBase, MultiTurnDialog

def setup_module():
random.seed(0)
Expand All @@ -24,7 +24,7 @@ def test_bleu_bug():
corpus_bleu(ref, gen, smoothing_function=SmoothingFunction().method7)


class FakeDataLoader(BasicLanguageGeneration):
class FakeDataLoader(GenerationBase):
def __init__(self):
self.all_vocab_list = ['<pad>', '<unk>', '<go>', '<eos>', \
'what', 'how', 'here', 'do', 'as', 'can', 'to']
Expand Down

0 comments on commit d76fd87

Please sign in to comment.