In [None]:
# OPTIONAL: Load the "autoreload" extension so that code can change
%load_ext autoreload

# OPTIONAL: always reload modules so that as you change code in src, it gets loaded

%autoreload 2

In [None]:
import pandas as pd
import numpy as np
np.random.seed(42)
import random
random.seed(42)

from ast import literal_eval

from pdb import set_trace

from pathlib import Path
Path('../../data/processed/explain_labeling/pictures/jointbert/onlybert/').mkdir(parents=True, exist_ok=True)

In [None]:
wordclass_dict = {'1':'model number',
                 '2':'brand name',
                 '3':'model name',
                 '4':'char. attr.',
                 '5':'stopword',
                 '6':'product type',
                 '7':'descr. word',
                 '8':'non-english word',
                 '9':'noisy model number',
                 '10':'noise from other product'}

wordclass_labels = pd.read_csv('../../data/processed/explain_labeling/wordclass_labeling_labeled.csv')
wordclass_labels = wordclass_labels.set_index('pair_id', drop=False)
wordclass_labels = wordclass_labels.fillna('')
wordclass_labels[['brand_left', 'title_left', 'brand_right', 'title_right']] = wordclass_labels[['brand_left', 'title_left', 'brand_right', 'title_right']].applymap(lambda x: x.lower().split())
wordclass_labels[['brand_left_wordclasses', 'title_left_wordclasses', 'brand_right_wordclasses', 'title_right_wordclasses']] = wordclass_labels[['brand_left_wordclasses', 'title_left_wordclasses', 'brand_right_wordclasses', 'title_right_wordclasses']].applymap(lambda x: literal_eval(x))
wordclass_labels[['brand_left_wordclasses', 'title_left_wordclasses', 'brand_right_wordclasses', 'title_right_wordclasses']] = wordclass_labels[['brand_left_wordclasses', 'title_left_wordclasses', 'brand_right_wordclasses', 'title_right_wordclasses']].applymap(lambda x: [wordclass_dict[x] for x in x])
wordclass_labels.head()

In [None]:
def build_dict_left(row):
    wordclass_dict_left = dict()
    
    for i, word in enumerate(row['brand_left']):
        if word not in wordclass_dict_left.keys():
                wordclass_dict_left[word] = row['brand_left_wordclasses'][i]
        else:
            if wordclass_dict_left[word] != row['brand_left_wordclasses'][i]:
                print(f'same word got different classes: {word} in {row["pair_id"]}')
                
    for i, word in enumerate(row['title_left']):
        if word not in wordclass_dict_left.keys():
            try:
                wordclass_dict_left[word] = row['title_left_wordclasses'][i]
            except IndexError:
                print(row['pair_id'])
        else:
            if wordclass_dict_left[word] != row['title_left_wordclasses'][i]:
                print(f'same word got different classes: {word} in {row["pair_id"]}')
                
    return wordclass_dict_left

def build_dict_right(row):
    wordclass_dict_right = dict()
                
    for i, word in enumerate(row['brand_right']):
        if word not in wordclass_dict_right.keys():
            wordclass_dict_right[word] = row['brand_right_wordclasses'][i]
        else:
            if wordclass_dict_right[word] != row['brand_right_wordclasses'][i]:
                print(f'same word got different classes: {word} in {row["pair_id"]}')
                
    for i, word in enumerate(row['title_right']):
        if word not in wordclass_dict_right.keys():
            wordclass_dict_right[word] = row['title_right_wordclasses'][i]
        else:
            if wordclass_dict_right[word] != row['title_right_wordclasses'][i]:
                print(f'same word got different classes: {word} in {row["pair_id"]}')
                
    return wordclass_dict_right


wordclass_labels['labeldict_left'] = wordclass_labels.apply(build_dict_left, axis=1)
wordclass_labels['labeldict_right'] = wordclass_labels.apply(build_dict_right, axis=1)

In [None]:
wordclass_labels.head()

In [None]:
explanations_distilbert = pd.read_pickle('../../data/processed/explain_labeling/explained/distilbert.pkl.gz')
explanations_deepmatcher = pd.read_pickle('../../data/processed/explain_labeling/explained/deepmatcher.pkl.gz')
explanations_bert = pd.read_pickle('../../data/processed/explain_labeling/explained/bert.pkl.gz')
explanations_jointbert = pd.read_pickle('../../data/processed/explain_labeling/explained/jointbert.pkl.gz')
explanations_distilbert = explanations_distilbert.set_index('data_inx', drop=False)
explanations_deepmatcher = explanations_deepmatcher.set_index('data_inx', drop=False)
explanations_bert = explanations_bert.set_index('data_inx', drop=False)
explanations_jointbert = explanations_jointbert.set_index('data_inx', drop=False)
explanations_deepmatcher.head()

In [None]:
challenges_df = pd.read_csv('../../data/processed/explain_labeling/challenge_lookup.csv')
challenges_df = challenges_df.set_index('pair_id', drop=False)

challenges_df.head()

