In [1]:
%reload_ext autoreload
%autoreload 2

%matplotlib inline

# To prevent automatic figure display when execution of the cell ends
%config InlineBackend.close_figures=False 

In [2]:
import os
import pandas as pd
import numpy as np
import seaborn as sns
from scipy import stats

import torch.optim as optim
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn

from sklearn.preprocessing import LabelEncoder

from models.mlp import BlackBoxModel

import matplotlib.pyplot as plt
from ipywidgets import widgets
from IPython.display import display,clear_output

import warnings
warnings.filterwarnings("ignore")

In [3]:
data_path = 'data/'

df_ = pd.read_csv(os.path.join(data_path, 'german_credit_data.csv'))
df = df_.copy()

In [4]:
target_name = 'Risk'
target = df[target_name].replace({'good': 0, 'bad': 1})

df[target_name] = target

In [5]:
# Initialize a label encoder
label_encoder = LabelEncoder()
label_mappings = {}


# Convert categorical columns to numerical representations using label encoding
for column in df.columns:
    if column is not target_name and df[column].dtype == 'object':
        # Handle missing values by filling with a placeholder and then encoding
        df[column] = df[column].fillna('Unknown')
        df[column] = label_encoder.fit_transform(df[column])
        label_mappings[column] = dict(zip(label_encoder.classes_, range(len(label_encoder.classes_))))


# For columns with NaN values that are numerical, we will impute them with the median of the column
for column in df.columns:
    if df[column].isna().any():
        median_val = df[column].median()
        df[column].fillna(median_val, inplace=True)

# Display the first few rows of the transformed dataframe
df.head()

Unnamed: 0,Age,Sex,Job,Housing,Saving accounts,Checking account,Credit amount,Duration,Purpose,Risk
0,67,1,2,1,0,1,1169,6,5,0
1,22,0,2,1,1,2,5951,48,5,1
2,49,1,1,1,1,0,2096,12,3,0
3,45,1,2,0,1,1,7882,42,4,0
4,53,1,2,0,1,1,4870,24,1,1


In [6]:
features = [
    'Age', 
    'Sex', 
    'Job', 
    'Housing', 
    'Saving accounts', 
    'Checking account',
    'Credit amount', 
    'Duration', 
    'Purpose', 
]

df_X = df[features].copy()
df_y = target

In [19]:
seed = 102

np.random.seed(seed)  # for reproducibility
torch.manual_seed(seed)

# Split the dataset into training and testing sets (80% train, 20% test)
X_train, X_test, y_train, y_test = train_test_split(df_X, df_y, test_size=0.2, random_state=seed)

df_train = X_train.copy()
df_test = X_test.copy()

std = X_train.std()
mean = X_train.mean()

X_train = (X_train - mean) / std
X_test = (X_test - mean) / std

# Convert to PyTorch tensors
X_train_tensor = torch.FloatTensor(X_train.values)
y_train_tensor = torch.FloatTensor(y_train.values).view(-1, 1)
X_test_tensor = torch.FloatTensor(X_test.values)
y_test_tensor = torch.FloatTensor(y_test.values).view(-1, 1)

# Initialize the model, loss function, and optimizer
model = BlackBoxModel(input_dim=X_train.shape[1])
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training loop
num_epochs = 300
for epoch in range(num_epochs):
    # Forward pass
    outputs = model(X_train_tensor)
    loss = criterion(outputs, y_train_tensor)
    
    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Evaluate on test set
model.eval()
with torch.no_grad():
    test_outputs = model(X_test_tensor)
    test_loss = criterion(test_outputs, y_test_tensor)

    # Convert outputs to binary using 0.5 as threshold
    y_pred_tensor = (test_outputs > 0.5).float()
    correct_predictions = (y_pred_tensor == y_test_tensor).float().sum()
    accuracy = correct_predictions / y_test_tensor.shape[0]

test_accuracy = accuracy.item()
print('Test Accuracy:', test_accuracy)

Test Accuracy: 0.7450000047683716


In [20]:
from explainers import pshap, dce
import shap

from utils.benchmarking import *

In [21]:
# sample_num = 50
# max_round = 1
# ce_max_iter = 100
# ci_factor = 1.96  # 95% confidence interval factor
# n_proj = 10
# delta = 0.05

