In [None]:
import time
import requests
import json
import redshift_connector
import multiprocessing as mp
import mag_functions as F
from math import ceil

In [None]:
# input that contains a single affiliation string to parse
with open("test_json_single.json", "rb") as f:
    input_json = json.load(f)

In [None]:
# input that contains hard samples to make sure the model does not return errors
with open("test_json_batch_tough.json", "rb") as f:
    input_json = json.load(f)

In [None]:
# input that contains 6 affiliation strings to parse
with open("test_json_batch.json", "rb") as f:
    input_json = json.load(f)

In [None]:
%%time
# testing the call to the API one time
len(json.loads(F.get_tags(input_json[10:13], 1)[1]))

In [None]:
%%time
# testing the API call 100 times sequentially
for i in range(100):
    F.get_tags(input_json, 1)[1]

### Testing throughput with threading

In [None]:
# iteratate through different batch sizes to see which one gives the optimal performance while
# using threading to take advantage of SageMaker parallelism
for i in [20,30]:   
    test_payload = input_json*i
    start_time = time.time()
    final_output = []
    with mp.Pool(processes=64) as p:
        results = [p.apply(F.get_tags, args=(test_payload, 1)) for x in range(100)]
    print(f"batch={i}____{(time.time() - start_time) / (i*100*len(input_json))} seconds")

### Testing API with data random queried OpenAlex data

In [None]:
# Redshift credentials for querying the OpenAlex database
with open("redshift_creds.txt", "r") as f:
    host = f.readline()[:-1]
    password= f.readline()[:-1]

In [None]:
# Creating a connection
conn = redshift_connector.connect(
     host=host,
     database='dev',
     user='app_user',
     password=password
  )

cursor = conn.cursor()

In [None]:
# Making calls to the API with data that is queried directly from the OpenAlex database
all_input = []
all_output = []
for i in range(25):
    query = f"""select original_affiliation
           from mid.affiliation
           where original_affiliation is not null
           order by RANDOM()
           limit 50"""
    cursor.execute("ROLLBACK;")
    cursor.execute(query)
    test_strings = cursor.fetch_dataframe()
    test_strings.columns = ['affiliation_string']
    test_input = json.loads(test_strings.to_json(orient='records'))
    all_input += test_input
    test_output = json.loads(F.get_tags(test_input, 1)[1])
    print(len(test_output))
    all_output += test_output

In [None]:
# Checking to make sure there are no bad or empty predictions
for i, j in zip(all_input,all_output):
    if j['affiliation_id']:
        print(f"{j['affiliation_id']}")