In [None]:
import spacy
from spacy import displacy
from spacy_llm.util import assemble
import spacy


nlp = spacy.load("en_core_web_md")

# actually just do an external config file

nlp = assemble("config.cfg")

In [None]:
from tqdm import tqdm
from youbot.store import Store


MESSAGES_COUNT = 10

docs = Store().get_archival_messages()  # z[:MESSAGES_COUNT]


docs_with_rels = []
for doc in tqdm(docs):
    docs_with_rels.append(nlp(doc))

In [None]:
doc = docs_with_rels[0]

# For NER visualization:
displacy.render(doc, style="ent", jupyter=True)

# For dependency visualization:
# displacy.render(doc, style='dep', jupyter=True)

In [None]:
from typing import Tuple
from spacy.tokens.span import Span


def get_kb_entity_id(ent: Span) -> Tuple[str, str]:
    return (ent.text, ent.label_)


# key: entity id
# value: set[fact]
kb_facts_by_entity = {}

# key: (entity id, entity id)
# value: set[fact]
kb_facts_by_relation = {}

kb_entity_counts = {}


for doc in docs_with_rels:
    entity_ids = [get_kb_entity_id(ent) for ent in doc.ents]

    for ent in doc.ents:
        id = get_kb_entity_id(ent)
        s = kb_facts_by_entity.get(id, set())
        s.add(doc.text)
        kb_facts_by_entity[id] = s

        kb_entity_counts[id] = kb_entity_counts.get(id, 0) + 1

    for rel in doc._.rel:
        dep_id = entity_ids[rel.dep]
        dest_id = entity_ids[rel.dest]
        relation = rel.relation

        if dep_id == dest_id:
            continue

        s = kb_facts_by_relation.get((dep_id, relation, dest_id), set())
        s.add(doc.text)
        kb_facts_by_relation[(dep_id, relation, dest_id)] = s


def get_entities_by_name(name: str):
    return [k for k in kb_entity_counts.keys() if k[0] == name]

In [None]:
names = {k[0] for k in kb_entity_counts.keys()}

clashing_entity_names = set()
for n in names:
    if len(get_entities_by_name(n)) > 1:
        clashing_entity_names.add(n)

In [None]:
relation_types = {}

for dep, type, dest in kb_facts_by_relation.keys():
    key = (dep, dest)
    s = relation_types.get(key, set())
    s.add(type)
    relation_types[key] = s

clashing_relations = {k: v for k, v in relation_types.items() if len(v) > 1}

In [None]:
print(clashing_entity_names)
print(clashing_relations)


# need to restrict which relations can be between which entities,
# perhaps zero in on which entities can have which relations
# summarize relation?

# organize by entity ID, track frequency of both relations and entities
# extract incoherent relations: 1 to 1 relations, invalid relationships
# for these, if it is close, flag