# explain_columns = [
#     'Age', 
#     'Sex', 
#     'Job', 
#     'Housing', 
#     'Saving accounts', 
#     'Checking account',
#     'Credit amount', 
#     'Duration', 
#     'Purpose', 
# ]

# y_target = torch.distributions.beta.Beta(0.1, 0.9).sample((sample_num,))

# # Initialize interactive output display
# plt.ioff()
# out = widgets.Output()
# vbox = widgets.VBox([out])
# display(vbox)

# # Lists to store accuracies over iterations
# ot_list_bs = []
# ot_list_jp = []
# kl_list_bs = []
# kl_list_jp = []
# mmd_list_bs = []
# mmd_list_jp = []

# for t in range(max_round):
#     indice = (X_test.sample(sample_num)).index
#     df_explain = X_test.loc[indice]
#     X_explain = df_explain.values
#     y_explain = model.predict_proba(X_explain)

#     explainer = dce.DistributionalCounterfactualExplainer(
#         model=model, 
#         df_X=df_explain, 
#         explain_columns=explain_columns,
#         y_target=y_target, 
#         lr=1e-1, 
#         n_proj=n_proj,
#         delta=0.05)

#     explainer.optimize(U_1=0.4, U_2=0.2, l=0.2, r=1, max_iter=ce_max_iter, tau=1e3)

#     X_baseline = explainer.best_X.detach().numpy()
#     y_baseline = explainer.best_y.detach().numpy().flatten()
#     # y_baseline = y_target
#     df_baseline = pd.DataFrame(X_baseline, columns=df_X.columns, index=indice)

#     matrix_mu = get_ot_plan(explainer.swd.mu_list)

#     shap_explainer = shap.KernelExplainer(model.predict_proba, X_baseline)
#     jp_explainer = pshap.JointProbabilityExplainer(model)

#     shap_values_baseline = shap_explainer.shap_values(X_explain)
#     shap_values_jp = jp_explainer.shap_values(X_explain, X_baseline, 
#                                               joint_probs=matrix_mu.detach().numpy())

#     ranked_features_baseline = get_ranked_features(shap_values_baseline, columns=df_explain.columns.to_list())
#     ranked_features_jp = get_ranked_features(shap_values_jp, columns=df_explain.columns.to_list())

#     baseline_change_columns = []
#     jp_change_columns = []

#     ot_start, _ = WassersteinDivergence().distance(
#         torch.FloatTensor(y_explain), 
#         torch.FloatTensor(y_baseline),
#         delta=delta,
#         )
#     kl_start = compute_kl_divergence(
#             y_explain, 
#             y_baseline,
#         )
#     mmd_start = compute_mmd(
#             y_explain, 
#             y_baseline,
#         )
#     results = [{
#         'OT_bs': ot_start, 'OT_jp': ot_start,
#         'KL_bs': kl_start, 'KL_jp': kl_start,
#         'MMD_bs': mmd_start, 'MMD_jp': mmd_start,
#         }]

#     for bs_change_col, jp_change_col in zip(ranked_features_baseline, ranked_features_jp):
#         baseline_change_columns.append(bs_change_col)
#         jp_change_columns.append(jp_change_col)
#         result = counterfactual_distance_performance_benchmarking(
#             model=model, df_baseline=df_baseline, 
#             df_explain=df_explain, y_baseline=y_baseline,
#             baseline_change_columns=baseline_change_columns,
#             jp_change_columns=jp_change_columns,
#             delta=delta
#         )
#         results.append(result)

#     new_ot_bs = [result['OT_bs'] for result in results]
#     new_ot_jp = [result['OT_jp'] for result in results]

#     ot_list_bs.append(new_ot_bs)
#     ot_list_jp.append(new_ot_jp)


#     # Compute mean and confidence intervals
#     ot_means_bs = np.mean(ot_list_bs, axis=0)
#     ot_means_jp = np.mean(ot_list_jp, axis=0)
#     ot_std_err_bs = stats.sem(ot_means_bs, axis=0)
#     ot_std_err_jp = stats.sem(ot_means_jp, axis=0)
#     ot_ci_bs = ot_std_err_bs * ci_factor
#     ot_ci_jp = ot_std_err_jp * ci_factor

