In [1]:
#%pip install allennlp
#%pip install --pre allennlp-models
#%pip install google-cloud-storage
#%pip install pandarallel

In [1]:
#! python -m spacy download en_core_web_sm

Collecting en-core-web-sm==3.3.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.3.0/en_core_web_sm-3.3.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m22.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: en-core-web-sm
Successfully installed en-core-web-sm-3.3.0
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')


In [3]:
import pandas as pd
import spacy
from allennlp.predictors.predictor import Predictor
from typing import Dict, List
from spacy.tokens import Doc
from spacy.tokens import Span
from google.cloud import storage
from datetime import datetime
from tqdm import tqdm
import re
from pandarallel import pandarallel
from contextlib import closing
import json
import torch


pandarallel.initialize(nb_workers=3, progress_bar=True)
pd.set_option('display.max_rows', 100)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', 500)

INFO: Pandarallel will run on 3 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.


In [4]:
def load_coref_index(client, bucket="meta-info", coref_index="coref-index.json"):
    files = set([f.name for f in client.list_blobs(bucket_or_name=bucket)])
    if coref_index not in files:
        return {
            "error": []
        }
    bucket = client.bucket(bucket)
    with bucket.blob(coref_index).open("r") as fp:
        index = json.load(fp)
        if "error" not in index:
            index["error"] = []
    return index


def get_coref_converted(client, bucket="markdown-corref"):
    files = set([f.name for f in client.list_blobs(bucket_or_name=bucket)])
    return files


def get_coref_work(source="markdown-converged", filter_f=lambda article: True):
    with closing(storage.Client(project="msca310019-capstone-f945")) as client:
        errors = set(load_coref_index(client)["error"])
        done = get_coref_converted(client)
        tbd = set([f.name for f in client.list_blobs(bucket_or_name=source)])
        tbd = tbd - done - errors
        tbd_df = pd.DataFrame({
            "tbd": list(tbd)
        })
    
    def filter_article(f_name):
        with closing(storage.Client(project="msca310019-capstone-f945")) as client:
            bucket = client.bucket(source)
            with bucket.blob(f_name).open("r") as fp:
                article = json.load(fp)
                if filter_f(article):
                    return [f_name, article]
        return []
    
    tbd_df["result"] = tbd_df.tbd.parallel_apply(filter_article)
    tbd_df_filtered = tbd_df.loc[tbd_df["result"].str.len() > 0]
    
    return tbd_df_filtered.result.to_list()

In [5]:
def year_filter(article, year=2023):
    timestamp = datetime.fromisoformat(article["published"])
    return timestamp.year == year

client = storage.Client(project="msca310019-capstone-f945")
works = get_coref_work(filter_f=year_filter)

len(works)

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=17430), Label(value='0 / 17430')))…

Process ForkPoolWorker-6:
Process ForkPoolWorker-5:
Process ForkPoolWorker-7:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/opt/homebrew/anaconda3/envs/FinBot/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/opt/homebrew/anaconda3/envs/FinBot/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/opt/homebrew/anaconda3/envs/FinBot/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/opt/homebrew/anaconda3/envs/FinBot/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/homebrew/anaconda3/envs/FinBot/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/homebrew/anaconda3/envs/FinBot/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args,

In [6]:
def get_span_noun_indices(doc: Doc, cluster: List[List[int]]):
        spans = [doc[span[0]:span[1]+1] for span in cluster]
        spans_pos = [[token.pos_ for token in span] for span in spans]
        span_noun_indices = [i for i, span_pos in enumerate(spans_pos)
            if any(pos in span_pos for pos in ['NOUN', 'PROPN'])]
        return span_noun_indices


def get_cluster_head_idx(doc, cluster):
    noun_indices = get_span_noun_indices(doc, cluster)
    return noun_indices[0] if noun_indices else 0

In [7]:
def print_clusters(doc, clusters):
    def get_span_words(span, allen_document):
        return ' '.join(allen_document[span[0]:span[1]+1])

    allen_document, clusters = [t.text for t in doc], clusters
    for cluster in clusters:
        cluster_head_idx = get_cluster_head_idx(doc, cluster)
        if cluster_head_idx >= 0:
            cluster_head = cluster[cluster_head_idx]
            print(get_span_words(cluster_head, allen_document) + ' - ', end='')
            print('[', end='')
            for i, span in enumerate(cluster):
                print(get_span_words(span, allen_document) + ("; " if i+1 < len(cluster) else ""), end='')
            print(']')

In [8]:
def core_logic_part(document: Doc, coref: List[int], resolved: List[str], mention_span: Span):
    final_token = document[coref[1]]
    if final_token.tag_ in ["PRP$", "POS"]:
        resolved[coref[0]] = mention_span.text + "'s" + final_token.whitespace_
    else:
        resolved[coref[0]] = mention_span.text + final_token.whitespace_
    for i in range(coref[0] + 1, coref[1] + 1):
        resolved[i] = ""
    return resolved


