# Model training for 3-segment reconstructions

## Import

In [347]:
import warnings
import numpy as np
import pandas as pd
import copy
from statistics import mean, stdev
from sklearn.preprocessing import QuantileTransformer
from sklearn.metrics import make_scorer, matthews_corrcoef, f1_score, accuracy_score, average_precision_score, roc_auc_score, brier_score_loss
from sklearn.model_selection import cross_validate, StratifiedKFold, RepeatedStratifiedKFold
from sklearn.linear_model import LogisticRegression 
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from scipy.stats import chi2_contingency
import shap
import statsmodels.formula.api as smf
import seaborn as sns
import matplotlib.pyplot as plt

In [348]:
df_dropped_first = pd.read_parquet('/Users/philipp.lampert/repositories/mymandible/data/dropped_first_imputed.parquet')
df_all_levels = pd.read_parquet('/Users/philipp.lampert/repositories/mymandible/data/all_levels_imputed.parquet')

In [349]:
def get_3_segment_df(df):
    df3 = df[
        (df['flap_segment_count'] == 3) 
        & (df['flap_donor_site___scapula'] == False) 
        & (df['plate_type___cad_mini'] == False)
        & (df['flap_loss'] == False)
        & (df['indication___secondary_reconstruction'] == False)
        & (df['indication___osteoradionecrosis'] == False)
        & (df['days_to_follow_up'] >= 91)
    ].copy()

    # drop unused variables
    df3 = df3.drop(['flap_donor_site___scapula', 
                    'flap_segment_count', 
                    'plate_type___cad_mini', 
                    'urkens_classification___s',
                    'indication___osteoradionecrosis',
                    'indication___secondary_reconstruction',
                    'prior_flap___bony'
                   ], axis=1)
    return df3

In [350]:
df_df = get_3_segment_df(df_dropped_first)
df_all = get_3_segment_df(df_all_levels)

## Preprocessing

In [351]:
from modules.functions import preprocessing as prp
from modules.functions import threshold_optimized_metrics as tom

In [352]:
acc_scorer = make_scorer(tom.optimized_accuracy, needs_proba=True)
f1_scorer = make_scorer(tom.optimized_f1, needs_proba=True)
mcc_scorer = make_scorer(tom.optimized_mcc, needs_proba=True)
pr_auc_scorer = make_scorer(average_precision_score, needs_proba=True)

## Model setup

In [353]:
def logreg_regularized(outcome, scaler, df, method, alpha):
    
    x, y = prp.get_x_y(df=df, outcome=outcome, min_follow_up_days=91, scaler=scaler, drop_cols=drop_cols, inverse_pos=False)
    boolean_columns = x.select_dtypes(include=bool).columns
    x[boolean_columns] = x[boolean_columns].astype('int')
    numeric_columns = x.select_dtypes(include='number').columns
    x[numeric_columns] = x[numeric_columns].astype('float64')
    y = y.astype('int')    
    x_columns = x.columns
    all_columns = "+".join(x_columns)
    formula = outcome +  '~' + all_columns
    
    data = pd.concat([x, y], axis=1)
    final_model = smf.logit(formula, data).fit_regularized(method=method, alpha=alpha)
    print(final_model.summary())

In [354]:
drop_cols = [
    #'sex_female', 
    #'comorbidity___smoking', 
    #'comorbidity___alcohol',
    'comorbidity___copd', 
    #'comorbidity___hypertension',
    'comorbidity___diabetes', 
    #'comorbidity___atherosclerosis',
    'comorbidity___hyperlipidemia', 
    'comorbidity___hypothyroidism',
    'comorbidity___chronic_kidney_disease',
    'comorbidity___autoimmune_disease', 
    #'age_surgery_years',
    'radiotherapy___pre_surgery', 
    #'radiotherapy___post_surgery',
    'chemotherapy___pre_surgery', 
    #'chemotherapy___post_surgery',
    'urkens_classification___c', 
    'urkens_classification___r',
    'surgery_duration_min', 
    #'bmi', 
    'skin_transplanted',
    'prior_flap___non_bony', 
    #'plate_type___cad_mix'
]