#     fig, ax = plt.subplots(figsize=(10, 6))
#     x_labels = np.arange(len(ot_means_bs))
#     ax.plot(x_labels, ot_means_bs, label='SHAP', marker='o')
#     ax.fill_between(x_labels, ot_means_bs - ot_ci_bs, ot_means_bs + ot_ci_bs, alpha=0.2)
#     ax.plot(x_labels, ot_means_jp, label='JP-SHAP', marker='o')
#     ax.fill_between(x_labels, ot_means_jp - ot_ci_jp, ot_means_jp + ot_ci_jp, alpha=0.2)
#     ax.set_xlabel('Number of Dropped Features')
#     ax.set_ylabel('OT distance')
#     ax.set_title('Comparison of Baseline and JointProb Accuracies with Confidence Intervals')
#     ax.legend()
#     ax.grid(True)

#     with out:
#         clear_output(wait=True);
#         print(f'Round:{t}')
#         display(ax.figure);

In [22]:
# y_baseline = explainer.best_y.detach().numpy().flatten()

In [23]:
sample_num = 100
max_round = 50
ce_max_iter = 50
ci_factor = 1.96  # 95% confidence interval factor
n_proj = 10
delta = 0.05

explain_columns = [
    'Age', 
    'Sex', 
    'Job', 
    'Housing', 
    'Saving accounts', 
    'Checking account',
    'Credit amount', 
    'Duration', 
    'Purpose', 
]

y_target = torch.zeros(sample_num)

In [24]:



# indice = (X_test.sample(sample_num)).index
# df_explain = X_test.loc[indice]
# X_explain = df_explain.values
# y_explain = model.predict_proba(X_explain)

# explainer = dce.DistributionalCounterfactualExplainer(
#     model=model, 
#     df_X=df_explain, 
#     explain_columns=explain_columns,
#     y_target=y_target, 
#     lr=1e-1, 
#     n_proj=n_proj,
#     delta=0.05)

# explainer.optimize(U_1=0.4, U_2=0.2, l=0.2, r=1, max_iter=ce_max_iter, tau=1e3)

# X_baseline = explainer.best_X.detach().numpy()
# y_baseline = explainer.best_y.detach().numpy().flatten()
# # y_baseline = y_target
# df_baseline = pd.DataFrame(X_baseline, columns=df_X.columns, index=indice)





In [35]:
# from explainers.distances import WassersteinDivergence, SlicedWassersteinDivergence

# _, mu_list = SlicedWassersteinDivergence(
#     dim=X_explain.shape[1], n_proj=1000
# ).distance(torch.FloatTensor(X_explain), torch.FloatTensor(X_baseline), delta)

# _, matrix_mu = WassersteinDivergence().distance(torch.FloatTensor(X_explain),torch.FloatTensor(X_baseline), delta)

# matrix_mu = get_ot_plan(mu_list, method='max')


import ot

X_test_ext = X_test.copy()
X_test_ext[target_name] = model.predict_proba(X_test.values)

df_baseline = X_test[X_test_ext[target_name] < 0.2]
df_explain = X_test

max_len = min(df_baseline.shape[0], df_explain.shape[0], sample_num)

df_baseline = df_baseline.sample(max_len)
df_explain = df_explain.sample(int(max_len))

X_baseline = df_baseline.values
y_baseline = model.predict_proba(X_baseline)
X_explain = df_explain.values
y_explain = model.predict_proba(X_explain)

ot_cost = ot.dist(X_explain, X_baseline)
matrix_mu = ot.emd(
    np.ones(X_explain.shape[0])/X_explain.shape[0], 
    np.ones(X_baseline.shape[0])/X_baseline.shape[0], ot_cost
)

shap_explainer = shap.KernelExplainer(model.predict_proba, X_baseline)
jp_explainer = pshap.JointProbabilityExplainer(model)
# jp_explainer = shap.KernelExplainer(model.predict_proba, X_train.sample(max_len))

shap_values_baseline = shap_explainer.shap_values(X_explain)
shap_values_jp = jp_explainer.shap_values(X_explain, X_baseline, joint_probs=matrix_mu)
# shap_values_jp = jp_explainer.shap_values(X_explain)

  0%|          | 0/100 [00:00<?, ?it/s]

INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 3.05272684e-02,  1.45491716e-04,  1.88776680e-03,  5.08619676e-04,
       -2.69302878e-01,  9.17319272e-02,  2.30885980e-02,  2.45374765e-02,
        5.29306825e-02])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([-0.02826001,  0.02090996,  0.00965511, -0.0031046 , -0.01874514,
        0.11213742, -0.00754529, -0.06873964,  0.02482361])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([-0.10193147, -0.04160382, -0.06431859, -0.03192915,  0.06844942,
        0.26281712,  0.07865718,  0.03386769,  0.01516804])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.00886865,  0.00766004, -0.01874786, -0.00235522, -0.0221373 ,
       -0.05953462, -0.00489779,  0.0381929 ,  0.00901134])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.05169874,  0.01728044,  0.02100883, -0.01113676, -0.18532132,
        0.09098429, -0.00993758, -0.00437511,  0.00655148])
INFO:shap:num_full_subsets = 4
INFO:shap:phi = array([ 0.02271

In [36]:
num_pairs_list = [25, 50, 75, 100, 150, 200, 300, 400, 500]

# Initialize interactive output display
plt.ioff()
out = widgets.Output()
vbox = widgets.VBox([out])
display(vbox)

# Lists to store accuracies over iterations
ot_list_bs = []
ot_list_jp = []
exp_list_bs = []
exp_list_jp = []
mmd_list_bs = []
mmd_list_jp = []

for t in range(max_round):

    ot_start, _ = WassersteinDivergence().distance(
        torch.FloatTensor(y_explain), 
        torch.FloatTensor(y_baseline),
        delta=delta,
        )
    kl_start = compute_kl_divergence(
            y_explain, 
            y_baseline,
        )
    mmd_start = compute_mmd(
            y_explain, 
            y_baseline,
        )
    # results = [{
    #     'OT_bs': ot_start, 'OT_jp': ot_start,
    #     'KL_bs': kl_start, 'KL_jp': kl_start,
    #     'MMD_bs': mmd_start, 'MMD_jp': mmd_start,
    #     }]
    results = []
    
    for num_pairs in num_pairs_list:
        result = counterfactual_ability_performance_benchmarking(
                model=model,
                df_explain=df_explain,
                df_baseline=df_baseline,
                y_baseline=y_baseline,
                shap_values_baseline=shap_values_baseline,
                shap_values_jp=shap_values_jp,
                num_pairs=num_pairs,
                delta=delta,
        )
        results.append(result)

    new_ot_bs = [result['OT_bs'] for result in results]
    new_ot_jp = [result['OT_jp'] for result in results]

    new_exp_bs = [result['EXP_bs'] for result in results]
    new_exp_jp = [result['EXP_jp'] for result in results]

    new_mmd_bs = [result['MMD_bs'] for result in results]
    new_mmd_jp = [result['MMD_jp'] for result in results]

    ot_list_bs.append(new_ot_bs)
    ot_list_jp.append(new_ot_jp)

    exp_list_bs.append(new_exp_bs)
    exp_list_jp.append(new_exp_jp)

    mmd_list_bs.append(new_mmd_bs)
    mmd_list_jp.append(new_mmd_jp)

    # Compute mean and confidence intervals for OT
    ot_means_bs = np.mean(ot_list_bs, axis=0)
    ot_means_jp = np.mean(ot_list_jp, axis=0)
    ot_std_err_bs = stats.sem(ot_means_bs, axis=0)
    ot_std_err_jp = stats.sem(ot_means_jp, axis=0)
    ot_ci_bs = ot_std_err_bs * ci_factor / np.sqrt(t+1)
    ot_ci_jp = ot_std_err_jp * ci_factor / np.sqrt(t+1)

    # Compute mean and confidence intervals for KL
    exp_means_bs = np.mean(exp_list_bs, axis=0)
    exp_means_jp = np.mean(exp_list_jp, axis=0)
    exp_std_err_bs = stats.sem(exp_means_bs, axis=0)
    exp_std_err_jp = stats.sem(exp_means_jp, axis=0)
    exp_ci_bs = exp_std_err_bs * ci_factor / np.sqrt(t+1)
    exp_ci_jp = exp_std_err_jp * ci_factor / np.sqrt(t+1)

    # Compute mean and confidence intervals for MMD
    mmd_means_bs = np.mean(mmd_list_bs, axis=0)
    mmd_means_jp = np.mean(mmd_list_jp, axis=0)
    mmd_std_err_bs = stats.sem(mmd_means_bs, axis=0)
    mmd_std_err_jp = stats.sem(mmd_means_jp, axis=0)
    mmd_ci_bs = mmd_std_err_bs * ci_factor / np.sqrt(t+1)
    mmd_ci_jp = mmd_std_err_jp * ci_factor / np.sqrt(t+1)

    fig, axes = plt.subplots(1,3,figsize=(16, 4))
    x_labels =  num_pairs_list

    # Plotting code for OT Distance
    axes[0].plot(x_labels, ot_means_bs, label='SHAP', marker='o')
    axes[0].fill_between(x_labels, ot_means_bs - ot_ci_bs, ot_means_bs + ot_ci_bs, alpha=0.2)
    axes[0].plot(x_labels, ot_means_jp, label='JP-SHAP', marker='o')
    axes[0].fill_between(x_labels, ot_means_jp - ot_ci_jp, ot_means_jp + ot_ci_jp, alpha=0.2)
    axes[0].set_xlabel('Number of Changes')
    axes[0].set_ylabel('OT Distance')
    axes[0].legend()
    axes[0].grid(True)

    # Plotting code for MMD Divergence
    axes[1].plot(x_labels, mmd_means_bs, label='SHAP', marker='o')
    axes[1].fill_between(x_labels, mmd_means_bs - mmd_ci_bs, mmd_means_bs + mmd_ci_bs, alpha=0.2)
    axes[1].plot(x_labels, mmd_means_jp, label='JP-SHAP', marker='o')
    axes[1].fill_between(x_labels, mmd_means_jp - mmd_ci_jp, mmd_means_jp + mmd_ci_jp, alpha=0.2)
    axes[1].set_xlabel('Number of Changes')
    axes[1].set_ylabel('MMD')
    axes[1].legend()
    axes[1].grid(True)

    # Plotting code for MMD Divergence
    axes[2].plot(x_labels, exp_means_bs, label='SHAP', marker='o')
    axes[2].fill_between(x_labels, exp_means_bs - exp_ci_bs, exp_means_bs + exp_ci_bs, alpha=0.2)
    axes[2].plot(x_labels, exp_means_jp, label='JP-SHAP', marker='o')
    axes[2].fill_between(x_labels, exp_means_jp - exp_ci_jp, exp_means_jp + exp_ci_jp, alpha=0.2)
    axes[2].set_xlabel('Number of Changes')
    axes[2].set_ylabel('Exp Diff')
    axes[2].legend()
    axes[2].grid(True)

    # Adjust the spacing between the plots
    fig.subplots_adjust(wspace=0.3)  # Increase the width space

    with out:
        clear_output(wait=True);
        print(f'Round:{t}')
        display(fig)

    plt.close(fig)  # Close the figure to free memory and avoid unnecessary resource use


VBox(children=(Output(),))

In [None]:
P_bs = np.abs(shap_values_baseline) / np.abs(shap_values_baseline).sum()
P_jp = np.abs(shap_values_jp) / np.abs(shap_values_jp).sum()

num_pairs = 5

# Flatten the array to make it easier to sample from
flat_indices = np.random.choice(a=P_jp.size, size=num_pairs, p=P_jp.flatten(), replace=True)
flat_indices = np.unique(flat_indices)

# Convert flat indices back to 2D indices
i_indices, j_indices = np.unravel_index(flat_indices, P_jp.shape)

X_explain_current = X_explain.copy()

# Set values at selected_pairs
values_from_baseline = X_baseline[i_indices, j_indices]
X_explain_current[i_indices, j_indices] = values_from_baseline

In [None]:
y_explain.round(3)

array([0.199, 0.   , 0.713, 0.   , 0.   , 0.   , 0.001, 0.975, 0.006,
       0.   , 0.   , 0.34 , 0.003, 0.378, 0.   , 0.55 , 0.556, 0.001,
       0.001, 0.389, 0.   , 0.101, 0.812, 0.   , 0.   , 0.002, 1.   ,
       0.003, 0.606, 0.023, 0.91 , 0.   , 0.003, 0.   , 0.452, 0.999,
       0.208, 0.   , 0.   , 0.139, 0.044, 0.038, 0.   , 0.001, 0.618,
       0.   , 0.505, 0.87 , 0.998, 0.072], dtype=float32)

In [None]:
y_baseline.round(3)

array([0.019, 0.705, 0.006, 0.103, 0.024, 0.   , 0.   , 0.008, 0.   ,
       0.001, 0.001, 0.012, 0.   , 0.002, 0.   , 0.843, 0.993, 0.   ,
       0.023, 0.411, 0.083, 0.005, 0.019, 0.999, 0.   , 0.509, 0.096,
       0.   , 0.016, 0.   , 0.509, 0.03 , 0.327, 0.   , 0.516, 0.   ,
       0.   , 0.015, 0.048, 0.003, 0.   , 0.   , 0.   , 0.   , 0.   ,
       0.004, 0.028, 0.   , 0.   , 1.   ], dtype=float32)

In [None]:
from scipy.stats import gaussian_kde, entropy

In [None]:
# Set up the matplotlib figure and axes
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(16, 8))

# Plot the first SHAP summary plot on the first axis
plt.sca(axs[0])  # Set the current axis to the first subplot
shap.summary_plot(shap_values_baseline, X_explain, feature_names=df_explain.columns, plot_type='violin', show=False)
axs[0].set_title('Baseline Model')

# Plot the second SHAP summary plot on the second axis
plt.sca(axs[1])  # Set the current axis to the second subplot
shap.summary_plot(shap_values_jp, X_explain, feature_names=df_explain.columns, plot_type='violin', show=False)
axs[1].set_title('JP Model')

# Adjust layout
plt.tight_layout()

# Show the plot
plt.show()

In [None]:
factual_X = df[df_X.columns].loc[indice].copy()
counterfactual_X = pd.DataFrame(explainer.best_X.detach().numpy() * std[df_X.columns].values + mean[df_X.columns].values, columns=df_X.columns, index=indice)

# dtype_dict = df.dtypes.apply(lambda x: x.name).to_dict()
# for k, v in dtype_dict.items():
#     if k in counterfactual_X.columns:
#         if v[:3] == 'int':
#             counterfactual_X[k] = counterfactual_X[k].round().astype(v)
#         else:
#             counterfactual_X[k] = counterfactual_X[k].astype(v)

factual_y = pd.DataFrame(y.detach().numpy(),columns=[target_name], index=factual_X.index)
counterfactual_y = pd.DataFrame(explainer.y.detach().numpy(),columns=[target_name], index=factual_X.index)

In [None]:
# Now, reverse the label encoding using the label_mappings
for dft in [factual_X, counterfactual_X]:
    for column, mapping in label_mappings.items():
        if column in dft.columns:
            # Invert the label mapping dictionary
            inv_mapping = {v: k for k, v in mapping.items()}
            # Map the encoded labels back to the original strings
            dft[column] = dft[column].map(inv_mapping)

In [None]:
factual = factual_X
counterfactual = counterfactual_X

factual[target_name] = factual_y
counterfactual[target_name] = counterfactual_y

In [None]:
factual.head(5)

In [None]:
counterfactual.head(5)

In [None]:
import matplotlib.pyplot as plt

matrix_nu = explainer.wd.nu.detach().numpy()

mu_avg = torch.zeros_like(explainer.swd.mu_list[0])
for mu in explainer.swd.mu_list:
    mu_avg += mu

total_sum = mu_avg.sum()

matrix_mu = mu_avg / total_sum

# Determine the global minimum and maximum values across both matrices
vmin = min(matrix_mu.min(), matrix_nu.min())
vmax = max(matrix_mu.max(), matrix_nu.max())

# Create a figure and a set of subplots
fig, axs = plt.subplots(1, 2, figsize=(20, 8))  # 1 row, 2 columns

# First subplot for matrix_mu
im_mu = axs[0].imshow(matrix_mu, cmap='viridis', vmin=vmin, vmax=vmax)
axs[0].set_title("Heatmap of $\mu$")

# Second subplot for matrix_nu
im_nu = axs[1].imshow(matrix_nu, cmap='viridis', vmin=vmin, vmax=vmax)
axs[1].set_title("Heatmap of $\\nu$")

# Create a colorbar for the whole figure
fig.colorbar(im_mu, ax=axs, orientation='vertical')

# Display the plots
plt.show()