Skip to content

Coref: Bridging and SplitAnte #76

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions udapi/block/corefud/printcluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from udapi.core.block import Block
from collections import Counter

class PrintCluster(Block):
"""Block corefud.PrintCluster prints all mentions of a given cluster."""

def __init__(self, cluster_id, **kwargs):
super().__init__(**kwargs)
self.cluster_id = cluster_id

def process_document(self, doc):
cluster = doc.coref_clusters.get(self.cluster_id)
if cluster and cluster.mentions:
print(f"Coref cluster {self.cluster_id} has {len(cluster.mentions)} mentions in document {doc.meta['docname']}:")
counter = Counter()
for mention in cluster.mentions:
counter[' '.join([w.form for w in mention.words])] += 1
for form, count in counter.most_common():
print(f"{count:4}: {form}")
129 changes: 112 additions & 17 deletions udapi/core/coref.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Classes for handling coreference."""
import re
import functools
import collections
import logging

@functools.total_ordering
class CorefMention(object):
"""Class for representing a mention (instance of an entity)."""
__slots__ = ['_head', '_cluster', '_bridging', '_words']
__slots__ = ['_head', '_cluster', '_bridging', '_words', 'misc']

def __init__(self, head, cluster=None):
self._head = head
Expand All @@ -15,6 +16,7 @@ def __init__(self, head, cluster=None):
cluster._mentions.append(self)
self._bridging = None
self._words = []
self.misc = None

def __lt__(self, other):
"""Does this mention precedes (word-order wise) the `other` mention?
Expand Down Expand Up @@ -58,6 +60,8 @@ def cluster(self, new_cluster):

@property
def bridging(self):
if not self._bridging:
self._bridging = BridgingLinks(self)
return self._bridging

# TODO add/edit bridging
Expand Down Expand Up @@ -97,13 +101,13 @@ def span(self, new_span):

class CorefCluster(object):
"""Class for representing all mentions of a given entity."""
__slots__ = ['_cluster_id', '_mentions', 'cluster_type', '_split_ante']
__slots__ = ['_cluster_id', '_mentions', 'cluster_type', 'split_ante']

def __init__(self, cluster_id, cluster_type=None):
self._cluster_id = cluster_id
self._mentions = []
self.cluster_type = cluster_type
self._split_ante = None
self.split_ante = []

@property
def cluster_id(self):
Expand Down Expand Up @@ -149,20 +153,88 @@ def create_mention(self, head=None, mention_words=None, mention_span=None):
mention.span = mention_span
return mention

@property
def split_ante(self):
return self._split_ante

# TODO add/edit split_ante

# TODO adapt depending on how mention.bridging is implemented (callable list subclass)
# TODO or should we create a BridgingLinks instance with a fake src_mention?
def all_bridging(self):
for m in self._mentions:
if m._bridging:
for b in m._bridging:
yield b


BridgingLink = collections.namedtuple('BridgingLink', 'target relation')


class BridgingLinks(collections.abc.MutableSequence):
"""BridgingLinks class serves as a list of BridgingLink tuples with additional methods.

Example usage:
>>> bl = BridgingLinks(src_mention) # empty links
>>> bl = BridgingLinks(src_mention, [(c12, 'Part'), (c56, 'Subset')]) # from a list of tuples
>>> bl = BridgingLinks(src_mention, 'c12:Part,c56:Subset', clusters) # from a string
>>> for cluster, relation in bl:
>>> print(f"{bl.src_mention} ->{relation}-> {cluster.cluster_id}")
>>> print(str(bl)) # c12:Part,c56:Subset
>>> bl('Part').targets == [c12]
>>> bl('Part|Subset').targets == [c12, c56]
>>> bl.append((c89, 'Funct'))
"""
def __init__(self, src_mention, value=None, clusters=None):
self.src_mention = src_mention
self._data = []
if value is not None:
if isinstance(value, str):
if clusters is None:
raise ValueError('BridgingClusters: clusters must be provided if initializing with a string')
try:
self._from_string(value, clusters)
except ValueError:
logging.error(f"Problem when parsing {value} in {src_mention.words[0]}:\n")
raise
elif isinstance(value, collections.abc.Sequence):
for v in value:
self._data.append(BridgingLink(v[0], v[1]))
super().__init__()

def __getitem__(self, key):
return self._data[key]

def __len__(self):
return len(self._data)

# TODO delete backlinks of old links, dtto for SplitAnte
def __setitem__(self, key, new_value):
self._data[key] = BridgingLink(new_value[0], new_value[1])

def __delitem__(self, key):
del self._data[key]

