### Import libraries

In [19]:
import plotly.express as px
import numpy as np
import pickle as pkl
import warnings

### Functions to read pickle files and plot the data

In [20]:
def read_pkl(file):
    """
    Read a pickle file and return the data
    """
    with open(file, 'rb') as f:
        return pkl.load(f)
    
def plot(file_1, file_2, label_1, label_2, mse=False):
    """
    Plot the data from 2 pickle files and report MSE
    """
    data_1 = read_pkl(file_1)
    data_2 = read_pkl(file_2)

    common_samples = set(data_1.keys()).intersection(set(data_2.keys()))

    variables = data_1[list(common_samples)[0]].keys()

    for variable in variables:
        x_1, x_2 = [], []
        y_1, y_2 = [], []
        c_1, c_2 = [], []
        for sample in sorted(common_samples):
            x_1.append(int(sample.split('.')[0]))
            x_2.append(int(sample.split('.')[0]))
            y_1.append(data_1[sample][variable])
            y_2.append(data_2[sample][variable])
            c_1.append(label_1)
            c_2.append(label_2)
        with warnings.catch_warnings():
            warnings.simplefilter(action='ignore', category=FutureWarning)
            fig = px.scatter(x=x_1 + x_2, y=y_1 + y_2, color=c_1 + c_2,
                            trendline="lowess", trendline_options=dict(frac=0.3))
        fig.update_layout(title=str.capitalize(variable),
                        xaxis_title='Sample',
                        yaxis_title='Information Flow',
                        legend_title_text = '',
                        width=1000,
                        height=500,
                        font_size=18,)
        fig.show()
        if mse:
            print(f'MSE for {variable}: {np.mean((np.array(y_1) - np.array(y_2))**2)}\n')


## Experiment 1: Score dataset

In [21]:
plot('../results/score_positive.pkl', '../results/score_negative.pkl', 'positive', 'negative')

## Experiment 2: URMP dataset

In [22]:
plot('../results/urmp_positive.pkl', '../results/urmp_negative.pkl', 'positive', 'negative')

## Experiment 3: Self-enhancement bias

In [23]:
plot('../results/mtmt.pkl', '../results/amt.pkl', 'MTMT', 'AMT')

## Experiment 4: Positional bias

In [24]:
plot('../results/XY.pkl', '../results/YX.pkl', 'XY', 'YX', mse=True)

MSE for beat: 0.0008486412116326392



MSE for duration: 0.021127579733729362



MSE for instrument: 4.092359176866012e-06



MSE for pitch: 0.0039096251130104065



MSE for position: 0.01263525988906622



MSE for type: 2.14659863218003e-07

