In [None]:
import matplotlib
matplotlib.use('TkAgg')
%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import PowerTransformer
from utils_evaluation import compute_shap_values
from utils_plots import plot_shap_waterfall
import pickle

class CustomPowerTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, columns, all_features):
        """
        Initialize the CustomPowerTransformer.
        
        Parameters:
        - columns: list of feature names to power-transform
        - all_features: list of all feature names
        """
        self.columns = columns
        self.all_features = all_features

    def fit(self, X, y=None):
        """
        Fit the transformer. 
        Here we compute the names for columns to power-transform and fit the PowerTransformer 
        and StandardScaler on the appropriate data.
        """
        # Determine indices of columns to power-transform
        self.column_names = self.all_features[:X.shape[1]]
        self.column_indices = [i for i, col_name in enumerate(self.column_names) if col_name in self.columns]
        
        # Convert X to numpy array if it's a DataFrame
        if isinstance(X, pd.DataFrame):
            X = X.values
        
        # Fit the power transformer to the specified columns
        self.power_transformer = PowerTransformer(method='yeo-johnson', standardize=True)
        self.power_transformer.fit(X[:, self.column_indices])
        
        # Transform the specified columns
        X_transformed = X.copy()
        X_transformed[:, self.column_indices] = self.power_transformer.transform(X[:, self.column_indices])
        
        return self

    def transform(self, X, y=None):
        """
        Apply power transformation and scaling to specified columns.
        """
        # Convert X to numpy array if it's a DataFrame
        if isinstance(X, pd.DataFrame):
            X = X.values
        
        # Apply power transformation to specified columns
        X_transformed = X.copy()
        X_transformed[:, self.column_indices] = self.power_transformer.transform(X[:, self.column_indices])
        
        return X_transformed
    
# Load the processed datasets
with open(r"...", "rb") as pickle_file:
    processed_data = pickle.load(pickle_file)
  
with open(r"...", 'rb') as f:
    model = pickle.load(f)

features = ['B_Index', 'Katagiri_Group', 'Brain', 'Opioid', 'grouped_KPS']
features_num = ['B_Index']

X_train = processed_data["X_train"][features]

shap_values = compute_shap_values(model, X_train)
fig = plot_shap_waterfall(shap_values)

# Display the plot
plt.show()

# Accessing plot attributes
ax = fig.gca()
print("Title:", ax.get_title())
print("X-axis Label:", ax.get_xlabel())
print("Y-axis Label:", ax.get_ylabel())