In [None]:
distilbert_results = pd.read_pickle('../../src/productbert/saved/models/BT-DistilBERT-FT-computers-xlarge-swctest/0921_172448/predictions.pkl.gz')
distilbert_results['label_distilbert'] = distilbert_results['predictions'].apply(lambda x: 1 if x >= 0.5 else 0)
distilbert_results = distilbert_results.set_index('pair_id', drop=False)
distilbert_results = distilbert_results[['pair_id', 'label_distilbert']]

deepmatcher_results = pd.read_csv('../../data/processed/inspection/wdc-lspc/deepmatcher/rnn_abs-diff_standard_epochs50_ratio6_batch16_lr0.001_lrdecay0.8_fasttext.en.bin_brand-title_preprocessed_computers_trainonly_xlarge_magellan_pairs_run1_preprocessed_computers_new_testset_1500_magellan_pairs.csv.gz')
deepmatcher_results['label_deepmatcher'] = deepmatcher_results['match_score'].apply(lambda x: 1 if x >= 0.5 else 0)
deepmatcher_results = deepmatcher_results.set_index('pair_id', drop=False)
deepmatcher_results = deepmatcher_results[['pair_id', 'label_deepmatcher']]

jointbert_results = pd.read_pickle('../../src/productbert/saved/models/BT-JointBERT-FT-computers-xlarge-swctest/1024_165744/predictions.pkl.gz')
jointbert_results['label_jointbert'] = jointbert_results['predictions'].apply(lambda x: 1 if x >= 0.5 else 0)
jointbert_results = jointbert_results.set_index('pair_id', drop=False)
jointbert_results = jointbert_results[['pair_id', 'label_jointbert']]

bert_results = pd.read_pickle('../../src/productbert/saved/models/BT-BERT-FT-computers-xlarge-swctest/1024_164723/predictions.pkl.gz')
bert_results['label_bert'] = bert_results['predictions'].apply(lambda x: 1 if x >= 0.5 else 0)
bert_results = bert_results.set_index('pair_id', drop=False)
bert_results = bert_results[['pair_id', 'label_bert']]

In [None]:
instances_distilbert = explanations_distilbert['data_inx'].unique().tolist()
instances_deepmatcher = explanations_deepmatcher['data_inx'].unique().tolist()
instances_bert = explanations_bert['data_inx'].unique().tolist()
instances_jointbert = explanations_jointbert['data_inx'].unique().tolist()

In [None]:
explanations_distilbert['wordclass'] = explanations_distilbert.apply(lambda x: wordclass_labels.loc[x.name]['labeldict_left'][x['token']] if x['tuple'] == 'L' else wordclass_labels.loc[x.name]['labeldict_right'][x['token']], axis=1)
explanations_distilbert['label'] = explanations_distilbert.apply(lambda x: wordclass_labels.loc[x.name]['label'], axis=1)
explanations_distilbert['label_distilbert'] = explanations_distilbert.apply(lambda x: distilbert_results.loc[x.name]['label_distilbert'], axis=1)
explanations_distilbert['label_jointdistilbert'] = explanations_distilbert.apply(lambda x: jointdistilbert_results.loc[x.name]['label_jointdistilbert'], axis=1)
explanations_distilbert['label_deepmatcher'] = explanations_distilbert.apply(lambda x: deepmatcher_results.loc[x.name]['label_deepmatcher'], axis=1)
explanations_distilbert['label_bert'] = explanations_distilbert.apply(lambda x: bert_results.loc[x.name]['label_bert'], axis=1)
explanations_distilbert['label_jointbert'] = explanations_distilbert.apply(lambda x: jointbert_results.loc[x.name]['label_jointbert'], axis=1)

explanations_distilbert = explanations_distilbert.rename(columns={'data_inx':'pair_id'})

explanations_deepmatcher['wordclass'] = explanations_deepmatcher.apply(lambda x: wordclass_labels.loc[x.name]['labeldict_left'][x['token']] if x['tuple'] == 'L' else wordclass_labels.loc[x.name]['labeldict_right'][x['token']], axis=1)
explanations_deepmatcher['label'] = explanations_deepmatcher.apply(lambda x: wordclass_labels.loc[x.name]['label'], axis=1)
explanations_deepmatcher['label_distilbert'] = explanations_deepmatcher.apply(lambda x: distilbert_results.loc[x.name]['label_distilbert'], axis=1)
explanations_deepmatcher['label_jointdistilbert'] = explanations_deepmatcher.apply(lambda x: jointdistilbert_results.loc[x.name]['label_jointdistilbert'], axis=1)
explanations_deepmatcher['label_deepmatcher'] = explanations_deepmatcher.apply(lambda x: deepmatcher_results.loc[x.name]['label_deepmatcher'], axis=1)
explanations_deepmatcher['label_bert'] = explanations_deepmatcher.apply(lambda x: bert_results.loc[x.name]['label_bert'], axis=1)
explanations_deepmatcher['label_jointbert'] = explanations_deepmatcher.apply(lambda x: jointbert_results.loc[x.name]['label_jointbert'], axis=1)

