In [0]:
# Environment note: this runs on DBR standard runtime 11.3

# ---------
# Imports
# ---------
import pandas as pd


# -------
# IMDB
# -------

# Flair accuracy
spark.sql('''
SELECT rank_first_match, COUNT(*) as n_in_rank
FROM user_nsulliv3.simpler_imdb_eval_maj_ranks
GROUP BY 1
;
''')

# 555/572 = 97% accuracy

# GPT accuracy
imdb_gpt = pd.read_csv('https://raw.githubusercontent.com/sullivannicole/simplER/main/data/imdb_after_gpt.csv')

spark.createDataFrame(imdb_gpt).createOrReplaceTempView('imdb_gpt')

imdb_gpt = spark.sql('''
WITH gpt_entities AS (SELECT *, LOWER(RTRIM(REGEXP_REPLACE(movie, '[^0-9a-zA-Z ]', ''))) AS movie_no_punc, LOWER(RTRIM(REGEXP_REPLACE(gpt_movie_title, '[^0-9a-zA-Z ]', ''))) AS gpt_movie_no_punc
FROM imdb_gpt
WHERE gpt_movie_title IS NOT NULL),

classified AS (SELECT COUNT(*) AS n_obs
              FROM gpt_entities)

SELECT COUNT(a.user_review) AS n_matches, b.n_obs
FROM gpt_entities a
JOIN classified b
WHERE a.movie_no_punc = a.gpt_movie_no_punc
OR CONTAINS(a.movie_no_punc, a.gpt_movie_no_punc)
OR CONTAINS(a.gpt_movie_no_punc, a.movie_no_punc)
GROUP BY 2;

''')

# These preds aren't captured using absolute rules above, but may still be the correct prediction technically (just missing a token like "the", "a" or "and"), so manually check - quicker than building out an eval pipeline at this point
imdb_gpt_check = spark.sql('''
WITH gpt_entities AS (SELECT *, LOWER(RTRIM(REGEXP_REPLACE(movie, '[^0-9a-zA-Z ]', ''))) AS movie_no_punc, LOWER(RTRIM(REGEXP_REPLACE(gpt_movie_title, '[^0-9a-zA-Z ]', ''))) AS gpt_movie_no_punc
FROM imdb_gpt
WHERE gpt_movie_title IS NOT NULL),

correct_class AS (SELECT *
FROM gpt_entities a
WHERE a.movie_no_punc = a.gpt_movie_no_punc
OR CONTAINS(a.movie_no_punc, a.gpt_movie_no_punc)
OR CONTAINS(a.gpt_movie_no_punc, a.movie_no_punc))

SELECT *
FROM gpt_entities
WHERE `Unnamed: 0` NOT IN (SELECT DISTINCT `Unnamed: 0` FROM correct_class)
;

''')

# Manually identify additional correct (n = 72): 31, 32, 35, 43, 69, 97, 109, 113, 135, 163, 166, 167, 177, 197, 235, 249, 261, 287, 315, 317, 325, 326
# 329, 330, 334, 362, 374, 378, 390, 401, 402, 416, 436, 437, 477, 486, 487, 491, 492, 502, 513, 638, 665, 676, 817, 825, 850, 856, 882, 895,
# 896, 918, 919, 939, 968, 984, 1023, 1028, 1095, 1104, 1113, 1119, 1120, 1125, 1176, 1240, 1247, 1253, 1316, 1317, 1357, 1358
imdb_gpt_check.display()

# imdb_gpt.display() # 808+72 / 911 = 96.5% accuracy

# Altogether: 0.9676331759946055% accuracy

# --------------
# COVID - user
# --------------

covid_usr_w_ids = pd.read_csv('https://raw.githubusercontent.com/sullivannicole/simplER/main/data/covid_raw_user_sentences_w_ids.csv')
covid_usr_df = covid_usr_w_ids.drop(columns = ['Unnamed: 0'])
spark.createDataFrame(covid_usr_df).write.mode('overwrite').saveAsTable('user_nsulliv3.simpler_covid_usr_ids')

