In [None]:
import os
from path import Path

# Change working directory to project root
os.chdir(Path(os.getcwd()).parent)

In [None]:
import altair as alt
import ast
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import yaml

from omegaconf import OmegaConf

In [None]:
def read_config(base_dir, model_dir):
    path = os.path.join(base_dir, model_dir, '.hydra', 'config.yaml')

    with open(path, 'r') as f:
        return OmegaConf.create(yaml.safe_load(f))

def read_df(base_fir, model_dir, file, config):
    df = pd.read_csv(os.path.join(base_dir, model_dir, file))
    df['encoder'] = encoder(config.model.encoder)
    df['pooling'] = pooling(config.model.pooling)
    df['normalization'] = config.model.normalizer
    df['dataset'] = config.data.name
    return df

def encoder(text):
    if text == 'transformer':
        return text.capitalize()
    else:
        return text.upper()

def pooling(text):
    return text.replace('_', ' ').capitalize()

In [None]:
base_dir = 'outputs'
ablation_dfs = []
unique_ablation_dfs = []

for model_dir in sorted(os.listdir(base_dir)):
    if not model_dir.startswith('.') and model_dir != 'multirun.yaml':
        config = read_config(base_dir, model_dir)
        ablation_dfs.append(read_df(base_dir, model_dir, 'ablation.csv', config))
        unique_ablation_dfs.append(read_df(base_dir, model_dir, 'ablation_unique.csv', config))

ablation_df = pd.concat(ablation_dfs, sort=True)
unique_ablation_df = pd.concat(unique_ablation_dfs, sort=True)

ablation_df.distance_tokens = ablation_df.distance_tokens.map(ast.literal_eval)
ablation_df.attention_categories = ablation_df.distance_categories.map(ast.literal_eval)

In [None]:
def plot_map(df, title, labels=False):

    source = df[df.pooling == 'Category attention']
    base = alt.Chart(source, title=title, width=180, height=300)
    
    point = base.mark_point(filled=True, color='black').encode(
        y=alt.Y('ablation_category:N', title='', axis=alt.Axis(labels=labels), sort=None),
        x=alt.X('mean(distance_map)', scale=alt.Scale(domain=[0, 1]))
    )
    
    ci = base.mark_errorbar(extent='ci').encode(
        y=alt.Y('ablation_category:N', title='', sort=None),
        x=alt.X('distance_map:Q', title='')
    )

    source['approach'] = 'Category Attention'
    categpory_attention = base.mark_rule(color='red').encode(
        x=alt.X('mean(distance_map)', title='Mean Average Precision'),
        color=alt.Color('approach', title=''),
        tooltip='mean(distance_map)'
    )

    source['random_baseline'] = 'Random Order'
    random = base.mark_rule(strokeDash=[2,2]).encode(
        x=alt.X('mean(random_map)'),
        color=alt.Color('random_baseline', title=''),
        tooltip='mean(random_map)'
    )

    source = df[df.pooling == 'Self attention']
    source['self_attention_baseline'] = 'Self Attention'
    self_attention = alt.Chart(source).mark_rule(strokeDash=[6,2]).encode(
        x=alt.X('mean(distance_map)'),
        color=alt.Color('self_attention_baseline', title=''),
        tooltip='mean(distance_map)'
    )
    
    return (point + ci + categpory_attention + self_attention + random)

def plot_sample_size(df, title, labels=False):
    source = df.groupby('ablation_category').size().reset_index()
    source.columns = ['ablation_category', 'Sample Size']
    
    bar = alt.Chart(source, title=title, width=80, height=300).mark_text().encode(
        y=alt.Y('ablation_category:N', title='', axis=None),
        text='Sample Size:Q'
    )
    return bar


def plot(dataset, encoder, pooling='Category attention'):
    source = ablation_df[(ablation_df.dataset == dataset) & (ablation_df.encoder == encoder)]
    source_unique = unique_ablation_df[(unique_ablation_df.dataset == dataset) & (ablation_df.encoder == encoder)]
    
    return alt.hconcat(
        plot_map(source, title='Category Entities', labels=True),
        plot_map(source_unique, title='Unique Category Entities'),
        #plot_sample_size(source, title='Sample Size')
    ).configure_axis(
        labelFontSize=14,
        titleFontSize=14,
        labelFontWeight='normal',
        titleFontWeight='normal',
    ).configure_legend(
        labelFontSize=14
    ).configure_axisY(
        titleFontWeight='normal'
    ).configure_axisX(
        titlePadding=10
    ).configure_title(
        fontSize=14,
    )

In [None]:
plot('aviation_case_ablation', 'LSTM')

In [None]:
plot('aviation_email_classification', 'LSTM')

In [None]:
plot('ohsumed_classification', 'LSTM')

In [None]:
def dataset(text):
    if text == 'aviation_case_ablation':
        return 'Case (Aviation)'
    elif text == 'aviation_email_classification':
        return 'Email (Aviation)'
    elif text == 'ohsumed_classification':
        return 'OHSUMED (MeSH)'

source = ablation_df[ablation_df.pooling == 'Category attention']
source.dataset = source.dataset.map(dataset)

alt.data_transformers.disable_max_rows()
alt.Chart(source, width=300).mark_line(point=True).encode(
    x='encoder:N',
    y=alt.Y('average(attention_map)', scale=alt.Scale(domain=(0, 1)), title='Mean Pearson Correlation r'),
    color='dataset'
) + alt.Chart(source).mark_errorband(extent='ci').encode(
    x='encoder:N',
    y=alt.Y('average(attention_map)', scale=alt.Scale(domain=(1, 0)), title='Mean Pearson Correlation r'),
    color='dataset'
)