# GNN performance as a function of data


In [1]:
# Stock imports
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

from qian_et_al_2023.src import analysis
from qian_et_al_2023.src import base
from qian_et_al_2023.src import data_loaders
base.set_visual_settings()

In [2]:
# Canonical way to load (most of) the data... add lines as needs to include other pieces
models, humans, panel, subjects = data_loaders.get_clean()

  panel = humans.groupby('RedJade Code').mean().loc[mol_codes, base.MONELL_CLASS_LIST]


### Run the code to prepare the data to be visualized

In [12]:
transpose_ott = analysis.fast_process(humans, models, axis=1)
corr_df = transpose_ott.groupby('index').agg(np.nanmedian)
corr_df['Test data counts'] = (panel > 0.7).sum(axis=0)
training_class_counts = pd.read_csv(base.DATA_PATH / "training_class_counts.csv")
training_class_counts['label'] = training_class_counts['label'].apply(
    lambda l: {"jasmin": "jasmine"}.get(l, l).title())
training_class_counts = training_class_counts.set_index('label')
training_class_counts = training_class_counts.loc[base.MONELL_CLASS_LIST]
corr_df['Training data counts'] = training_class_counts

### Name the figure

In [None]:
fig_name = '3B'

### Run the code to make the figure and save the figure to friendly formats

In [None]:
plt.figure(figsize=(8,8))
plot = sns.scatterplot(data=corr_df.reset_index(), x='Training data counts', y='GNN',
                size='Test data counts', sizes=(20, 400), color='#fc8d62')
plot.legend(loc='lower right', title='Test data counts')
plot.set(xscale="log")
plt.xlabel('Training Data Counts', fontsize=20)
plt.xlim(10**1.5, 10**3.5)
plt.ylabel('GNN Correlation with Panel Mean', fontsize=20)
plt.xticks(10 ** np.linspace(1.5, 3.5, 5), fontsize=16)
plt.yticks(np.linspace(-0.1, 0.5, 7), fontsize=16)
for label in ('Camphoreous', 'Fishy', 'Cooling', 'Sulfurous', 'Roasted', 'Fruity', 'Floral', 'Sweet', 'Green',
               'Ozone', 'Sharp', 'Waxy', 'Medicinal', 'Musty', 'Fermented', 'Garlic', 'Alcoholic', 'Musk', 'Meaty'):
    row = corr_df.loc[label]
    offset = (5, 5)
    if label == 'Floral':
        offset = (5, -15)
    if label == 'Fishy':
        offset = (5, -15)
    if label == 'Camphoreous':
        offset = (5, -5)
    plt.annotate(label, (row['Training data counts'], row['GNN']), xytext=offset, textcoords='offset points',
                fontsize=20)

for axis in ['bottom','left']:
    plot.spines[axis].set_linewidth(3)
    plot.spines[axis].set_edgecolor('black')
for axis in ['top','right']:
    plot.spines[axis].set_visible(False)

plot.grid(False)
