Skip to content

Commit

Permalink
fix: double aggregation on entities produced by ListEntityPlugin.
Browse files Browse the repository at this point in the history
  • Loading branch information
ltbringer committed Sep 25, 2021
1 parent 0995114 commit cd2b3a3
Showing 1 changed file with 7 additions and 15 deletions.
22 changes: 7 additions & 15 deletions dialogy/plugins/text/list_entity_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from typing import Any, Dict, List, Optional, Tuple

import pandas as pd
import pydash as py_
from tqdm import tqdm

from dialogy import constants as const
Expand Down Expand Up @@ -183,12 +182,11 @@ def get_entities(self, transcripts: List[str]) -> List[BaseEntity]:
"""
matches_on_transcripts = self._search(transcripts)
logger.debug(matches_on_transcripts)
entity_metadata = []
entities: List[BaseEntity] = []

for i, matches_on_transcript in enumerate(matches_on_transcripts):
for text, label, value, span in matches_on_transcript:
entity = {
entity_dict = {
"start": span[0],
"end": span[1],
"body": text,
Expand All @@ -204,18 +202,12 @@ def get_entities(self, transcripts: List[str]) -> List[BaseEntity]:
"values": [{"value": value}],
},
}
entity_metadata.append(entity)
entity_groups = py_.group_by(entity_metadata, lambda e: e["__group"])
logger.debug("entity groups:")
logger.debug(pformat(entity_groups))

for _, grouped_entities in entity_groups.items():
entity = sorted(grouped_entities, key=lambda e: e["alternative_index"])[0]
del entity["__group"]
entity["score"] = round(len(grouped_entities) / len(transcripts), 4)
entity_ = KeywordEntity.from_dict(entity)
entity_.add_parser(self).set_value()
entities.append(entity_)

del entity_dict["__group"]
entity_ = KeywordEntity.from_dict(entity_dict)
entity_.add_parser(self).set_value()
entities.append(entity_)

logger.debug("Parsed entities")
logger.debug(entities)

Expand Down

0 comments on commit cd2b3a3

Please sign in to comment.