In [None]:
import shap
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

def plot_mfcc_shap(model, X_test, ind, unique_genres, max_evals=5000, batch_size=50):
    """
    Plots the original MFCC and SHAP values for a selected input example.

    Parameters:
        cnn_lstm_model (keras.Model): The trained CNN-LSTM model to interpret.
        X_test (np.array): Array of MFCC input data for testing, shaped (samples, time_steps, mfcc_coefficients, channels).
        ind (list): List containing index(es) of samples to interpret.
        unique_genres (list): List of genre names for output classes.
        max_evals (int): Maximum evaluations for SHAP explainer.
        batch_size (int): Batch size for SHAP explainer evaluation.
    """
    def f(x):
        tmp = x.copy()
        return model(tmp)

    masker_blur = shap.maskers.Image("blur(32,32)", X_test[0].shape)
    explainer = shap.Explainer(f, masker_blur, output_names=unique_genres)
    shap_values_ = explainer(X_test[ind], max_evals=max_evals, batch_size=batch_size, outputs=shap.Explanation.argsort.flip[:1])

    original_mfcc = X_test[ind][0, :, :, 0]
    shap_values_top_class = shap_values_.values[0, :, :, 0, 0]

    colors = ['blue', 'white', 'red']  # Negative, zero, positive
    cmap = LinearSegmentedColormap.from_list('custom_cmap', colors)

    fig, axs = plt.subplots(2, 1, figsize=(15, 15))

    mfcc_img = axs[0].imshow(original_mfcc.T, aspect='auto', interpolation='nearest', cmap='gray')
    axs[0].set_ylabel('MFCC Coefficients')
    plt.colorbar(mfcc_img, ax=axs[0], label='MFCC value')
    num_coefficients = original_mfcc.shape[1]
    axs[0].set_yticks(ticks=np.arange(num_coefficients))
    axs[0].set_yticklabels(np.arange(num_coefficients))
    axs[0].set_xticks([])

    shap_img = axs[1].imshow(shap_values_top_class.T, cmap=cmap, aspect='auto', interpolation='nearest',
                             vmin=-np.max(np.abs(shap_values_top_class.T)), vmax=np.max(np.abs(shap_values_top_class.T)))
    axs[1].set_xlabel('Time Frames')
    axs[1].set_ylabel('MFCC Coefficients')
    plt.colorbar(shap_img, ax=axs[1], label='SHAP value')
    num_coefficients_shap = shap_values_top_class.shape[1]
    axs[1].set_yticks(ticks=np.arange(num_coefficients_shap))
    axs[1].set_yticklabels(np.arange(num_coefficients_shap))

    plt.tight_layout()
    plt.show()


unique_genres = ['blues', 'classical', 'country', 'disco', 'hip-hop', 'jazz', 'metal', 'pop', 'reggae', 'rock']

plot_mfcc_shap(cnn_lstm, X_test_cnn, [600], unique_genres)