From faf1207ecfe46876d1e1d34b77ff710933113b4b Mon Sep 17 00:00:00 2001 From: HUANG Fei Date: Wed, 30 Jan 2019 21:00:05 +0800 Subject: [PATCH] fix bugs in metric --- contk/metric/metric.py | 6 +++--- tests/metric/tes.log | 15 --------------- tests/{dataloader => metric}/test_metric.py | 2 +- 3 files changed, 4 insertions(+), 19 deletions(-) delete mode 100644 tests/metric/tes.log rename tests/{dataloader => metric}/test_metric.py (99%) diff --git a/contk/metric/metric.py b/contk/metric/metric.py index 18da861b..8c19f84b 100644 --- a/contk/metric/metric.py +++ b/contk/metric/metric.py @@ -315,13 +315,13 @@ class MultiTurnDialogRecorder(MetricBase): Arguments: dataloader (DataLoader): A dataloader for translating index to sentences. context_key (str): Dialog context are passed to :func:`forward` by ``data[context_key]``. - Default: ``post``. + Default: ``context``. reference_key (str): Dialog reference are passed to :func:`forward` by ``data[reference_key]``. - Default: ``resp``. + Default: ``reference``. gen_key (str): Sentences generated by model are passed to :func:`forward` by ``data[gen_key]``. Default: ``gen``. ''' - def __init__(self, dataloader, context_key="content", reference_key="reference", gen_key="gen"): + def __init__(self, dataloader, context_key="context", reference_key="reference", gen_key="gen"): super().__init__() self.dataloader = dataloader self.context_key = context_key diff --git a/tests/metric/tes.log b/tests/metric/tes.log deleted file mode 100644 index b8b641b7..00000000 --- a/tests/metric/tes.log +++ /dev/null @@ -1,15 +0,0 @@ -============================= test session starts ============================= -platform win32 -- Python 3.6.6, pytest-4.1.1, py-1.7.0, pluggy-0.8.1 -rootdir: E:\Tsinghua\Research\contk, inifile: -plugins: cov-2.6.1 -collected 0 items / 1 errors - -=================================== ERRORS ==================================== -________________ ERROR collecting tests/metric/test_metric.py _________________ -test_metric.py:527: in - class TestLanguageGenerationRecorder(): -test_metric.py:535: in TestLanguageGenerationRecorder - @pytest.mark.parametrize('argument, shape, type', language_generation_test_parameter) -E NameError: name 'language_generation_test_parameter' is not defined -!!!!!!!!!!!!!!!!!!! Interrupted: 1 errors during collection !!!!!!!!!!!!!!!!!!! -=========================== 1 error in 0.84 seconds =========================== diff --git a/tests/dataloader/test_metric.py b/tests/metric/test_metric.py similarity index 99% rename from tests/dataloader/test_metric.py rename to tests/metric/test_metric.py index e0cee6c9..de0f0d85 100644 --- a/tests/dataloader/test_metric.py +++ b/tests/metric/test_metric.py @@ -517,7 +517,7 @@ def test_close(self, argument, shape, type, batch_len, gen_len): # 'equal' or 'unequal' # 0, 1 dataloader = FakeDataLoader() - context_key, reference_key, gen_key = ('post', 'resp', 'gen') \ + context_key, reference_key, gen_key = ('context', 'reference', 'gen') \ if argument == 'default' else ('ck', 'rk', 'gk') data = dataloader.get_data(context_key=context_key, reference_key=reference_key, gen_key=gen_key, \ multi_turn=True, to_list=(type == 'list'), pad=(shape == 'pad'), \