explanations_deepmatcher = explanations_deepmatcher.rename(columns={'data_inx':'pair_id'})

explanations_bert['wordclass'] = explanations_bert.apply(lambda x: wordclass_labels.loc[x.name]['labeldict_left'][x['token']] if x['tuple'] == 'L' else wordclass_labels.loc[x.name]['labeldict_right'][x['token']], axis=1)
explanations_bert['label'] = explanations_bert.apply(lambda x: wordclass_labels.loc[x.name]['label'], axis=1)
explanations_bert['label_distilbert'] = explanations_bert.apply(lambda x: distilbert_results.loc[x.name]['label_distilbert'], axis=1)
explanations_bert['label_jointdistilbert'] = explanations_bert.apply(lambda x: jointdistilbert_results.loc[x.name]['label_jointdistilbert'], axis=1)
explanations_bert['label_deepmatcher'] = explanations_bert.apply(lambda x: deepmatcher_results.loc[x.name]['label_deepmatcher'], axis=1)
explanations_bert['label_bert'] = explanations_bert.apply(lambda x: bert_results.loc[x.name]['label_bert'], axis=1)
explanations_bert['label_jointbert'] = explanations_bert.apply(lambda x: jointbert_results.loc[x.name]['label_jointbert'], axis=1)

explanations_bert = explanations_bert.rename(columns={'data_inx':'pair_id'})

explanations_jointbert['wordclass'] = explanations_jointbert.apply(lambda x: wordclass_labels.loc[x.name]['labeldict_left'][x['token']] if x['tuple'] == 'L' else wordclass_labels.loc[x.name]['labeldict_right'][x['token']], axis=1)
explanations_jointbert['label'] = explanations_jointbert.apply(lambda x: wordclass_labels.loc[x.name]['label'], axis=1)
explanations_jointbert['label_distilbert'] = explanations_jointbert.apply(lambda x: distilbert_results.loc[x.name]['label_distilbert'], axis=1)
explanations_jointbert['label_jointdistilbert'] = explanations_jointbert.apply(lambda x: jointdistilbert_results.loc[x.name]['label_jointdistilbert'], axis=1)
explanations_jointbert['label_deepmatcher'] = explanations_jointbert.apply(lambda x: deepmatcher_results.loc[x.name]['label_deepmatcher'], axis=1)
explanations_jointbert['label_bert'] = explanations_jointbert.apply(lambda x: bert_results.loc[x.name]['label_bert'], axis=1)
explanations_jointbert['label_jointbert'] = explanations_jointbert.apply(lambda x: jointbert_results.loc[x.name]['label_jointbert'], axis=1)

explanations_jointbert = explanations_jointbert.rename(columns={'data_inx':'pair_id'})

explanations_deepmatcher.head()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(font_scale=2.0)

