In [1]:
# avg corr(pred rating, true rating) for each quintile of type frequency (E[T]) for full model and reports only model on semisynthetic data
import os
import pandas as pd
import pickle
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
import warnings
plt.rcParams['font.family'] = 'serif'

# Set the font used for math expressions to LaTeX
plt.rcParams["mathtext.fontset"] = "cm"

In [2]:
# file paths
base_file = '/share/garg/311_data/sb2377/clean_codebase/three_year_base.csv'
results_dir = '/share/garg/311_data/sb2377/results'

# user specified arguments
types = {'Street': 'StreetConditionDOT',
         'Park': 'MaintenanceorFacilityDPR',
         'Rodent': 'RodentDOHMH',
         'Food': 'FoodDOHMH',
         'DCWP': 'ConsumerComplaintDCWP'}
models = {'Full model': {'job_ids':[i * 3 + 3200 for i in range(20)]},
          'Reports-only model': {'job_ids':[i * 3 + 3201 for i in range(20)]}}
epoch = '59'

In [3]:
# load files
base_df = pd.read_csv(base_file)

In [4]:
# get type indices
# for df with all types
type_df = base_df[['typeagency', 'type_idxs']].drop_duplicates()
indices = {}
for type_name, type_id in types.items():
    idx = type_df[type_df['typeagency'] == type_id]['type_idxs'].iloc[0]
    indices[type_name] = idx

In [5]:
# get predicted ratings for all jobs
checkpoint_file = '{}/job{}/model-epoch={}.ckpt'
results_file = '{}/job{}/epoch={}_test_unobserved.pkl'
checkpoint_counters = {}
results_counters = {}
for m in models:
    checkpoint_counters[m] = 0
    results_counters[m] = 0
dfs = {}
for m in models:
    dfs[m] = []

for m in models:
    for i, job_idx in enumerate(models[m]['job_ids']):
        if os.path.exists(checkpoint_file.format(results_dir, job_idx, epoch)):
            checkpoint_counters[m] += 1
        if os.path.exists(results_file.format(results_dir, job_idx, epoch)):
            results_counters[m] += 1
            with open(results_file.format(results_dir, job_idx, epoch), 'rb') as file:
                pred_rating, true_rating, mask, node_embedding, type_embedding, node_idxs, type_idxs, demographics, pred_pt, true_t = pickle.load(file)

            df = pd.DataFrame()
            df['pred_rating'] = pred_rating
            df['true_rating'] = true_rating
            df['node_idxs'] = node_idxs
            df['type_idxs'] = type_idxs
            df['pred_pt'] = pred_pt
            df['true_t'] = true_t
            df['mask'] = mask
            df['job_id'] = i

            dfs[m].append(df)

for m in models:
    print('{}: checkpoint files done = {}'.format(m, checkpoint_counters[m]))
    print('{}: results files done = {}'.format(m, results_counters[m]))

Full model: checkpoint files done = 20
Full model: results files done = 20
Reports-only model: checkpoint files done = 20
Reports-only model: results files done = 20


In [None]:
# print avg correlation for each qunitile of type frequency
for m in models:
    df_set = dfs[m]
    corrs = []
    freqs = []
    for idx in range(len(dfs[m][0]['type_idxs'].unique())):
        if idx not in list(indices.values()):
            type_corrs = []
            type_rmses = []
            type_freqs = []
            for df in df_set:
                df_type = df[df['type_idxs'] == idx]
                node_df = df_type.groupby(['node_idxs', 'type_idxs']).mean().reset_index()

                # calculate correlation
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    if m == 'Reports-only model':
                        # for reports-only model, we use -P(T) as a proxy for r
                        corr = pearsonr(-1 * node_df['pred_pt'], node_df['true_rating'])
                    else:
                        corr = pearsonr(node_df['pred_rating'], node_df['true_rating'])
                type_corrs.append(corr[0])

                # calculate type frequency
                type_freqs.append(df_type['true_t'].mean())

            corrs.append(type_corrs)
            freqs.append(type_freqs)

    corrs = np.array(corrs)
    freqs = np.array(freqs)
    results_df = pd.DataFrame()
    results_df['type_freq'] = freqs[:, 0]
    for i in range(corrs.shape[1]):
        results_df['corrs_{}'.format(i)] = corrs[:, i]
    # Create quintiles for column type frequency
    results_df['quintile'] = pd.qcut(results_df['type_freq'], q=5, labels=False) + 1  # Labels from 1 to 5
    results_df.drop(columns=['type_freq'], inplace=True)

    # Compute the average correlation of each type for each quintile
    avg_per_quintile = results_df.groupby('quintile').mean()
    avg_per_quintile = avg_per_quintile.to_numpy()
    mean_avg_per_quintile = avg_per_quintile.mean(axis=1)

    # Compute 95% CIs
    std_per_quintile = avg_per_quintile.std(axis=1)
    ci_per_quintile = 1.96 * std_per_quintile / np.sqrt(avg_per_quintile.shape[1] - 1)

    for i in range(5):
        print('Model: {}, Quintile {}-{}%, Corr: {} \pm {}'.format(m, 
                                                                    i*20, 
                                                                    (i+1)*20, 
                                                                    mean_avg_per_quintile[i], 
                                                                    ci_per_quintile[i]))


Model: Full model, Quintile 0-20%, Corr: 0.26891619090191043 \pm 0.025082469820980446
Model: Full model, Quintile 20-40%, Corr: 0.19731850839534595 \pm 0.020658027588664418
Model: Full model, Quintile 40-60%, Corr: 0.3667753357311248 \pm 0.010829055778879635
Model: Full model, Quintile 60-80%, Corr: 0.6020162966509004 \pm 0.004915444233865695
Model: Full model, Quintile 80-100%, Corr: 0.7677887792568437 \pm 0.0022427212146255474
Model: Reports-only model, Quintile 0-20%, Corr: -0.07395683401514073 \pm 0.0352895001215662
Model: Reports-only model, Quintile 20-40%, Corr: -0.0657438980919717 \pm 0.02067375953093559
Model: Reports-only model, Quintile 40-60%, Corr: 0.32649365290354904 \pm 0.007284297465675729
Model: Reports-only model, Quintile 60-80%, Corr: 0.5813738487305552 \pm 0.004090901616726463
Model: Reports-only model, Quintile 80-100%, Corr: 0.7451271065725565 \pm 0.0022190179756350565
