In [22]:
import collections
from copy import copy
import jsonlines
import matplotlib.pyplot as plt
import numpy
from tqdm import tqdm, trange
import random

In [2]:
with jsonlines.open("./moviecoref/wl-coref/data/english_train.jsonlines") as reader:
    train_data = list(reader)

with jsonlines.open("./moviecoref/wl-coref/data/english_development.jsonlines") as reader:
    dev_data = list(reader)

with jsonlines.open("./moviecoref/wl-coref/data/english_test.jsonlines") as reader:
    test_data = list(reader)

all_data = train_data + dev_data + test_data

In [3]:
print(f"{len(train_data)} train, {len(dev_data)} dev, {len(test_data)} test, {len(all_data)} total")

2802 train, 343 dev, 348 test, 3493 total


In [4]:
n_tokens = 0
n_token_mentions = 0
n_span_mentions = 0
n_entities = 0
n_docs = len(all_data)

for doc in all_data:
    n_tokens += len(doc["cased_words"])
    n_token_mentions += sum(end - begin for cluster in doc["clusters"] for begin, end in cluster)
    n_span_mentions += sum(len(cluster) for cluster in doc["clusters"])
    n_entities += len(doc["clusters"])

print(f"{n_tokens} tokens, {n_token_mentions} token mentions, {n_span_mentions} span mentions, {n_entities} entities")
print(f"{n_token_mentions/n_tokens:.2f} token mentions/token, {n_span_mentions/n_tokens:.2f} span mentions/token, {n_entities/n_tokens:.2f} entities/token")
print(f"{n_tokens/n_docs:.2f} tokens/docs, {n_entities/n_docs:.2f} entities/doc")

1631995 tokens, 445834 token mentions, 194480 span mentions, 44221 entities
0.27 token mentions/token, 0.12 span mentions/token, 0.03 entities/token
467.22 tokens/docs, 12.66 entities/doc


In [5]:
n_token_mentions_per_token = []

for doc in all_data:
    n_token_mentions = numpy.zeros(len(doc["cased_words"]), dtype=int)
    for cluster in doc["clusters"]:
        for begin, end in cluster:
            n_token_mentions[begin: end] += 1
    n_token_mentions_per_token.extend(n_token_mentions.tolist())

n_token_mentions_per_token_distribution = collections.Counter(n_token_mentions_per_token)

In [6]:
n_non_overlapping_span_mentions = 0

for doc in all_data:
    for cluster in doc["clusters"]:
        cluster = sorted(cluster, key=lambda span: span[1])
        if len(cluster) > 1:
            n, i, end = 1, 1, cluster[0][1]
            while i < len(cluster):
                if end <= cluster[i][0]:
                    n += 1
                    end = cluster[i][1]
                i += 1
            n_non_overlapping_span_mentions += n
        else:
            n_non_overlapping_span_mentions += 1

print(f"{n_non_overlapping_span_mentions} non overlapping span mentions ({100*n_non_overlapping_span_mentions/n_span_mentions:.2f}%)")

193175 non overlapping span mentions (99.33%)


In [7]:
def permutation_traversal(path, pathlen, nchoices, allpaths):
    if pathlen == 0:
        allpaths.append(path)
    else:
        for choice in range(nchoices[len(nchoices) - pathlen]):
            newpath = copy(path)
            newpath.append(choice)
            permutation_traversal(newpath, pathlen - 1, nchoices, allpaths)

def permutations(nchoices):
    all_possible_sequences = []
    permutation_traversal([], len(nchoices), nchoices, all_possible_sequences)
    return all_possible_sequences

In [14]:
class Permutation:

    def __init__(self, num_choices):
        self.n = len(num_choices)
        self.index = numpy.zeros(self.n, dtype=int)
        self.maxindex = num_choices
        self.products = numpy.ones(self.n, dtype=int)
        for i, m in enumerate(reversed(num_choices)):
            if i:
                self.products[self.n - 1 - i] = self.products[self.n - i] * m
            else:
                self.products[self.n - 1] = m

    def setindex(self, sequence_index):
        if sequence_index >= self.products[0]:
            raise IndexError(f"sequence_index should be less than {self.products[0]}")
        
        for i in range(self.n - 1):
            self.index[i] = sequence_index // self.products[i + 1]
            sequence_index = sequence_index % self.products[i + 1]
            if i == self.n - 2:
                self.index[i + 1] = sequence_index
    
    def __iter__(self):
        self.index = numpy.zeros(self.n, dtype=int)
        self.digit = self.n
        return self
    
    def __next__(self):
        if self.digit == self.n:
            self.digit -= 1
            return self.index

        if self.digit == -1:
            raise StopIteration
        
        self.index[self.digit] += 1
        self.index[self.digit + 1:] = 0
        
        if self.digit < self.n - 1:
            self.digit = self.n - 1
        else:
            while self.index[self.digit] == self.maxindex[self.digit] - 1 and self.digit >= 0:
                self.digit -= 1
        
        return self.index
    
    def random(self, n_items):
        sequence_indices = numpy.random.randint(0, self.products[0], n_items)
        for sequence_index in sequence_indices:
            self.setindex(sequence_index)
            yield self.index

In [15]:
permutation = Permutation([2, 3, 5])

In [16]:
permutation.n, permutation.index, permutation.maxindex, permutation.products

(3, array([0, 0, 0]), [2, 3, 5], array([30, 15,  5]))

In [17]:
for index in permutation.random(5):
    print(index)

[1 2 4]
[0 2 2]
[0 2 4]
[0 1 2]
[0 2 1]