def plot_explanations(explanations, filename = 'default'):
    median_width = 0.2

    g = sns.FacetGrid(explanations, col="label", sharex=True, sharey=True, height=8.27, aspect=11.7/8.27, ylim=(-0.4, 0.4))
    g.map_dataframe(sns.stripplot, x='wordclass', y='weight', hue='model', palette='tab10', dodge=True, order=['model number', 'brand name', 'model name', 'char. attr.', 'stopword', 'product type', 'descr. word'], hue_order=['BERT','JointBERT', 'Deepmatcher'], zorder=1)
    g.set_axis_labels(None, "weight")
    g.add_legend()
    
    median_values_bert = {}
    median_values_jointbert = {}
    median_values_deepmatcher = {}
    
    for i, ax in enumerate(g.axes.flat):
        
        ax.set_xticklabels(ax.get_xticklabels(), rotation=30, horizontalalignment='right')

        for tick, text in zip(ax.get_xticks(), ax.get_xticklabels()):
            sample_name = text.get_text()  # "X" or "Y"
            
            median_val_bert = explanations[(explanations['wordclass']==sample_name) & (explanations['label']==i) & (explanations['model']=='BERT')].weight.median()
            median_val_jointbert = explanations[(explanations['wordclass']==sample_name) & (explanations['label']==i) & (explanations['model']=='JointBERT')].weight.median()
            median_val_deepmatcher = explanations[(explanations['wordclass']==sample_name) & (explanations['label']==i) & (explanations['model']=='Deepmatcher')].weight.median()
            
            if sample_name in median_values_bert.keys():
                median_values_bert[sample_name].append(median_val_bert)
                median_values_jointbert[sample_name].append(median_val_jointbert)
                median_values_deepmatcher[sample_name].append(median_val_deepmatcher)
            else:
                median_values_bert[sample_name] = [median_val_bert]
                median_values_jointbert[sample_name] = [median_val_jointbert]
                median_values_deepmatcher[sample_name] = [median_val_deepmatcher]
            
            
            # plot horizontal lines across the column, centered on the tick
            ax.plot([tick-2*median_width, tick-1.0*median_width], [median_val_bert, median_val_bert],
                   'k', lw=4, zorder=2)
            ax.plot([tick-0.5*median_width, tick+0.5*median_width], [median_val_jointbert, median_val_jointbert],
                   'k', lw=4, zorder=2)
            ax.plot([tick+1.0*median_width, tick+2*median_width], [median_val_deepmatcher, median_val_deepmatcher],
                   'k', lw=4, zorder=2)

        if 'label' in median_values_bert.keys():
                median_values_bert['label'].append(i)
                median_values_jointbert['label'].append(i)
                median_values_deepmatcher['label'].append(i)
        else:
            median_values_bert['label'] = [i]
            median_values_jointbert['label'] = [i]
            median_values_deepmatcher['label'] = [i]
        
        bert_df = explanations[(explanations['label']==i) & (explanations['model']=='BERT')]
        jointbert_df = explanations[(explanations['label']==i) & (explanations['model']=='JointBERT')]
        deepmatcher_df = explanations[(explanations['label']==i) & (explanations['model']=='Deepmatcher')]
        
        count_bert_correct = len(bert_df[bert_df['label'] == bert_df['label_bert']].reset_index(level="wordclass", drop=True).index.unique())
        count_bert_wrong = len(bert_df[bert_df['label'] != bert_df['label_bert']].reset_index(level="wordclass", drop=True).index.unique())
        count_bert_all = count_bert_correct + count_bert_wrong

        count_jointbert_correct = len(jointbert_df[jointbert_df['label'] == jointbert_df['label_jointbert']].reset_index(level="wordclass", drop=True).index.unique())
        count_jointbert_wrong = len(jointbert_df[jointbert_df['label'] != jointbert_df['label_jointbert']].reset_index(level="wordclass", drop=True).index.unique())
        count_jointbert_all = count_jointbert_correct + count_jointbert_wrong
        
        count_deepmatcher_correct = len(deepmatcher_df[deepmatcher_df['label'] == deepmatcher_df['label_deepmatcher']].reset_index(level="wordclass", drop=True).index.unique())
        count_deepmatcher_wrong = len(deepmatcher_df[deepmatcher_df['label'] != deepmatcher_df['label_deepmatcher']].reset_index(level="wordclass", drop=True).index.unique())
        count_deepmatcher_all = count_deepmatcher_correct + count_deepmatcher_wrong
        
        print(f'Label: {i}: BERT correct: {count_bert_correct}/{count_bert_all}')
        print(f'Label: {i}: JointBERT correct: {count_jointbert_correct}/{count_jointbert_all}')
        print(f'Label: {i}: Deepmatcher correct: {count_deepmatcher_correct}/{count_deepmatcher_all}')
    
    median_bert_df = pd.DataFrame.from_dict(median_values_bert)
    median_bert_df['model'] = 'BERT'
    median_jointbert_df = pd.DataFrame.from_dict(median_values_jointbert)
    median_jointbert_df['model'] = 'JointBERT'
    median_deepmatcher_df = pd.DataFrame.from_dict(median_values_deepmatcher)
    median_deepmatcher_df['model'] = 'Deepmatcher'
    
    median_to_file = median_bert_df.append(median_jointbert_df)
    median_to_file = median_to_file.append(median_deepmatcher_df)
    
    #median_to_file.to_csv(f'../../data/processed/explain_labeling/pictures/MEDIAN_{filename}.csv', index=False, float_format="%.4f")
    
    plt.subplots_adjust(bottom=0.25)
    #g.fig.suptitle(f'Combined Classifications\nChallenge: {name}')
    plt.savefig(f'../../data/processed/explain_labeling/pictures/jointbert/onlybert/MEDIAN_{filename}.png')
    plt.show()
    
