In [2]:
import os

import pandas as pd
from tqdm import tqdm
import requests


In [3]:
def extract_salient_entities(
    data: list,
    title: str = "title",
    article: str = "article",
    id: str = "id",
):

    annotated_articles = []

    for row in tqdm(data, total=len(data)):

        if len(row[title].split()) > 3300 or len(row[article].split()) > 3300:
            continue

        # Perform mention detection on headline and body text
        el_title = requests.post(API_URL, json={
            "text": row[title],
            "spans": []
        }).json()
        el_article = requests.post(API_URL, json={
            "text": row[article],
            "spans": []
        }).json()

        # Filter mentions with the ORG tag
        headline_mentions_org = [
            mention for mention in el_title if mention[-1] == "ORG"
        ]
        body_text_mentions_org = [
            mention for mention in el_article if mention[-1] == "ORG"
        ]

        # Mark salient entities
        salient_entities_org = []
        for body_entity in body_text_mentions_org:
            if body_entity[3] in [
                headline_entity[3] for headline_entity in headline_mentions_org
            ]:
                salient_entities_org.append(body_entity)

        if salient_entities_org:
            salient_entities_org_set = set([entity[3] for entity in salient_entities_org])
        else:
            salient_entities_org_set = {'None'}
            
        # Save the annotated article
        annotated_articles.append(
            {
                "raw_news_id": row[id],
                # "headline": row[title],
                # "body_text": row[article],
                "headline_mentions": el_title,
                "body_text_mentions": el_article,
                "salient_entities_org": salient_entities_org,
                "salient_entities_set": salient_entities_org_set,
            }
        )

    return annotated_articles

In [18]:
API_URL = "http://rel:5555/api"

In [19]:
db_host = os.getenv('DB_HOST')
db_port = os.getenv('DB_PORT')
db_name = os.getenv('DB_NAME')
db_user = os.getenv('DB_USER')
db_pass = os.getenv('DB_PASS')

In [20]:
import psycopg2
import json

# Database connection parameters
conn_params = {
    "host": os.getenv('DB_HOST'),
    "port": os.getenv('DB_PORT'),
    "database": os.getenv('DB_NAME'),
    "user": os.getenv('DB_USER'),
    "password": os.getenv('DB_PASS')
}

# Connect to the PostgreSQL database
conn = psycopg2.connect(**conn_params)
cursor = conn.cursor()

# SQL query to fetch the data
query = "SELECT id, title, text FROM raw_news_articles WHERE is_parsed = True;"

# Execute the query
cursor.execute(query)

# Fetch all rows
rows = cursor.fetchall()

# Close the connection
cursor.close()
conn.close()

In [21]:
# Convert rows to JSON
data = [{"id": row[0], "title": row[1], "text": row[2]} for row in rows]
json_data = json.dumps(data)

In [22]:
data = json.loads(json_data)

In [31]:
curl -X POST http://rel:5555/api -H "Content-Type: application/json" -d '{"text": "China Is Stealing AI Secrets to Turbocharge Spying, U.S. Says", "spans": []}'
curl -X POST http://localhost:5555/api -H "Content-Type: application/json" -d '{"text": "China Is Stealing AI Secrets to Turbocharge Spying, U.S. Says", "spans": []}'

[[0, 5, "China", "China", 0.8694822901752369, 0.8440544605255127, "LOC"], [52, 4, "U.S.", "United_States", 0.521479391215041, 0.9513190984725952, "LOC"]]

In [61]:
[data[2], data[7]]

[{'id': 7343,
  'title': 'China Is Stealing AI Secrets to Turbocharge Spying, U.S. Says',
  'text': 'On a July day in 2018, Xiaolang Zhang headed to the San Jose, Calif., airport to board a flight to Beijing. He had passed the checkpoint at Terminal B when his journey was abruptly cut short by federal agents.\n\nAfter a tipoff by Apple ’s security team, the former Apple employee was arrested and charged with stealing trade secrets related to the company’s autonomous-driving program.'},
 {'id': 7350,
  'title': 'U.S. Coast Guard suspends search for man who fell off cruise ship that departed from Baltimore',
  'text': 'BALTIMORE -- The U.S. Coast Guard on Christmas Eve suspended its search for a 41-year-old man who fell off a ship belonging to a cruise line that departs from Baltimore.\n\n#FinalUpdate: The @USCG has suspended the search efforts for the missing cruise ship passenger 127 mi east of #Charleston.\n\n\n\nCoast Guard crews searched more than 1,625 square miles and 8 hours.\n\n

In [25]:
# df_test = pd.read_excel('target_raw_news_articles_202312141455.xlsx')
# df_test = df_test[~df_test['target'].isna()]
# df_test['target'] = df_test['target'].apply(lambda x: x.strip("[]").replace("'", ""))
# df_test.shape

In [26]:
# df_test.head()

In [62]:
df_test_rel = extract_salient_entities([data[2], data[7]], 'title', 'text')

100%|██████████| 2/2 [00:01<00:00,  1.03it/s]


In [79]:
len(df_test_rel[1]['salient_entities_set'])

1

In [84]:
df_test_rel[1]['salient_entities_set']

{'United_States_Coast_Guard'}

In [85]:
df_test_rel[0]['salient_entities_set']

{'None'}

In [87]:
len(df_test_rel[1]['salient_entities_org'])

6

In [89]:
len([
    [17, 16, 'U.S. Coast Guard', 'United_States_Coast_Guard', 0.8891559804540544, 0.8521444400151571, 'ORG'], 
    [192, 4, 'USCG', 'United_States_Coast_Guard', 0.9111871297938334, 0.47667717933654785, 'ORG'], 
    [791, 17, 'U.S. Coast Guard.', 'United_States_Coast_Guard', 0.3872777678067984, 0.8202411532402039, 'ORG'], 
    [816, 16, 'U.S. Coast Guard', 'United_States_Coast_Guard', 0.43678396048090334, 0.7355291048685709, 'ORG'], 
    [918, 16, 'U.S. Coast Guard', 'United_States_Coast_Guard', 0.47601745960158903, 0.8914013306299845, 'ORG']
])

5

In [90]:
df_test_rel[1]['salient_entities_org']

[[17,
  16,
  'U.S. Coast Guard',
  'United_States_Coast_Guard',
  0.30562666525715215,
  0.8639779289563497,
  'ORG'],
 [190,
  4,
  'USCG',
  'United_States_Coast_Guard',
  0.7504895236536515,
  0.9649078845977783,
  'ORG'],
 [297,
  11,
  'Coast Guard',
  'United_States_Coast_Guard',
  0.8574496275930025,
  0.8526458442211151,
  'ORG'],
 [777,
  16,
  'U.S. Coast Guard',
  'United_States_Coast_Guard',
  0.27727612688077324,
  0.8181795080502828,
  'ORG'],
 [800,
  16,
  'U.S. Coast Guard',
  'United_States_Coast_Guard',
  0.2773915748522871,
  0.8772232333819071,
  'ORG'],
 [902,
  16,
  'U.S. Coast Guard',
  'United_States_Coast_Guard',
  0.2773879666549176,
  0.8078232804934183,
  'ORG']]