In [None]:
import json
import os
import sys
from dataclasses import dataclass
from datetime import datetime

import cairosvg
#import dataframe_image as dfi
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
import skunk
from scipy.stats import spearmanr

sys.path.insert(1,'../../')
from dglgcn import compute_threshold_from_split

sys.path.insert(1, '/path/to/application/app/folder')
today_date = datetime.today().date()
time_now = datetime.today().ctime()

# packages & settings for plotting
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
import urllib.request

urllib.request.urlretrieve('https://github.com/google/fonts/raw/main/ofl/ibmplexmono/IBMPlexMono-Regular.ttf', 'IBMPlexMono-Regular.ttf')
fe = font_manager.FontEntry(
    fname='IBMPlexMono-Regular.ttf',
    name='plexmono')
font_manager.fontManager.ttflist.append(fe)
plt.rcParams.update({'axes.facecolor':'#f5f4e9',
            'grid.color' : '#AAAAAA',
            'axes.edgecolor':'#333333',
            'figure.facecolor':'#FFFFFF',
            'axes.grid': False,
                     
            'axes.prop_cycle':   plt.cycler('color', plt.cm.Dark2.colors),
            'font.family': fe.name,
            'figure.figsize': (3.5,3.5 / 1.2),
            'ytick.left': True,
            'xtick.bottom': True   ,
            'figure.dpi': 300
           })

In [None]:
censor_region = "above"
censor_splits = [0.1, 0.5, 0.9]
run_date = '2024-05-06'

In [None]:
urllib.request.urlretrieve(
    "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/Lipophilicity.csv",
    "./lipophilicity.csv",
)
lipodata = pd.read_csv("./lipophilicity.csv")
lipodata = list(zip(lipodata.smiles,lipodata.exp))

In [None]:
def calculate_x_noise_levels(censor_intervals):
    '''
    Convert similarity ranges to feature noise levels
    Low similarity implies high noise level, so we "inverse" the numbers by subtracting them from 1
    
    Arg: censor_intervals is a list of strings with similarity scores
    
    Returns a list of noise levels (floats)
    '''
    x_noise_levels = []
    for c in censor_intervals:
        score1, score2 = map(float, c.split('-')) # convert string to scores
        midpoint = (score1 + score2) / 2
        x_level = 1 - midpoint 
        x_noise_levels.append(round(x_level, 3))
    return x_noise_levels
        

In [None]:
from matplotlib.offsetbox import AnnotationBbox
import torch

sys.path.insert(1,'../')
from dglgcn import compute_threshold_from_split

def get_file_path(censor_type, censor_split, censor_region, censor_interval, results_dir=None):
    # get files that store parity plot data
    task = f'{censor_type}_results_split{censor_split}_{censor_region}'
    if results_dir is None:
        dir_name = f'../all_results/gcn_{task}'
    else: 
        dir_name = results_dir + f'../all_results/gcn_{task}'
        
    file_name = f'{dir_name}/parityplotdata_trial0_{censor_type}_{censor_interval}.json'
    return file_name


def make_parity_plot(censor_type, censor_split, censor_region, censor_interval, results_dir=None):
    # helper function to make parity plots for inset figures
    json_path = get_file_path(censor_type, censor_split, censor_region, censor_interval, results_dir=results_dir)
    with open(json_path, 'r') as f:
        rmse, lower_rmse, upper_rmse, corr, lower_corr, upper_corr, ytest, yhat = json.load(f)
        
    labels = [label for _, label in lipodata]
    threshold = compute_threshold_from_split(labels, censor_split, censor_region)
    
    fig1, ax1 = plt.subplots(figsize=(1,1))
    ytest_t, yhat_t = torch.tensor(ytest), torch.tensor(yhat)
    upper_ytest = ytest_t[ytest_t >= threshold]
    lower_ytest = ytest_t[ytest_t < threshold]
    upper_yhat = yhat_t[ytest_t >= threshold]
    lower_yhat = yhat_t[ytest_t < threshold]
    if censor_region == 'above':
        lower_color = 'C0'
        upper_color = 'C1'

    else:
        lower_color = 'C1'
        upper_color = 'C0'
    ax1.scatter(upper_ytest, upper_yhat, s=1, c=upper_color)
    ax1.scatter(lower_ytest, lower_yhat, s=1, c=lower_color)
    #ax1.plot([threshold, threshold], [-2, 5], '--', linewidth=0.5, c='C2')
    ax1.plot([-2, 5], [-2, 5], c='black', linewidth=0.5)
    ax1.set_xlim(-2,5)
    ax1.set_ylim(-2,5)

    # remove ticks and tick labels to look simple
    ax1.set_xticks([])
    ax1.set_yticks([])

    # save to svg file
    filename, _ = os.path.splitext(json_path) # remove .json extension
    svg_filename = f'{filename}_yt{threshold}.svg'
    fig1.patch.set_alpha(0.0)
    plt.tight_layout()
    plt.savefig(svg_filename, bbox_inches='tight')
    plt.close()
    return svg_filename