def plot_explanations_avg(explanations, filename = 'default'):
    median_width = 0.2

    g = sns.FacetGrid(explanations, col="label", sharex=True, sharey=True, height=8.27, aspect=11.7/8.27, ylim=(-0.4, 0.4))
    g.map_dataframe(sns.stripplot, x='wordclass', y='weight', hue='model', palette='tab10', dodge=True, order=['model number', 'brand name', 'model name', 'char. attr.', 'stopword', 'product type', 'descr. word'], hue_order=['BERT','JointBERT', 'Deepmatcher'], zorder=1)
    g.set_axis_labels("wordclass", "weight")
    g.add_legend()
    
    median_values_bert = {}
    median_values_jointbert = {}
    median_values_deepmatcher = {}
    
    for i, ax in enumerate(g.axes.flat):
        
        ax.set_xticklabels(ax.get_xticklabels(), rotation=30, horizontalalignment='right')

        for tick, text in zip(ax.get_xticks(), ax.get_xticklabels()):
            sample_name = text.get_text()  # "X" or "Y"
            
            median_val_bert = explanations[(explanations['wordclass']==sample_name) & (explanations['label']==i) & (explanations['model']=='BERT')].weight.mean()
            median_val_jointbert = explanations[(explanations['wordclass']==sample_name) & (explanations['label']==i) & (explanations['model']=='JointBERT')].weight.mean()
            median_val_deepmatcher = explanations[(explanations['wordclass']==sample_name) & (explanations['label']==i) & (explanations['model']=='Deepmatcher')].weight.mean()
            
            if sample_name in median_values_bert.keys():
                median_values_bert[sample_name].append(median_val_bert)
                median_values_jointbert[sample_name].append(median_val_jointbert)
                median_values_deepmatcher[sample_name].append(median_val_deepmatcher)
            else:
                median_values_bert[sample_name] = [median_val_bert]
                median_values_jointbert[sample_name] = [median_val_jointbert]
                median_values_deepmatcher[sample_name] = [median_val_deepmatcher]
            
            
            # plot horizontal lines across the column, centered on the tick
            ax.plot([tick-2*median_width, tick-1.0*median_width], [median_val_bert, median_val_bert],
                   'k', lw=4, zorder=2)
            ax.plot([tick-0.5*median_width, tick+0.5*median_width], [median_val_jointbert, median_val_jointbert],
                   'k', lw=4, zorder=2)
            ax.plot([tick+1.0*median_width, tick+2*median_width], [median_val_deepmatcher, median_val_deepmatcher],
                   'k', lw=4, zorder=2)

        if 'label' in median_values_bert.keys():
                median_values_bert['label'].append(i)
                median_values_jointbert['label'].append(i)
                median_values_deepmatcher['label'].append(i)
        else:
            median_values_bert['label'] = [i]
            median_values_jointbert['label'] = [i]
            median_values_deepmatcher['label'] = [i]
        
        bert_df = explanations[(explanations['label']==i) & (explanations['model']=='BERT')]
        jointbert_df = explanations[(explanations['label']==i) & (explanations['model']=='JointBERT')]
        deepmatcher_df = explanations[(explanations['label']==i) & (explanations['model']=='Deepmatcher')]
        
        count_bert_correct = len(bert_df[bert_df['label'] == bert_df['label_bert']].reset_index(level="wordclass", drop=True).index.unique())
        count_bert_wrong = len(bert_df[bert_df['label'] != bert_df['label_bert']].reset_index(level="wordclass", drop=True).index.unique())
        count_bert_all = count_bert_correct + count_bert_wrong

        count_jointbert_correct = len(jointbert_df[jointbert_df['label'] == jointbert_df['label_jointbert']].reset_index(level="wordclass", drop=True).index.unique())
        count_jointbert_wrong = len(jointbert_df[jointbert_df['label'] != jointbert_df['label_jointbert']].reset_index(level="wordclass", drop=True).index.unique())
        count_jointbert_all = count_jointbert_correct + count_jointbert_wrong
        
        count_deepmatcher_correct = len(deepmatcher_df[deepmatcher_df['label'] == deepmatcher_df['label_deepmatcher']].reset_index(level="wordclass", drop=True).index.unique())
        count_deepmatcher_wrong = len(deepmatcher_df[deepmatcher_df['label'] != deepmatcher_df['label_deepmatcher']].reset_index(level="wordclass", drop=True).index.unique())
        count_deepmatcher_all = count_deepmatcher_correct + count_deepmatcher_wrong
        
        print(f'Label: {i}: BERT correct: {count_bert_correct}/{count_bert_all}')
        print(f'Label: {i}: JointBERT correct: {count_jointbert_correct}/{count_jointbert_all}')
        print(f'Label: {i}: Deepmatcher correct: {count_deepmatcher_correct}/{count_deepmatcher_all}')
    
    median_bert_df = pd.DataFrame.from_dict(median_values_bert)
    median_bert_df['model'] = 'BERT'
    median_jointbert_df = pd.DataFrame.from_dict(median_values_jointbert)
    median_jointbert_df['model'] = 'JointBERT'
    median_deepmatcher_df = pd.DataFrame.from_dict(median_values_deepmatcher)
    median_deepmatcher_df['model'] = 'Deepmatcher'
    
    median_to_file = median_bert_df.append(median_jointbert_df)
    median_to_file = median_to_file.append(median_deepmatcher_df)
    
    #median_to_file.to_csv(f'../../data/processed/explain_labeling/pictures/MEDIAN_{filename}.csv', index=False, float_format="%.4f")
    
    plt.subplots_adjust(bottom=0.25)
    #g.fig.suptitle(f'Combined Classifications\nChallenge: {name}')
    plt.savefig(f'../../data/processed/explain_labeling/pictures/jointbert/onlybert/AVG_{filename}.png')
    plt.show()

In [None]:
all_df = challenges_df[challenges_df['challenge_4'] == 0]
training_examples_df = challenges_df[(challenges_df['challenge_7'] == 0) & (challenges_df['challenge_4'] == 0)]
no_training_examples_df = challenges_df[(challenges_df['challenge_7'] == 1) & (challenges_df['challenge_4'] == 0)]


dfs = {
    #'train': training_examples_df,
    'train+notrain': all_df
    #'no_train': no_training_examples_df
}