## Patient characteristics

In [355]:
def num_variable(variable):
    cad_mix_values = df_df.loc[df_df['plate_type___cad_mix'], variable]
    cad_long_values = df_df.loc[~df_df['plate_type___cad_mix'], variable]
    
    cad_mix_mean = round(cad_mix_values.mean(), 1)
    cad_mix_std = round(cad_mix_values.std(), 1)
    cad_long_mean = round(cad_long_values.mean(), 1)
    cad_long_std = round(cad_long_values.std(), 1)
    
    overall_mean = round(df_df[variable].mean(), 1)
    overall_std = round(df_df[variable].std(), 1)
    
    data = {
        'Mix': [cad_mix_mean, cad_mix_std],
        'Reco': [cad_long_mean, cad_long_std],
        'Overall': [overall_mean, overall_std]
    }
    
    df = pd.DataFrame(data, index=['mean', 'std'])
    print(df)

In [356]:
def cat_variable(variable):
    # Absolute frequencies
    cad_mix_counts = df_df.loc[df_df['plate_type___cad_mix'], variable].value_counts()
    cad_long_counts = df_df.loc[~df_df['plate_type___cad_mix'], variable].value_counts()
    overall_counts = df_df[variable].value_counts()
    
    # Relative probabilities
    cad_mix_probs = round((cad_mix_counts / cad_mix_counts.sum())*100, 1)
    cad_long_probs = round((cad_long_counts / cad_long_counts.sum())*100, 1)
    overall_probs = round((overall_counts / overall_counts.sum())*100, 1)
    
    # Create DataFrames for absolute frequencies and relative probabilities
    absolute_freq_df = pd.DataFrame({
        'Mix': cad_mix_counts,
        'Reco': cad_long_counts,
        'Overall': overall_counts
    }).fillna(0)  # Fill NaN values with 0
    
    relative_prob_df = pd.DataFrame({
        'Mix': cad_mix_probs,
        'Reco': cad_long_probs,
        'Overall': overall_probs
    }).fillna(0)  # Fill NaN values with 0
    
    print("Absolute frequencies:")
    print(absolute_freq_df)
    
    print("\nRelative probabilities:")
    print(relative_prob_df)

In [357]:
num_variable('age_surgery_years')

       Mix  Reco  Overall
mean  62.3  63.1     62.8
std    7.5   9.7      9.0


In [358]:
cat_variable('sex_female')

Absolute frequencies:
            Mix  Reco  Overall
sex_female                    
False        14    27       41
True          5     9       14

Relative probabilities:
             Mix  Reco  Overall
sex_female                     
False       73.7  75.0     74.5
True        26.3  25.0     25.5


In [359]:
num_variable('bmi')

            Mix  Reco  Overall
mean  24.299999  23.0     23.4
std    3.700000   4.4      4.2


In [360]:
cat_variable('comorbidity___smoking')

Absolute frequencies:
                       Mix  Reco  Overall
comorbidity___smoking                    
False                   12    22       34
True                     7    14       21

Relative probabilities:
                        Mix  Reco  Overall
comorbidity___smoking                     
False                  63.2  61.1     61.8
True                   36.8  38.9     38.2


In [361]:
cat_variable('comorbidity___atherosclerosis')

Absolute frequencies:
                               Mix  Reco  Overall
comorbidity___atherosclerosis                    
False                           17    28       45
True                             2     8       10

Relative probabilities:
                                Mix  Reco  Overall
comorbidity___atherosclerosis                     
False                          89.5  77.8     81.8
True                           10.5  22.2     18.2


In [362]:
cat_variable('comorbidity___alcohol')

Absolute frequencies:
                       Mix  Reco  Overall
comorbidity___alcohol                    
False                   14    28       42
True                     5     8       13

Relative probabilities:
                        Mix  Reco  Overall
comorbidity___alcohol                     
False                  73.7  77.8     76.4
True                   26.3  22.2     23.6