def inset_fig_placeholder(i, ax, xy, censor_type, no_noise=True):
    '''
    Placing empty annotation boxes for 2nd inset figure in each subplot
    Can be before, during, or after 'process_and_plot'
    
    Steps
    1. plot inset figures & save using savefig('filename.svg') -- done
    2. make skunk annotation boxes
    3. use skunk.insert to insert figures --> last, after legend is added
    '''
    
    # place inset figure holder
    if no_noise:
        name_box = f'{censor_type}_no-noise{i}'
        connectionstyle="arc3,rad=-0.2"
        #xybox = (0.03, 0.97) # left top corner
        xybox = (0.01,0.97)
    else:
        name_box = f'{censor_type}_max-noise{i}'
        connectionstyle="arc3,rad=0.2"
        #xybox = (0.73, 0.97) # right top corner
        xybox = (0.76, 0.97)
    #box = skunk.Box(75,75, name_box)
    box = skunk.Box(45, 45, name_box)
    ab = AnnotationBbox(box, xy, # where it points
                        xybox=xybox, # where the box is located
                        xycoords='data',
                        boxcoords=("axes fraction", "axes fraction"),
                        box_alignment=(0,1),
                        arrowprops=dict(arrowstyle='->,head_length=0.4,head_width=0.2',
                                        connectionstyle=connectionstyle, 
                                        fc="w",))
    ax.add_artist(ab)
    return ax

def insert_skunk_figs(censor_types, censor_splits, censor_region, censor_intervals, results_dir=None):
    if isinstance(censor_types, str):
        censor_types = [censor_types]
        censor_intervals = [censor_intervals]
    
    skunk_dict={}
    for j, ctype in enumerate(censor_types):
        intervals = censor_intervals[j]
        for i, split in enumerate(censor_splits):
            # fig for point 0
            svg_filename1 = make_parity_plot(ctype, split, censor_region, intervals[0], results_dir=results_dir)
            skunk_dict[f'{ctype}_no-noise{i}'] = svg_filename1

            # fig for last point
            svg_filename2 = make_parity_plot(ctype, split, censor_region, intervals[-1], results_dir=results_dir)
            skunk_dict[f'{ctype}_max-noise{i}'] = svg_filename2
    svg = skunk.insert(skunk_dict)
    return svg

def make_xnoise_mapping_table(censor_intervals): #filetype='png'):
        fig1, ax1 = plt.subplots()
        # hide axes
        fig1.patch.set_visible(False)
        ax1.axis('off')
        ax1.axis('tight')
        print('censor_intervals', censor_intervals)
        df1 = pd.DataFrame({
            'X Noise Level': calculate_x_noise_levels(censor_intervals),
            'Similarity Score': censor_intervals,
        })
        filename = 'xnoise_gcn_mapping_table'

        ax1.table(
            cellText=df1.values, 
            colLabels=df1.columns, 
            loc='center',
            colColours =["lightblue"] * 2, # this doesn't work? 
        )

        fig1.tight_layout()
        #filename = f'{filename}.{filetype}'
        fig1.savefig(f'{filename}.svg', bbox_inches='tight')
        fig1.savefig(f'{filename}.png', bbox_inches='tight', dpi=300)
        plt.close()
        print(f'Table of X noise levels saved to {filename}')

def process_and_plot(i, censor_split, ax, censor_type, run_date, output_dir=None, metric='corr', ending="", hasNaN=False, NaNlist=None, title=True):
    task = f'{censor_type}_results_split{censor_split}_{censor_region}{ending}'
    if output_dir is None:
        dir_name = f'../all_results/gcn_{task}'
    else: 
        dir_name = output_dir + f'all_results/gcn_{task}'
    labels = [label for _, label in lipodata]
    threshold = compute_threshold_from_split(labels, censor_split, censor_region)
    print(f'For censor split {censor_split}, Threshold =',threshold)

    if title:
        ax.set_title(f'{censor_split * 100:.0f}% sensitive data', fontsize=12)
    
    if hasNaN and censor_split == 0.9:
        # special case for omitting 90% sensitive data --> often gets NaN for correlation values
        file_path = f'{dir_name}/dataframe_{run_date}_revised.json'
        #ax.set_title(f'{censor_split * 100:.0f}% sensitive data$^*$') # add asterisk
    else:
        file_path = f'{dir_name}/dataframe_{run_date}.json'
        
    
    # load json file and plot results
    df = pd.read_json(file_path)
    censor_intervals = df.iloc[:, 0].to_list()
    mean_above = df[f'upper {metric}'].to_list()
    std_above = df[f'upper {metric} std']
    mean_below = df[f'lower {metric}'].to_list()
    std_below = df[f'lower {metric} std']
    
    if censor_type == 'xnoise':
        sim_scores = censor_intervals # save the list just in case
        censor_intervals = calculate_x_noise_levels(censor_intervals)

    # Plot means with fill_between for standard deviation
    ax.plot(censor_intervals, mean_above, marker='x', color='C1', label='Sensitive Labels')
    ax.fill_between(censor_intervals, np.subtract(mean_above, std_above), np.add(mean_above, std_above), color='C1', alpha=0.2)

    ax.plot(censor_intervals, mean_below, marker='^', color='C0', label='Non-sensitive Labels')
    ax.fill_between(censor_intervals, np.subtract(mean_below, std_below), np.add(mean_below, std_below), color='C0', alpha=0.2)
     
    # Placeholders for inset figures
    no_noise_xcoord = censor_intervals[0]
    max_noise_xcoord = censor_intervals[-1]