def original_replace_corefs(document: Doc, clusters: List[List[List[int]]]) -> str:
    resolved = list(tok.text_with_ws for tok in document)

    for cluster in clusters:
        mention_start, mention_end = cluster[0][0], cluster[0][1] + 1
        mention_span = document[mention_start:mention_end]

        for coref in cluster[1:]:
            core_logic_part(document, coref, resolved, mention_span)

    return "".join(resolved)


def get_cluster_head(doc: Doc, cluster: List[List[int]], noun_indices: List[int]):
    head_idx = noun_indices[0]
    head_start, head_end = cluster[head_idx]
    head_span = doc[head_start:head_end+1]
    return head_span, [head_start, head_end]


def is_containing_other_spans(span: List[int], all_spans: List[List[int]]):
    return any([s[0] >= span[0] and s[1] <= span[1] and s != span for s in all_spans])


def improved_replace_corefs(document, clusters):
    resolved = list(tok.text_with_ws for tok in document)
    all_spans = [span for cluster in clusters for span in cluster]  # flattened list of all spans

    for cluster in clusters:
        noun_indices = get_span_noun_indices(document, cluster)

        if noun_indices:
            mention_span, mention = get_cluster_head(document, cluster, noun_indices)

            for coref in cluster:
                if coref != mention and not is_containing_other_spans(coref, all_spans):
                    core_logic_part(document, coref, resolved, mention_span)

    return "".join(resolved)


In [9]:
allen_url = "https://storage.googleapis.com/allennlp-public-models/coref-spanbert-large-2021.03.10.tar.gz"
gpu_predictor = Predictor.from_path(allen_url, cuda_device=torch.cuda.current_device())

Some weights of BertModel were not initialized from the model checkpoint at SpanBERT/spanbert-large-cased and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
spacy.require_cpu()
nlp = spacy.load("en_core_web_sm")

In [11]:
def fetch_articles(client, max_id, bucket="markdown-converged"):
    bucket = client.bucket(bucket)
    for i in range(coref_index["standardized"], max_id):
        file_name = "%s.json" % i
        with bucket.blob(file_name).open("r") as fp:
            try:
                article_dict = json.load(fp)
                yield i, article_dict
            except:
                continue

In [12]:
import gc


def window_sentences(sentences, idx, pre=5, sep = "\n\n"):
    start_idx = max(0, idx - pre)
    context = " ".join(sentences[start_idx:idx])
    context = re.sub(r"\s+", " ", context)
    result = context + " " + sep + " " + sentences[idx]
    return result


def coref_text_whole(article, predictor):
    article = article.strip()
    if len(article) == 0:
        return ""
    clusters = predictor.predict(article)['clusters']
    doc = nlp(article)
    coref_article = improved_replace_corefs(doc, clusters)
    return coref_article


def coref_text_parts(sentences, predictor):
    sentences = list(sentences)

    for i in range(len(sentences)):
        shard = window_sentences(sentences, i)
        clusters = predictor.predict(shard)['clusters']
        doc = nlp(shard)
        coref_shard = improved_replace_corefs(doc, clusters)
        replacement_parts = coref_shard.split("\n\n")
        if len(replacement_parts) > 2:
            raise ValueError("Incorrect number of parts: " + str(len(replacement_parts)))
        replacement = replacement_parts[1].strip()
        sentences[i] = replacement

    return sentences


def coref_text(article):
    try:
        return coref_text_whole(article, gpu_predictor)
    except Exception:
        gc.collect()
        torch.cuda.empty_cache()
        return None


In [13]:
bucket = client.bucket("markdown-corref")
idx = load_coref_index(client)
errors = set(idx["error"])


with tqdm(total=len(works)) as progress:
    for f_name, article in works:
        corref_body = ""
        if len(article["body"]) > 0:
            corref_body = coref_text(article["body"])
        if corref_body:
            with bucket.blob(f_name).open("w") as fp:
                article["body"] = corref_body
                json.dump(fp=fp, obj=article)
        else:
            errors.add(f_name)
        progress.update(1)    
    


  num_effective_segments = (seq_lengths + self._max_length - 1) // self._max_length
100%|██████████| 9374/9374 [1:30:22<00:00,  1.73it/s]


NameError: name 'i' is not defined

In [17]:
idx["error"] = list(errors)
def write_conversion_index(client, index, bucket="meta-info", conversion_index="coref-index.json"):
    bucket = client.bucket(bucket)
    with bucket.blob(conversion_index).open("w") as fp:
        json.dump(fp=fp, obj=index)

In [18]:
write_conversion_index(client, idx)
client.close()