In [76]:
import pyreal.sample_applications.titanic as titanic
import numpy as np

real_app = titanic.load_titanic_app()
sample_data, _ = titanic.load_titanic_data(n_rows=300)

In [None]:
explanation = real_app.produce_local_feature_contributions(sample_data)

In [None]:
from pyreal.visualize import plot_top_contributors

passenger_id = 1
predictions = real_app.predict(sample_data)
predictions = np.array(["Died" if pred==0 else "Survived" for pred in predictions])

plot_top_contributors(explanation[1], prediction=predictions[passenger_id])

In [None]:
from pyreal.visualize import swarm_plot

swarm_plot(explanation, type="strip")

In [None]:
from pyreal.types.explanations.feature_based import FeatureContributionExplanation
import pandas as pd
import matplotlib.pyplot as plt
from pyreal.visualize.visualize_config import NEGATIVE_COLOR_LIGHT, NEUTRAL_COLOR, POSITIVE_COLOR_LIGHT, PALETTE_CMAP
import seaborn as sns
import numpy as np


def feature_scatter_plot(explanation, feature, predictions, legend=True, legend_titles=None):
    if isinstance(explanation, FeatureContributionExplanation):
        contributions = explanation.get()
        values = explanation.get()
    else:
        contribution_list = [explanation[i]["Contribution"] for i in explanation]
        value_list = [explanation[i]["Feature Value"] for i in explanation]
        contributions = pd.DataFrame(contribution_list)
        values = pd.DataFrame(value_list)

    contributions = contributions[feature]
    values = values[feature]

    data = pd.DataFrame({"Contribution": contributions.values, "Value":values.values, "Prediction":predictions})

    num_colors = len(np.unique(predictions.astype("str")))
    palette = sns.blend_palette(
            [NEGATIVE_COLOR_LIGHT, NEUTRAL_COLOR, POSITIVE_COLOR_LIGHT], n_colors=num_colors
        )
    legend = True
    if isinstance(predictions[0], (float)) or (isinstance(predictions[0], (int)) and num_colors > 6):
        legend = False
    ax = sns.lmplot( x="Value", y="Contribution",
                data=data, fit_reg=False,
                hue='Prediction', palette=palette,
                legend=legend)
    plt.xlabel("Values for %s" % feature)
    if not legend:
        norm = plt.Normalize(0, 1)
        sm = plt.cm.ScalarMappable(cmap=PALETTE_CMAP, norm=norm)
        min = predictions.min()
        max = predictions.max()
        sm.set_array([])
        cbar = ax.figure.colorbar(sm)
        cbar.ax.get_yaxis().set_ticks([])
        cbar.ax.text(1.5, 0.05, ('%.2f' % min).rstrip('0').rstrip('.'), ha="left", va="center")
        cbar.ax.text(1.5, 0.95, ('%.2f' % max).rstrip('0').rstrip('.'), ha="left", va="center")
        cbar.ax.set_ylabel("Feature Value", rotation=270)
        cbar.ax.get_yaxis().labelpad = 15


feature_scatter_plot(explanation, "Age", predictions, legend_titles=["Died", "Survived"])