#     if censor_type == 'xnoise':
#         no_noise_xcoord = 0
#         max_noise_xcoord = len(censor_intervals)-1
#     else:
#         no_noise_xcoord = censor_intervals[0]
#         max_noise_xcoord = censor_intervals[-1]
        
    # inset fig for no-noise parity plot, two arrows pointing at two curve regions
    xy1a = (no_noise_xcoord, mean_above[0])
    xy1b = (no_noise_xcoord, mean_below[0])
    inset_fig_placeholder(i, ax, xy1a, censor_type, no_noise=True)
    inset_fig_placeholder(i, ax, xy1b, censor_type, no_noise=True) # for 2nd arrow
    
    # inset fig for max-noise parity plot
    xy2a = (max_noise_xcoord, mean_above[-1])
    xy2b = (max_noise_xcoord, mean_below[-1])
    inset_fig_placeholder(i, ax, xy2a, censor_type, no_noise=False)
    inset_fig_placeholder(i, ax, xy2b, censor_type, no_noise=False)

    # finalize subplots
#     if censor_type == 'xnoise':
#         ax.set_xticks(range(len(censor_intervals)))
#         ax.set_xticklabels(calculate_x_noise_levels(censor_intervals)) # replace (range(len(censor_intervals)))
#         ax.xaxis.set_major_locator(ticker.MaxNLocator(nbins=5))
        
    ax.autoscale(enable=True, axis='x', tight=True)
    if metric == 'corr':
        ax.set_ylim(-0.2, 1.4)
    else:
        ax.set_ylim(0, 4.0)
    ax.grid(True)
    
    if NaNlist is not None and censor_split == 0.9:  # special case
        for i in NaNlist:
            ax.annotate('*', (censor_intervals[i], mean_above[i]), textcoords="offset points", xytext=(0,5), ha='center', fontsize=15)
            ax.annotate('*', (censor_intervals[i], mean_below[i]), textcoords="offset points", xytext=(0,5), ha='center', fontsize=15)   
    
    if censor_type == 'xnoise':
        ax.set_xlim(0,1.0)
        return sim_scores
    return censor_intervals
    
def plot_all_splits(censor_type, metric='corr', ending='', hasNaN=False, NaNlist=None):
    if metric not in ['rmse', 'corr']:
        raise ValueError("Invalid metric specified. Choose 'rmse' or 'corr'.")    
    elif metric == 'rmse':
        ytitle = 'Root Mean Square Error'
    else:
        ytitle = 'Spearman Correlation'
    if censor_type == 'omit':
        xtitle = '% sensitive data omitted from training data'
    elif censor_type == 'xnoise':
        xtitle = 'Feature noise level applied to sensitive data in training data'
    elif censor_type == 'ynoise':
        xtitle = 'Label noise level applied to sensitive data in training data'
    else:
        raise ValueError('censor_type must be either "omit", "xnoise", or "ynoise"')
    
    fig, axs = plt.subplots(1, 3, figsize=(18, 5))
    axs[0].set_ylabel(ytitle, fontsize=16)
    axs[1].set_xlabel(xtitle, labelpad=10, fontsize=16)
    
    for i, split in enumerate(censor_splits):
        censor_intervals = process_and_plot(i, split, axs[i], censor_type, run_date, metric=metric, ending=ending, hasNaN=hasNaN, NaNlist=NaNlist)
        
    legend_handles = [
        plt.Line2D([0], [0], marker='x', color='C1', lw=2, label='Sensitive Labels'),
        plt.Line2D([0], [0], marker='^', color='C0', lw=2, label='Non-sensitive Labels')
    ]
    fig.legend(handles=legend_handles, loc='center', bbox_to_anchor=(0.92, 0.5))
    plt.tight_layout(rect=[0, 0, 0.85, 1])
    
    svg = insert_skunk_figs(censor_type, censor_splits, censor_region, censor_intervals)
    plt.close()
    if censor_type =='xnoise':
        make_xnoise_mapping_table(censor_intervals)
    return svg