In [363]:
cat_variable('comorbidity___hypertension')

Absolute frequencies:
                            Mix  Reco  Overall
comorbidity___hypertension                    
False                        10    24       34
True                          9    12       21

Relative probabilities:
                             Mix  Reco  Overall
comorbidity___hypertension                     
False                       52.6  66.7     61.8
True                        47.4  33.3     38.2


In [364]:
cat_variable('radiotherapy___post_surgery')

Absolute frequencies:
                             Mix  Reco  Overall
radiotherapy___post_surgery                    
True                          12    19       31
False                          7    17       24

Relative probabilities:
                              Mix  Reco  Overall
radiotherapy___post_surgery                     
True                         63.2  52.8     56.4
False                        36.8  47.2     43.6


In [384]:
cat_variable('chemotherapy___post_surgery')

Absolute frequencies:
                             Mix  Reco  Overall
chemotherapy___post_surgery                    
False                         14    20       34
True                           5    16       21

Relative probabilities:
                              Mix  Reco  Overall
chemotherapy___post_surgery                     
False                        73.7  55.6     61.8
True                         26.3  44.4     38.2


In [365]:
num_variable('days_to_follow_up')

        Mix   Reco  Overall
mean  442.3  679.3    597.4
std   213.1  555.5    477.6


## Logistic Regression

### Any complication

In [366]:
df_df['comorbidity___autoimmune_disease'].value_counts()

comorbidity___autoimmune_disease
False    52
True      3
Name: count, dtype: Int64

In [367]:
logreg_regularized('any_complication', 'None', df_df, 'l1', alpha=0)

Optimization terminated successfully    (Exit mode 0)
            Current function value: 0.4047605114111938
            Iterations: 62
            Function evaluations: 67
            Gradient evaluations: 62
                           Logit Regression Results                           
Dep. Variable:       any_complication   No. Observations:                   55
Model:                          Logit   Df Residuals:                       44
Method:                           MLE   Df Model:                           10
Date:                Fri, 23 Feb 2024   Pseudo R-squ.:                  0.3287
Time:                        17:07:52   Log-Likelihood:                -22.262
converged:                       True   LL-Null:                       -33.163
Covariance Type:            nonrobust   LLR p-value:                   0.01614
                                    coef    std err          z      P>|z|      [0.025      0.975]
------------------------------------------------------------

### Soft tissue complication

In [368]:
df_df['soft_tissue_complication'].value_counts()

soft_tissue_complication
True     34
False    21
Name: count, dtype: Int64

In [369]:
logreg_regularized('soft_tissue_complication', 'None', df_df, 'l1', alpha=0)

Optimization terminated successfully    (Exit mode 0)
            Current function value: 0.4971603208500139
            Iterations: 61
            Function evaluations: 65
            Gradient evaluations: 61
                              Logit Regression Results                              
Dep. Variable:     soft_tissue_complication   No. Observations:                   55
Model:                                Logit   Df Residuals:                       44
Method:                                 MLE   Df Model:                           10
Date:                      Fri, 23 Feb 2024   Pseudo R-squ.:                  0.2523
Time:                              17:07:52   Log-Likelihood:                -27.344
converged:                             True   LL-Null:                       -36.572
Covariance Type:                  nonrobust   LLR p-value:                   0.04773
                                    coef    std err          z      P>|z|      [0.025      0.975]
------------

### Nonunion

In [370]:
df_df['nonunion'].value_counts()

nonunion
False    19
True     17
Name: count, dtype: Int64

In [371]:
logreg_regularized('nonunion', 'None', df_df, 'l1', alpha=0)

Optimization terminated successfully    (Exit mode 0)
            Current function value: 0.5951844157905586
            Iterations: 58
            Function evaluations: 62
            Gradient evaluations: 58
                           Logit Regression Results                           