for k, df in dfs.items():
    print(f'Results for {k}:')
    relevant_df = df
    relevant_explanations_distilbert = explanations_distilbert.loc[relevant_df.index]

    relevant_explanations_correct_distilbert = relevant_explanations_distilbert[relevant_explanations_distilbert['label'] == relevant_explanations_distilbert['label_distilbert']]
    relevant_explanations_wrong_distilbert = relevant_explanations_distilbert[relevant_explanations_distilbert['label'] != relevant_explanations_distilbert['label_distilbert']]


    result_correct_distilbert = relevant_explanations_correct_distilbert.groupby(['pair_id','wordclass']).mean()
    result_correct_distilbert['wordclass'] = [j for i, j in result_correct_distilbert.index.tolist()]
    result_correct_distilbert['model'] = 'DistilBERT'

    result_wrong_distilbert = relevant_explanations_wrong_distilbert.groupby(['pair_id','wordclass']).mean()
    result_wrong_distilbert['wordclass'] = [j for i, j in result_wrong_distilbert.index.tolist()]
    result_wrong_distilbert['model'] = 'DistilBERT'

    ##########################################

    relevant_explanations_deepmatcher = explanations_deepmatcher.loc[relevant_df.index]

    relevant_explanations_correct_deepmatcher = relevant_explanations_deepmatcher[relevant_explanations_deepmatcher['label'] == relevant_explanations_deepmatcher['label_deepmatcher']]
    relevant_explanations_wrong_deepmatcher = relevant_explanations_deepmatcher[relevant_explanations_deepmatcher['label'] != relevant_explanations_deepmatcher['label_deepmatcher']]


    result_correct_deepmatcher = relevant_explanations_correct_deepmatcher.groupby(['pair_id','wordclass']).mean()
    result_correct_deepmatcher['wordclass'] = [j for i, j in result_correct_deepmatcher.index.tolist()]
    result_correct_deepmatcher['model'] = 'Deepmatcher'

    result_wrong_deepmatcher = relevant_explanations_wrong_deepmatcher.groupby(['pair_id','wordclass']).mean()
    result_wrong_deepmatcher['wordclass'] = [j for i, j in result_wrong_deepmatcher.index.tolist()]
    result_wrong_deepmatcher['model'] = 'Deepmatcher'

    ##########################################

    relevant_explanations_bert = explanations_bert.loc[relevant_df.index]

    relevant_explanations_correct_bert = relevant_explanations_bert[relevant_explanations_bert['label'] == relevant_explanations_bert['label_bert']]
    relevant_explanations_wrong_bert = relevant_explanations_bert[relevant_explanations_bert['label'] != relevant_explanations_bert['label_bert']]


    result_correct_bert = relevant_explanations_correct_bert.groupby(['pair_id','wordclass']).mean()
    result_correct_bert['wordclass'] = [j for i, j in result_correct_bert.index.tolist()]
    result_correct_bert['model'] = 'BERT'

    result_wrong_bert = relevant_explanations_wrong_bert.groupby(['pair_id','wordclass']).mean()
    result_wrong_bert['wordclass'] = [j for i, j in result_wrong_bert.index.tolist()]
    result_wrong_bert['model'] = 'BERT'

    ##########################################

    relevant_explanations_jointbert = explanations_jointbert.loc[relevant_df.index]

    relevant_explanations_correct_jointbert = relevant_explanations_jointbert[relevant_explanations_jointbert['label'] == relevant_explanations_jointbert['label_jointbert']]
    relevant_explanations_wrong_jointbert = relevant_explanations_jointbert[relevant_explanations_jointbert['label'] != relevant_explanations_jointbert['label_jointbert']]


    result_correct_jointbert = relevant_explanations_correct_jointbert.groupby(['pair_id','wordclass']).mean()
    result_correct_jointbert['wordclass'] = [j for i, j in result_correct_jointbert.index.tolist()]
    result_correct_jointbert['model'] = 'JointBERT'

    result_wrong_jointbert = relevant_explanations_wrong_jointbert.groupby(['pair_id','wordclass']).mean()
    result_wrong_jointbert['wordclass'] = [j for i, j in result_wrong_jointbert.index.tolist()]
    result_wrong_jointbert['model'] = 'JointBERT'

    ###########################################

    combined_results_correct =  result_correct_distilbert.append(result_correct_deepmatcher)
    combined_results_correct =  combined_results_correct.append(result_correct_bert)
    combined_results_correct =  combined_results_correct.append(result_correct_jointbert)

    combined_results_wrong = result_wrong_distilbert.append(result_wrong_deepmatcher)
    combined_results_wrong = combined_results_wrong.append(result_wrong_bert)
    combined_results_wrong = combined_results_wrong.append(result_wrong_jointbert)

    all_results = combined_results_correct.append(combined_results_wrong)

    combined_results_correct = combined_results_correct[combined_results_correct['model'] != 'DistilBERT']
    combined_results_wrong = combined_results_wrong[combined_results_wrong['model'] != 'DistilBERT']
    all_results = all_results[all_results['model'] != 'DistilBERT']

    ###########################################

    jointbert_correct = all_results[(all_results['label'] == all_results['label_jointbert']) & (all_results['label'] != all_results['label_bert'])]
    bert_correct = all_results[(all_results['label'] != all_results['label_jointbert']) & (all_results['label'] == all_results['label_bert'])]
    deepmatcher_correct = all_results[(all_results['label'] != all_results['label_jointbert']) & (all_results['label'] != all_results['label_bert'])]
    all_correct = all_results[(all_results['label'] == all_results['label_jointbert']) & (all_results['label'] == all_results['label_bert'])]
    all_wrong = all_results[(all_results['label'] != all_results['label_jointbert']) & (all_results['label'] != all_results['label_bert'])]
    
    print('Combined performance:')
    plot_explanations(all_results, f'3_combined_{k}')
    print('Correct classifications:')
    plot_explanations(combined_results_correct, f'3_correct_{k}')
    print('Wrong classifications:')
    plot_explanations(combined_results_wrong, f'3_wrong_{k}')

    print('##############################')
    
    print('Only JointBERT correct:')
    plot_explanations(jointbert_correct, f'3_only_joint_correct_{k}')

    print('Only BERT correct:')
    plot_explanations(bert_correct, f'3_only_bert_correct_{k}')
    
    print('Only Deepmatcher correct:')
    plot_explanations(deepmatcher_correct, f'3_only_deepmatcher_correct_{k}')
    
    print('All correct:')
    plot_explanations(all_correct, f'3_all_correct_{k}')

    print('All wrong:')
    plot_explanations(all_wrong, f'3_all_wrong_{k}')
    
    print('Combined performance:')
    plot_explanations_avg(all_results, f'3_combined_{k}')
    print('Correct classifications:')
    plot_explanations_avg(combined_results_correct, f'3_correct_{k}')
    print('Wrong classifications:')
    plot_explanations_avg(combined_results_wrong, f'3_wrong_{k}')

    print('##############################')

    print('Only JointBERT correct:')
    plot_explanations_avg(jointbert_correct, f'3_only_joint_correct_{k}')

    print('Only BERT correct:')
    plot_explanations_avg(bert_correct, f'3_only_bert_correct_{k}')

    print('Only Deepmatcher correct:')
    plot_explanations_avg(deepmatcher_correct, f'3_only_deepmatcher_correct_{k}')

    print('All correct:')
    plot_explanations_avg(all_correct, f'3_all_correct_{k}')

    print('All wrong:')
    plot_explanations_avg(all_wrong, f'3_all_wrong_{k}')