def insert(self, key, new_value):
self._data.insert(key, BridgingLink(new_value[0], new_value[1]))

def __str__(self):
return ','.join(f'{l.target._cluster_id}:{l.relation}' for l in self)

def _from_string(self, string, clusters):
self._data.clear()
for link_str in string.split(','):
target, relation = link_str.split(':')
self._data.append(BridgingLink(clusters[target], relation))

def __call__(self, relations_re=None):
"""Return a subset of links contained in this list as specified by the args.
Args:
relations: only links with a relation matching this regular expression will be returned
"""
if relations_re is None:
return self
return Links(self.src_mention, [l for l in self._data if re.match(relations_re, l.relation)])

@property
def targets(self):
"""Return a list of the target clusters (without relations)."""
return [link.target for link in self._data]


def create_coref_cluster(head, cluster_id=None, cluster_type=None, **kwargs):
clusters = head.root.bundle.document.coref_clusters
if not cluster_id:
Expand All @@ -180,7 +252,7 @@ def create_coref_cluster(head, cluster_id=None, cluster_type=None, **kwargs):

def load_coref_from_misc(doc):
clusters = {}
for node in doc.nodes:
for node in doc.nodes_and_empty:
index, index_str = 0, ""
cluster_id = node.misc["ClusterId"]
if not cluster_id:
Expand All @@ -194,14 +266,32 @@ def load_coref_from_misc(doc):
mention = CorefMention(node, cluster)
if node.misc["MentionSpan" + index_str]:
mention.span = node.misc["MentionSpan" + index_str]
else:
mention.words = [node]
cluster_type = node.misc["ClusterType" + index_str]
if cluster_type is not None:
if cluster.cluster_type is not None and cluster_type != cluster.cluster_type:
logging.warning(f"cluster_type mismatch in {node}: {cluster.cluster_type} != {cluster_type}")
cluster.cluster_type = cluster_type
# TODO deserialize Bridging and SplitAnte
mention._bridging = node.misc["Bridging" + index_str]
cluster._split_ante = node.misc["SplitAnte" + index_str]

bridging_str = node.misc["Bridging" + index_str]
if bridging_str:
mention._bridging = BridgingLinks(mention, bridging_str, clusters)

split_ante_str = node.misc["SplitAnte" + index_str]
if split_ante_str:
split_antes = []
for ante_str in split_ante_str.split('+'):
if ante_str in clusters:
split_antes.append(clusters[ante_str])
else:
# split cataphora, e.g. "We, that is you and me..."
cluster = CorefCluster(ante_str)
clusters[ante_str] = cluster
split_antes.append(cluster)
cluster.split_ante = split_antes

mention.misc = node.misc["MentionMisc" + index_str]
index += 1
index_str = f"[{index}]"
cluster_id = node.misc["ClusterId" + index_str]
Expand All @@ -212,7 +302,7 @@ def store_coref_to_misc(doc):
if not doc._coref_clusters:
return
attrs = ("ClusterId", "MentionSpan", "ClusterType", "Bridging", "SplitAnte")
for node in doc.nodes:
for node in doc.nodes_and_empty:
for key in list(node.misc):
if any(re.match(attr + r'(\[\d+\])?$', key) for attr in attrs):
del node.misc[key]
Expand All @@ -235,8 +325,13 @@ def store_coref_to_misc(doc):
head.misc["ClusterId" + index_str] = cluster.cluster_id
head.misc["MentionSpan" + index_str] = mention.span
head.misc["ClusterType" + index_str] = cluster.cluster_type
head.misc["Bridging" + index_str] = mention.bridging
head.misc["SplitAnte" + index_str] = cluster.split_ante
if mention._bridging:
head.misc["Bridging" + index_str] = str(mention.bridging)
if cluster.split_ante:
serialized = '+'.join((c.cluster_id for c in cluster.split_ante))
head.misc["SplitAnte" + index_str] = serialized
if mention.misc:
head.misc["MentionMisc" + index_str] = mention.misc


def span_to_nodes(root, span):
Expand Down
10 changes: 9 additions & 1 deletion udapi/core/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,20 @@ def trees(self):

@property
def nodes(self):
"""An iterator over all nodes in the document."""
"""An iterator over all nodes (excluding empty nodes) in the document."""
for bundle in self:
for tree in bundle:
for node in tree._descendants:
yield node

@property
def nodes_and_empty(self):
"""An iterator over all nodes and empty nodes in the document."""
for bundle in self:
for tree in bundle:
for node in tree.descendants_and_empty:
yield node

def draw(self, **kwargs):
"""Pretty print the trees using TextModeTrees."""
TextModeTrees(**kwargs).run(self)
Expand Down