In [None]:
# combine figures of all GCN results

# todo: update corresponding noise levels for feature noise
def plot_everything(metric='corr'):
    if metric not in ['rmse', 'corr']:
        raise ValueError("Invalid metric specified. Choose 'rmse' or 'corr'.")
    elif metric == 'rmse':
        ytitle = 'Root Mean Square Error'
    else:
        ytitle = 'Spearman Correlation'
    
    combined_fig, axs = plt.subplots(3, 3, figsize=(10,10), sharey=True) # 3 splits for each of 3 censor types
    
    censor_intervals = []
    
    # plot omission results
    run_date = '2024-05-11'
    #axs[0,1].set_title('GCN Accuracy with Omission') #, fontsize=18)
    axs[0,0].set_ylabel(ytitle, fontsize=12)
    axs[0,1].set_xlabel('Percentage of Sensitive Data Omitted from Training Data\n ', fontsize=16) #labelpad=10) #, fontsize=16)
    for i, split in enumerate(censor_splits):
        omit_intervals = process_and_plot(i, split, axs[0, i], 'omit', run_date, metric=metric, ending="_150epochs", hasNaN=True, NaNlist=[2,8])
    censor_intervals.append(omit_intervals)

    # plot xnoise results
    run_date = '2024-05-06'
    #axs[1,1].set_title('GCN Accuracy with $X$ Noise') #, fontsize=18)
    axs[1,0].set_ylabel(ytitle, fontsize=12)
    axs[1,1].set_xlabel(r'Level of Feature Noise ($\delta X$) Applied to Sensitive Data in Training Data' + '\n ', fontsize=16) #, labelpad=10) #, fontsize=16)
    for i, split in enumerate(censor_splits):
        xnoise_intervals = process_and_plot(i, split, axs[1, i], 'xnoise', run_date, metric=metric) #, title=False)
    make_xnoise_mapping_table(xnoise_intervals)
    censor_intervals.append(xnoise_intervals)
    

    # plot ynoise results
    run_date = '2024-05-11'
    #axs[2,1].set_title('GCN Accuracy with $y$ Noise') #, fontsize=18)
    axs[2,0].set_ylabel(ytitle, fontsize=12)
    axs[2,1].set_xlabel(r'Level of Label Noise ($\delta y$) Applied to Sensitive Data in Training Data', fontsize=16) #, labelpad=10) #, fontsize=16)
    for i, split in enumerate(censor_splits):
        ynoise_intervals = process_and_plot(i, split, axs[2, i], 'ynoise', run_date, metric=metric) #, title=False)
    censor_intervals.append(ynoise_intervals)
    
    plt.subplots_adjust(hspace=1.0)
    legend_handles = [
        plt.Line2D([0], [0], marker='x', color='C1', lw=2, label='Sensitive Labels'),
        plt.Line2D([0], [0], marker='^', color='C0', lw=2, label='Non-sensitive Labels')
    ]
    combined_fig.legend(handles=legend_handles, bbox_to_anchor=(0.97, 0.06)) #, loc='center') #, bbox_to_anchor=(0.92, 0.5))
    
    combined_fig.tight_layout(rect=[0, 0.05, 1, 1]) #tight_layout(rect=[0, 0, 0.85, 1])
    
    svg = insert_skunk_figs(['omit', 'xnoise', 'ynoise'], censor_splits, censor_region, censor_intervals)
    plt.close()
    return svg
    

In [None]:
svg = plot_everything()
skunk.display(svg)

with open('paper_figs/gcn_corr_with_all_titles_2024-12-31.svg', 'w') as f:
    f.write(svg)

cairosvg.svg2png(bytestring=svg, write_to='paper_figs/gcn_corr_with_all_titles_2024-12-31.png', dpi=300)

In [None]:
cairosvg.svg2png(bytestring=svg, write_to='paper_figs/gcn_corr_with_all_titles_2024-08-27.png', dpi=300)

In [None]:
svg = plot_all_splits(censor_type='omit', metric='rmse')
skunk.display(svg)

# with open('replaced.svg', 'w') as f:
#     f.write(svg)

#cairosvg.svg2png(bytestring=svg, write_to='overview_gcn_omit_rmse.png', dpi=300)

In [None]:
run_date = '2024-05-06'
svg = plot_all_splits(censor_type='xnoise', metric='rmse')
skunk.display(svg)

#cairosvg.svg2png(bytestring=svg, write_to='overview_gcn_xnoise_rmse.png', dpi=300)