In [None]:
challenges = {'challenge_1':'Solvable by looking at model numbers',
              'challenge_2':'No model number for one of the two but solvable by looking at attribute words'
              'challenge_5':'both have many training examples',
              'challenge_6':'both have few training examples',
              'challenge_7':'no training examples'}



for challenge, name in challenges.items():
    
    print(f'Results for {name}:')
    relevant_df = challenges_df[challenges_df[challenge] == 1]
    relevant_explanations_distilbert = explanations_distilbert.loc[relevant_df.index]

    relevant_explanations_correct_distilbert = relevant_explanations_distilbert[relevant_explanations_distilbert['label'] == relevant_explanations_distilbert['label_distilbert']]
    relevant_explanations_wrong_distilbert = relevant_explanations_distilbert[relevant_explanations_distilbert['label'] != relevant_explanations_distilbert['label_distilbert']]


    result_correct_distilbert = relevant_explanations_correct_distilbert.groupby(['pair_id','wordclass']).mean()
    result_correct_distilbert['wordclass'] = [j for i, j in result_correct_distilbert.index.tolist()]
    result_correct_distilbert['model'] = 'DistilBERT'

    result_wrong_distilbert = relevant_explanations_wrong_distilbert.groupby(['pair_id','wordclass']).mean()
    result_wrong_distilbert['wordclass'] = [j for i, j in result_wrong_distilbert.index.tolist()]
    result_wrong_distilbert['model'] = 'DistilBERT'

    ##########################################

    relevant_explanations_deepmatcher = explanations_deepmatcher.loc[relevant_df.index]

    relevant_explanations_correct_deepmatcher = relevant_explanations_deepmatcher[relevant_explanations_deepmatcher['label'] == relevant_explanations_deepmatcher['label_deepmatcher']]
    relevant_explanations_wrong_deepmatcher = relevant_explanations_deepmatcher[relevant_explanations_deepmatcher['label'] != relevant_explanations_deepmatcher['label_deepmatcher']]


    result_correct_deepmatcher = relevant_explanations_correct_deepmatcher.groupby(['pair_id','wordclass']).mean()
    result_correct_deepmatcher['wordclass'] = [j for i, j in result_correct_deepmatcher.index.tolist()]
    result_correct_deepmatcher['model'] = 'Deepmatcher'

    result_wrong_deepmatcher = relevant_explanations_wrong_deepmatcher.groupby(['pair_id','wordclass']).mean()
    result_wrong_deepmatcher['wordclass'] = [j for i, j in result_wrong_deepmatcher.index.tolist()]
    result_wrong_deepmatcher['model'] = 'Deepmatcher'

    ##########################################

    relevant_explanations_bert = explanations_bert.loc[relevant_df.index]

    relevant_explanations_correct_bert = relevant_explanations_bert[relevant_explanations_bert['label'] == relevant_explanations_bert['label_bert']]
    relevant_explanations_wrong_bert = relevant_explanations_bert[relevant_explanations_bert['label'] != relevant_explanations_bert['label_bert']]


    result_correct_bert = relevant_explanations_correct_bert.groupby(['pair_id','wordclass']).mean()
    result_correct_bert['wordclass'] = [j for i, j in result_correct_bert.index.tolist()]
    result_correct_bert['model'] = 'BERT'

    result_wrong_bert = relevant_explanations_wrong_bert.groupby(['pair_id','wordclass']).mean()
    result_wrong_bert['wordclass'] = [j for i, j in result_wrong_bert.index.tolist()]
    result_wrong_bert['model'] = 'BERT'

    ##########################################

    relevant_explanations_jointbert = explanations_jointbert.loc[relevant_df.index]

    relevant_explanations_correct_jointbert = relevant_explanations_jointbert[relevant_explanations_jointbert['label'] == relevant_explanations_jointbert['label_jointbert']]
    relevant_explanations_wrong_jointbert = relevant_explanations_jointbert[relevant_explanations_jointbert['label'] != relevant_explanations_jointbert['label_jointbert']]


    result_correct_jointbert = relevant_explanations_correct_jointbert.groupby(['pair_id','wordclass']).mean()
    result_correct_jointbert['wordclass'] = [j for i, j in result_correct_jointbert.index.tolist()]
    result_correct_jointbert['model'] = 'JointBERT'

    result_wrong_jointbert = relevant_explanations_wrong_jointbert.groupby(['pair_id','wordclass']).mean()
    result_wrong_jointbert['wordclass'] = [j for i, j in result_wrong_jointbert.index.tolist()]
    result_wrong_jointbert['model'] = 'JointBERT'

    ###########################################

    combined_results_correct =  result_correct_distilbert.append(result_correct_deepmatcher)
    combined_results_correct =  combined_results_correct.append(result_correct_bert)
    combined_results_correct =  combined_results_correct.append(result_correct_jointbert)

    combined_results_wrong = result_wrong_distilbert.append(result_wrong_deepmatcher)
    combined_results_wrong = combined_results_wrong.append(result_wrong_bert)
    combined_results_wrong = combined_results_wrong.append(result_wrong_jointbert)

    all_results = combined_results_correct.append(combined_results_wrong)

    combined_results_correct = combined_results_correct[combined_results_correct['model'] != 'DistilBERT']
    combined_results_wrong = combined_results_wrong[combined_results_wrong['model'] != 'DistilBERT']
    all_results = all_results[all_results['model'] != 'DistilBERT']

    ###########################################

    jointbert_correct = all_results[(all_results['label'] == all_results['label_jointbert']) & (all_results['label'] != all_results['label_bert'])]
    bert_correct = all_results[(all_results['label'] != all_results['label_jointbert']) & (all_results['label'] == all_results['label_bert'])]
    deepmatcher_correct = all_results[(all_results['label'] != all_results['label_jointbert']) & (all_results['label'] != all_results['label_bert'])]
    all_correct = all_results[(all_results['label'] == all_results['label_jointbert']) & (all_results['label'] == all_results['label_bert'])]
    all_wrong = all_results[(all_results['label'] != all_results['label_jointbert']) & (all_results['label'] != all_results['label_bert'])]
    
    
    print('Combined performance:')
    plot_explanations(all_results, f'3_combined_{name}')
    
    print('All correct:')
    plot_explanations(all_correct, f'3_all_correct_{name}')
    
    print('All wrong:')
    plot_explanations(all_wrong, f'3_all_wrong_{name}')
    
    print('Only JointBERT correct:')
    plot_explanations(jointbert_correct, f'3_only_joint_correct_{name}')

    print('Only BERT correct:')
    plot_explanations(bert_correct, f'3_only_bert_correct_{name}')
    
    print('Only Deepmatcher correct:')
    plot_explanations(deepmatcher_correct, f'3_only_deepmatcher_correct_{name}')
    
    print('Combined performance:')
    plot_explanations_avg(all_results, f'3_combined_{name}')

    print('All correct:')
    plot_explanations_avg(all_correct, f'3_all_correct_{name}')

    print('All wrong:')
    plot_explanations_avg(all_wrong, f'3_all_wrong_{name}')
    
    print('Only JointBERT correct:')
    plot_explanations_avg(jointbert_correct, f'3_only_joint_correct_{name}')

    print('Only BERT correct:')
    plot_explanations_avg(bert_correct, f'3_only_bert_correct_{name}')
    
    print('Only Deepmatcher correct:')
    plot_explanations_avg(deepmatcher_correct, f'3_only_deepmatcher_correct_{name}')