Dep. Variable:               nonunion   No. Observations:                   36
Model:                          Logit   Df Residuals:                       25
Method:                           MLE   Df Model:                           10
Date:                Fri, 23 Feb 2024   Pseudo R-squ.:                  0.1394
Time:                        17:07:52   Log-Likelihood:                -21.427
converged:                       True   LL-Null:                       -24.898
Covariance Type:            nonrobust   LLR p-value:                    0.7309
                                    coef    std err          z      P>|z|      [0.025      0.975]
------------------------------------------------------------

### Wound infection

In [372]:
df_df['wound_infection'].value_counts()

wound_infection
False    37
True     18
Name: count, dtype: Int64

In [373]:
logreg_regularized('wound_infection', 'None', df_df, 'l1', alpha=0)

Optimization terminated successfully    (Exit mode 0)
            Current function value: 0.3786375965837924
            Iterations: 65
            Function evaluations: 69
            Gradient evaluations: 65
                           Logit Regression Results                           
Dep. Variable:        wound_infection   No. Observations:                   55
Model:                          Logit   Df Residuals:                       44
Method:                           MLE   Df Model:                           10
Date:                Fri, 23 Feb 2024   Pseudo R-squ.:                  0.4011
Time:                        17:07:52   Log-Likelihood:                -20.825
converged:                       True   LL-Null:                       -34.773
Covariance Type:            nonrobust   LLR p-value:                  0.001876
                                    coef    std err          z      P>|z|      [0.025      0.975]
------------------------------------------------------------

### Plate exposure

In [374]:
df_df['complication_plate___exposure'].value_counts()

complication_plate___exposure
False    35
True     20
Name: count, dtype: Int64

In [375]:
logreg_regularized('complication_plate___exposure', 'None', df_df, 'l1', alpha=0)

Optimization terminated successfully    (Exit mode 0)
            Current function value: 0.5024857942233203
            Iterations: 58
            Function evaluations: 64
            Gradient evaluations: 58
                                 Logit Regression Results                                
Dep. Variable:     complication_plate___exposure   No. Observations:                   55
Model:                                     Logit   Df Residuals:                       44
Method:                                      MLE   Df Model:                           10
Date:                           Fri, 23 Feb 2024   Pseudo R-squ.:                  0.2334
Time:                                   17:07:53   Log-Likelihood:                -27.637
converged:                                  True   LL-Null:                       -36.051
Covariance Type:                       nonrobust   LLR p-value:                   0.07822
                                    coef    std err          z      P>

## Univariate Analysis

In [376]:
def chi2_test(outcome, df):
    contingency = pd.crosstab(df[outcome], df['plate_type___cad_mix'])
    c, p, dof, expected = chi2_contingency(contingency)
    print(contingency)
    print(f"p-value: {p}")

### ORN

In [377]:
chi2_test('orn', df_df)

plate_type___cad_mix  False  True 
orn                               
False                    32     15
True                      4      4
p-value: 0.553674677711491


### Plate failure

In [378]:
chi2_test('plate_failure', df_df)

plate_type___cad_mix  False  True 
plate_failure                     
False                    36     15
True                      0      4
p-value: 0.020726005952780586


### Any complication

In [379]:
chi2_test('any_complication', df_df)

plate_type___cad_mix  False  True 
any_complication                  
False                    12      4
True                     24     15
p-value: 0.5212815473335353


### Soft tissue complication

In [380]:
chi2_test('soft_tissue_complication', df_df)

plate_type___cad_mix      False  True 
soft_tissue_complication              
False                        15      6
True                         21     13
p-value: 0.6596437720763275


### Nonunion

In [381]:
chi2_test('nonunion', df_df)

plate_type___cad_mix  False  True 
nonunion                          
False                    13      6
True                      9      8
p-value: 0.5427035080655522


### Wound infection

In [382]:
chi2_test('wound_infection', df_df)

plate_type___cad_mix  False  True 
wound_infection                   
False                    27     10
True                      9      9
p-value: 0.16789902340466287


### Plate exposure

In [383]:
chi2_test('complication_plate___exposure', df_df)

plate_type___cad_mix           False  True 
complication_plate___exposure              
False                             25     10
True                              11      9
p-value: 0.34834487194004304