In [29]:
def cluster_cover_exists(corefcluster):
    # corefcluster is a list of clusters
    # cluster is a list of spans
    # span is a ordered pair of integers

    for _ in range(1000000):
        spans = [cluster[random.randint(0, len(cluster) - 1)] for cluster in corefcluster]
        spans = sorted(spans, key=lambda span: span[1])

        for i in range(len(spans) - 1):
            if spans[i + 1][0] < spans[i][1]:
                # spans intersect
                break
        else:
            # none of the spans intersect with each other
            return True
    
    return False

In [30]:
indexes = []

for i in trange(len(all_data)):
    doc = all_data[i]
    if doc["clusters"] and not cluster_cover_exists(doc["clusters"]):
        print(i)
        indexes.append(i)

print(f"{len(indexes)} docs don't have cluster cover")

 54%|█████▍    | 1891/3493 [00:00<00:00, 18908.98it/s]

2127


 68%|██████▊   | 2364/3493 [00:39<00:25, 45.07it/s]   

2363


 77%|███████▋  | 2694/3493 [00:55<00:21, 36.77it/s]

2693


 78%|███████▊  | 2726/3493 [01:08<00:29, 26.09it/s]

2725


 87%|████████▋ | 3031/3493 [01:32<00:22, 20.50it/s]

3030


100%|██████████| 3493/3493 [01:37<00:00, 35.76it/s]

3073
6 docs don't have cluster cover





In [34]:
for index in indexes:
    corefcluster = all_data[index]["clusters"]
    prod = 1
    for cluster in corefcluster:
        prod = prod * len(cluster)
    print(f"{index}: {prod} choices, {type(prod)}")

2127: 128 choices, <class 'int'>
2363: 1217299827686532710400 choices, <class 'int'>
2693: 1642291200 choices, <class 'int'>
2725: 66355200 choices, <class 'int'>
3030: 104063978962944 choices, <class 'int'>
3073: 240 choices, <class 'int'>


In [38]:
corefcluster = all_data[2363]["clusters"]

In [39]:
len(corefcluster)

38

In [41]:
collections.Counter(len(cluster) for cluster in corefcluster)

Counter({18: 1,
         6: 1,
         7: 1,
         2: 20,
         29: 1,
         3: 4,
         4: 4,
         5: 1,
         14: 1,
         12: 1,
         8: 1,
         20: 1,
         19: 1})

In [43]:
all_data[2127].keys()

dict_keys(['document_id', 'cased_words', 'sent_id', 'part_id', 'speaker', 'pos', 'deprel', 'head', 'clusters'])

In [44]:
def print_document(doc):
    sent_ids = numpy.unique(doc["sent_id"])
    sentences = []
    for sent_id in sent_ids:
        tokens = [token for token, token_sent_id in zip(doc["cased_words"], doc["sent_id"]) if token_sent_id == sent_id]
        sentence = " ".join(tokens)
        sentences.append(sentence)
    corefcluster = []
    for cluster in doc["clusters"]:
        mentions = []
        for start, end in cluster:
            mention_tokens = doc["cased_words"][start: end]
            mention = " ".join(mention_tokens)
            mentions.append(mention)
        corefcluster.append(mentions)
    print(f"document {doc['document_id']}")
    for i, sentence in enumerate(sentences):
        print(f"{i + 1:2d}.  {sentence}")
    print()
    print("clusters =>")
    for i, cluster in enumerate(corefcluster):
        print(f"cluster {i + 1}: {cluster}")

In [46]:
print_document(all_data[10])

document bc/cctv/00/cctv_0002
 1.  News and events happen every day .
 2.  What you are interested in is exactly what our focuses are .
 3.  This is CCTV Focus Today .
 4.  Prior to the APEC meeting , Japanese Prime Minister Junichiro Koizumi visited Yasukuni Shrine for the fifth time and was strongly condemned by Asian nations such as China and South Korea .
 5.  What kind of situation will Japan find itself in at the APEC summit ?
 6.  What heavy prices will Japan pay for Koizumi 's paying respect to the ghosts ?
 7.  Focus Today is coming up in a moment .

clusters =>
cluster 1: ['This', 'Focus Today']
cluster 2: ['the APEC meeting', 'the APEC summit']
cluster 3: ['Japanese Prime Minister Junichiro Koizumi', "Koizumi 's"]
cluster 4: ['Japan', 'itself', 'Japan']


In [47]:
for doc in all_data:
    print(doc["document_id"])

bc/cctv/00/cctv_0001
bc/cctv/00/cctv_0001
bc/cctv/00/cctv_0001
bc/cctv/00/cctv_0001
bc/cctv/00/cctv_0001
bc/cctv/00/cctv_0001
bc/cctv/00/cctv_0001
bc/cctv/00/cctv_0001
bc/cctv/00/cctv_0001
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0002
bc/cctv/00/cctv_0003
bc/cctv/00/cctv_0003
bc/cctv/00/cctv_0003
bc/cctv/00/cctv_0003
bc/cctv/00/cctv_0003
bc/cctv/00/cctv_0003
bc/cctv/00/cctv_0003
bc/cctv/00/cctv_0003
bc/cctv/00/cctv_0003
bc/cctv/00/cctv_0003
bc/cctv/00/cctv_0003
bc/cctv/00/cctv_0003
bc/cctv/00/cctv_0003
bc/cctv/00/cctv_0003
bc/cctv/00/cctv_0003
bc/cctv/00/cctv_0003
bc/cctv/00/cctv_0004
bc/cctv/00/cc