In [2]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from lee_et_al_2023.src import base

In [None]:
pc_df = pd.read_csv(base.DATA_PATH / 'fig_1_pc_data.csv')

pca_gnn = pc_df[['gnn_pc1', 'gnn_pc2']].values
pca_fp = pc_df[['fp_pc1', 'fp_pc2']].values
pca_label = pc_df[['label_pc1', 'label_pc2']].values
pc_df

In [None]:
def plot_islands(bg_points,
                 bg_name,
                 *fg_specs,
                 alpha = 0.4,
                 z_limit = 15,
                 colors=None):
    """Plot islands in a background sea of points.

    This command executes matplotlib.pyplot commands as side effects.
    Use plt.figure() to control where these outputs get generated.

    Args:
    bg_points: Array of shape (n_points, 2) indicating x,y coordinates
        of any background points.
    bg_name: Name of background points
    *fg_specs: Repeated island specifications. Allowable dict keys:
        data - array of shape (n_points, 2)
        label - str, name of the island
        color - RGB, color of the island
        scatter - bool, whether to add the point scatters
        scatter_size - int, radius of scatter points in pixel
        filled - bool, whether to fill the island
        levels_to_plot - List[float], percentiles of the KDE to plot.
    alpha: Transparency of island color fill.
    z_limit: Plotting boundaries of the plot.
    """
    default_fg_colors = colors or sns.color_palette('Set3', len(fg_specs))

    sns.scatterplot(x=bg_points[:, 0], y=bg_points[:, 1],
              s=3, color='0.60', label=bg_name)
    for fg_color, fg_spec in zip(default_fg_colors, fg_specs):
        fg_color = fg_spec.get('color', fg_color)
        fg_scatter = fg_spec.get('scatter', False)
        fg_scatter_size = fg_spec.get('scatter_size', 5)
        fg_filled = fg_spec.get('filled', False)
        fg_level_to_plot = fg_spec.get('level_to_plot', 0.25)
        x, y = fg_spec['data'][:, 0], fg_spec['data'][:, 1]
        label = fg_spec['label']
        if fg_scatter:
            sns.scatterplot(x, y, s=fg_scatter_size, color=fg_color)
        if fg_filled:
            sns.kdeplot(x=x, y=y, color=fg_color, fill=True,
                      thresh=fg_level_to_plot, alpha=alpha, levels=2, bw_method=0.3)
        else:
            sns.kdeplot(x=x, y=y, color=fg_color, fill=False,
                      thresh=fg_level_to_plot, levels=2, bw_method=0.3)
        # Generate the legend entry
        if fg_filled:
            plt.scatter([], [], marker='s', c=fg_color, label=label)
        else:
            plt.plot([], [], c=fg_color, linewidth=3, label=label)
    plt.xlim([-z_limit - 0.6, z_limit + 0.6])
    plt.ylim([-z_limit - 0.6, z_limit + 0.6])
    plt.gca().set_aspect('equal', adjustable='box')
    

def plot_odor_islands(pca_space, z_limit=15):
    plt.figure(figsize=(12, 8))
    color_palette = sns.color_palette('Set3', 15)
    fg_specs = []
    i = 0
    for main_group, subgroups in [('floral', ['muguet', 'lavender', 'jasmin']),
                                ('meaty', ['savory', 'beefy', 'roasted']),
                                ('alcoholic', ['cognac', 'fermented', 'winey']),
                                ]:
        main_embeddings = pca_space[pc_df[main_group]]
        fg_specs.append({'data': main_embeddings,
                        'filled': True,
                        'scatter': False,
                        'level_to_plot': 0.1,
                        'label': main_group.capitalize()})

        for subgroup in subgroups:
            island_embeddings = pca_space[pc_df[subgroup]]
            fg_specs.append({'data': island_embeddings,
                          'filled': False,
                          'scatter': False,
                          'level_to_plot': 0.2,
                          'label': subgroup.capitalize()})
        plot_islands(pca_space, None, *fg_specs, colors=color_palette[i:i+4], z_limit=z_limit)
        i += 5
        fg_specs = []
        plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))

plot_odor_islands(pca_gnn, z_limit=15)
plt.title('GNN Embeddings')
plot_odor_islands(pca_fp, z_limit=10)
plt.title('Fingerprints')
plot_odor_islands(pca_label, z_limit=8)
plt.title('True Labels')
plt.show()