# Run in SQL editor; UNPIVOT doesn't work in notebook
spark.sql('''
create or replace table user_nsulliv3.simpler_covid_usr_long AS
select id, sentence, country
from user_nsulliv3.simpler_covid_usr_ids
unpivot (country for col in (country_1, country_2))
''')

# Flair accuracy
covid_flair = pd.read_csv('https://raw.githubusercontent.com/sullivannicole/simplER/main/data/covid_user_flair.csv')

spark.createDataFrame(covid_flair).createOrReplaceTempView('covid_flair')

covid_flair_correct = spark.sql('''
            SELECT a.id, a.text AS country_pred 
            FROM covid_flair a 
            LEFT JOIN user_nsulliv3.simpler_covid_usr_long b 
            ON a.id = b.id 
            AND a.text = b.country
            WHERE entity_detected = 'GPE';''')

# 1 sentence got missed in my original classification but Flair classifies correctly when checked manually: 42
covid_flair_check = spark.sql('''
WITH flair_preds AS (SELECT a.id, a.text AS country_pred 
            FROM covid_flair a 
            LEFT JOIN user_nsulliv3.simpler_covid_usr_long b 
            ON a.id = b.id 
            AND a.text = b.country
            WHERE entity_detected = 'GPE')

SELECT a.*, b.*
FROM user_nsulliv3.simpler_covid_usr_long a 
LEFT JOIN covid_flair b 
ON a.id = b.id
WHERE a.id NOT IN (SELECT DISTINCT id FROM flair_preds)
''')

# (57+1) / 64 = 91% accuracy

# GPT accuracy
covid_gpt_raw = pd.read_csv('https://raw.githubusercontent.com/sullivannicole/simplER/main/data/covid_raw_user_sentences_w_ids_after_gpt.csv')
spark.createDataFrame(covid_gpt_raw).createOrReplaceTempView('covid_gpt')

covid_gpt_correct = spark.sql('''
SELECT COUNT(*) AS n_matches
FROM covid_gpt
WHERE count_type1 = LOWER(count_type_gpt)
''')

# Manually check "incorrect" predictions
# +2 for incorrect labels on ids = 40, 47, correctly predicted by GPT
spark.sql('''
SELECT *
FROM covid_gpt
WHERE count_type_gpt IS NOT NULL
AND LOWER(count_type_gpt) != 'unknown'
AND count_type1 != LOWER(count_type_gpt)
''').display()

# How many total observations were there?
spark.sql('''
SELECT COUNT(*)
FROM covid_gpt
WHERE count_type_gpt IS NOT NULL
AND LOWER(count_type_gpt) != 'unknown'
''').display()

# GPT accuracy: 73.5%

# --------------------
# COVID - generated
# --------------------

spark.sql('''
with tokenized AS (SELECT *, SPLIT(sentence, ' ') AS tokens
FROM user_nsulliv3.simpler_covid_gen),

count_type AS (SELECT *, CONCAT_WS(' ', tokens[3], tokens[4]) AS count_type_pattern
FROM tokenized),

-- used to write case when
-- SELECT DISTINCT table, count_type_pattern
-- FROM count_type

regex_pred AS (SELECT sentence, table,
CASE WHEN count_type_pattern = 'new cases' THEN 'new confirmed'
WHEN count_type_pattern IN ('total confirmed', 'cases in', 'confirmed cases', 'total cases', 'cases .') THEN 'total confirmed'
WHEN count_type_pattern = 'deaths cases' THEN 'total deaths'
WHEN count_type_pattern = 'recovered cases' THEN 'total recovered'
ELSE count_type_pattern 
END AS count_type_pred
FROM count_type)

SELECT COUNT(*) AS n_matches
FROM regex_pred
WHERE table = count_type_pred
;
''')

# count_type: 100% accuracy
# country: 98.6% (6912/7008)