<a id="top"></a>

# Correlate metrics and behavior

Calculate correlation between image quality metrics and behavior (detection performance or rating).


- [Load data](#load)
- [Calculate correlations (include real)](#correlation_incl_real)
- [Calculate correlations (exclude real)](#correlation_excl_real)


In [1]:
import os
import pandas as pd
import numpy as np
import scipy, pickle
from scipy import stats

In [2]:
from pathlib import Path
home = str(Path.home())
print(home)

basedir = os.path.join(home, 'git/MRI-GAN-QA/')
datadir = basedir + 'experiment/'
imagedir = basedir + 'experiment/Psytoolkit/'
figdir = basedir + 'figures/'
resultdir = basedir + 'results/'

/Users/matthiastreder


In [3]:
def clean_RT_data(df, RT_col, low_RT, high_RT, n_timeouts = None):
    '''Cleans RT data by removing too short RTs and removing participants with
    too many timeouts'''
    # time outs
    if n_timeouts is not None:
        timeout = df.groupby('participant')['timeout'].sum()
        timeout = timeout[timeout > n_timeouts]
        if timeout.shape[0] > 0:
            print(f'{timeout.shape[0]} participants have >{n_timeouts} timeouts, removing them')
            for ix in timeout.index:
                df = df[df['participant'] != ix] 
            
    # check lower RT bound
    df_low = df[df[RT_col] <= low_RT]
    if df_low.shape[0] > 0:
        print(f'Removing {df_low.shape[0]} trials with RT <= {low_RT}')
        df = df[df[RT_col] > low_RT]
    
    # check high RT bound
    df_high = df[df[RT_col] >= high_RT]
    if df_high.shape[0] > 0:
        print(f'Removing {df_high.shape[0]} trials with RT >= {high_RT}')
        df = df[df[RT_col] < high_RT]

    return df

In [4]:
def trim_mean_upper(x, prop=0.1):
    '''Calculates trimmed mean but trims only upper tail'''
    return x[x <= x.quantile(1-prop)].mean()

In [20]:
task = 'detection'
# task = 'rating'

print('processing', task, 'task')

processing detection task


---

<a id="load"></a>
# Load data
[back to top](#top)

In [21]:
# Load metrics
with open(resultdir + f'analyze_gan_2D_all_metrics_{task}.pickle', 'rb') as f:
    (df_is, df_mis, df_fid, df_mmd, df_vmaf, stat) = pickle.load(f)

In [22]:
# Load NIQE/BRISQUE
if task == 'detection':
    nb = pd.read_csv(resultdir + 'analyze_2D_image_NIQE_BRISQUE_detection.csv')
else:
    nb = pd.read_csv(resultdir + 'analyze_2D_image_NIQE_BRISQUE_rating.csv')

nb['iteration'] = nb.batch.astype('category')
nb['iteration'] = nb.iteration.cat.reorder_categories([1,2,3,4,5,0]).cat.rename_categories({0:'real', 1:'344', 2:'1055', 3:'7954', 4:'24440', 5:'60000'})

nb.head()

Unnamed: 0,image,button,batch,niqe,brisque,niqe-mri,brisque-mri,iteration
0,batch_344_im_101,2,1,17.858787,39.121691,18.306365,3.591137,344
1,batch_344_im_102,2,1,16.137112,40.706215,18.580572,3.066754,344
2,batch_344_im_103,2,1,15.194134,37.045835,17.003539,3.764313,344
3,batch_344_im_104,2,1,16.908035,38.088221,18.924624,3.858712,344
4,batch_344_im_105,2,1,17.156788,38.954275,20.316853,3.47628,344


In [23]:
# Load behavioral data
with open(resultdir + f'psytoolkit_all_participants26.pickle', 'rb') as f:
    behavior = pickle.load(f)
    
if task == 'detection':
    behavior = behavior[behavior['blockname'] == 'MAIN_EXPERIMENT_BLOCK']
else:
    behavior = behavior[behavior['blockname'] == 'RATING_BLOCK']
    
# 'REAL' is recoded as -1 (so we can calculate mean later on)
behavior.batch = behavior.batch.astype('object')
behavior.batch[behavior.batch=='real'] = -1
behavior.batch = behavior.batch.astype('int')


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  behavior.batch[behavior.batch=='real'] = -1


In [24]:
# clean RTs
if task == 'detection':
    behavior = clean_RT_data(behavior, 'RT', low_RT=150, high_RT=20000, n_timeouts=30)
else:
    behavior = clean_RT_data(behavior, 'rate_RT', low_RT=150, high_RT=10000)

behavior.shape

1 participants have >30 timeouts, removing them
Removing 16 trials with RT >= 20000


(5984, 20)

In [25]:
# average across participants
if task == 'detection':
    behavior = behavior.groupby(['tablerow'])[['real', 'correct', 'experience','batch']].mean().reset_index()
#     button_av = behavior.groupby(['participant','batch'])[['real', 'correct', 'experience']].mean().reset_index()
    behavior['real'] *= 100
    behavior['correct'] *= 100
    print(behavior.head(5))
else:
    behavior = behavior.groupby(['tablerow'])[['rate', 'experience','batch']].mean().reset_index()

behavior.batch = behavior.batch.astype('int').astype('object')
behavior.batch[behavior.batch==-1] = 'real'

   tablerow      real     correct  experience  batch
0         1  0.000000  100.000000    3.200000  344.0
1         2  0.000000  100.000000    3.200000  344.0
2         3  4.166667   91.666667    3.208333  344.0
3         4  4.000000   96.000000    3.200000  344.0
4         5  4.000000   96.000000    3.200000  344.0


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  behavior.batch[behavior.batch==-1] = 'real'


In [26]:
behavior.head()

Unnamed: 0,tablerow,real,correct,experience,batch
0,1,0.0,100.0,3.2,344
1,2,0.0,100.0,3.2,344
2,3,4.166667,91.666667,3.208333,344
3,4,4.0,96.0,3.2,344
4,5,4.0,96.0,3.2,344


---

<a id="correlation_excl_real"></a>
# Calculate correlations (excluding `real`)
[back to top](#top)


In [27]:
behavior_ex = behavior[behavior.batch != 'real']
n = behavior_ex.shape[0]
print(n)

160


In [28]:
mis_av = df_mis.groupby('iteration')['MIS'].mean()
fid_av = df_fid.groupby('iteration')['FID'].mean()
mmd_av = df_mmd.groupby('iteration')['MMD'].mean()

In [29]:
# MIS, FID, and MMD have no values per-image, but only averages per iteration.
# So we fill up a vector 
mis_val = np.zeros((behavior_ex.shape[0],))
fid_val = np.zeros((behavior_ex.shape[0],))
mmd_val = np.zeros((behavior_ex.shape[0],))

for it in (344, 1055, 7954, 24440, 60000):
    print(sum(behavior.batch == it))
    mis_val[behavior_ex.batch == it] = mis_av[it]
    fid_val[behavior_ex.batch == it] = fid_av[it]
    mmd_val[behavior_ex.batch == it] = mmd_av[it]

32
32
32
32
32


In [30]:
target = behavior_ex.real if task=='detection' else behavior_ex.rate

print('IS:', stats.spearmanr(df_is.IS[:n], target))
print('MIS:', stats.spearmanr(mis_val, target))
print('FID:', stats.spearmanr(fid_val, target))
print('MMD:', stats.spearmanr(mmd_val, target))
print('NIQE:', stats.spearmanr(nb['niqe'][:n], target))
print('NIQE-MRI:', stats.spearmanr(nb['niqe-mri'][:n], target))
print('BRISQUE:', stats.spearmanr(nb['brisque'][:n], target))
print('BRISQUE-MRI:', stats.spearmanr(nb['brisque-mri'][:n], target))

IS: SpearmanrResult(correlation=-0.2557241071130985, pvalue=0.0010994342495376304)
MIS: SpearmanrResult(correlation=0.1109438090138115, pvalue=0.1625180169333629)
FID: SpearmanrResult(correlation=-0.60237698308851, pvalue=3.5651129764241114e-17)
MMD: SpearmanrResult(correlation=-0.60237698308851, pvalue=3.5651129764241114e-17)
NIQE: SpearmanrResult(correlation=-0.7060186822145289, pvalue=1.8810304731231436e-25)
NIQE-MRI: SpearmanrResult(correlation=-0.5363324816837106, pvalue=2.683282311377366e-13)
BRISQUE: SpearmanrResult(correlation=0.4649300265637616, pvalue=5.877070362695938e-10)
BRISQUE-MRI: SpearmanrResult(correlation=-0.859937587481293, pvalue=5.2033647796747766e-48)


---

<a id="correlation_incl_real"></a>
# Calculate correlations (including `real`)
[back to top](#top)


In [31]:
mis_av = df_mis.groupby('iteration')['MIS'].mean()

In [32]:
# MIS, FID, and MMD have no values per-image, but only averages per iteration.
# So we fill up a vector 
mis_val = np.zeros((behavior.shape[0],))

for it in (344, 1055, 7954, 24440, 60000, 'real'):
    mis_val[behavior.batch == it] = mis_av[it]

In [33]:
target = behavior.real if task=='detection' else behavior.rate

print('IS:', stats.spearmanr(df_is.IS, target))
print('MIS:', stats.spearmanr(mis_val, target))
print('NIQE:', stats.spearmanr(nb['niqe'], target))
print('NIQE-MRI:', stats.spearmanr(nb['niqe-mri'], target))
print('BRISQUE:', stats.spearmanr(nb['brisque'], target))
print('BRISQUE-MRI:', stats.spearmanr(nb['brisque-mri'], target))

IS: SpearmanrResult(correlation=0.007990392367803863, pvalue=0.9019944003871532)
MIS: SpearmanrResult(correlation=0.3439971066800105, pvalue=4.521177803825465e-08)
NIQE: SpearmanrResult(correlation=-0.5188688895953197, pvalue=6.076029935991954e-18)
NIQE-MRI: SpearmanrResult(correlation=-0.761025303379701, pvalue=1.2603042770588866e-46)
BRISQUE: SpearmanrResult(correlation=-0.013092691108690668, pvalue=0.840088095866526)
BRISQUE-MRI: SpearmanrResult(correlation=-0.8585113137801569, pvalue=5.563524264506524e-71)
