# Trajectory of Outcomes After A Distal Radius Fracture
A python version of the R analysis report.

## Environment Setup
First, install the dependencies:  
`pip install pandas numpy seaborn statsmodels plotly scikit-learn`

In [1]:
# Packages
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from statsmodels.regression.mixed_linear_model import MixedLM
import plotly.express as px
import plotly.graph_objects as go

# Set plotting style
sns.set_theme()

print("All dependencies are ready")

All dependencies are ready


# Data Import

In [2]:
# Read the CSV file with index_col=0 to use the first column as index
c1 = pd.read_csv("../data/data.csv", index_col=0)
c1 = c1.reset_index().rename(columns={'index': 'X'})  # Convert index to column named 'X'
c1.head()

Unnamed: 0,X,MRN,Age at injury,Sex,ISS,CAD,Hypertension,Osteoporosis,Diabetes,Substance abuse,...,SpecificActivities_5Y,Total_5Y,UsualActivities_5Y,Procedure,admission date,Discharge,Procedure date,Aditional procedures,Revision date,Revision procedure
0,1,15390,79,F,9,,,,,,...,,,,ORIF,2006-08-05,2006-08-08,2006-08-07,,none,none
1,2,36020,63,F,9,Present,Present,,,,...,,,,ORIF,2013-01-10,2013-01-10,2013-01-10,,41305,Revision of IF device
2,3,45377,58,F,9,,,,,,...,,,,ORIF,2009-11-24,2009-11-26,2009-11-25,,none,none
3,4,74420,77,F,9,,Present,,,,...,0.0,5.0,10.0,ORIF,2006-04-03,2006-04-06,2006-04-05,,none,none
4,5,120632,63,F,9,,,,,,...,1.0,6.5,0.0,ORIF,2011-11-01,2011-11-01,2011-11-01,,none,none


In [3]:
# If the column name has space, replace it with underscore.
# If the column name has special character, replace it with underscore.
c1.columns = c1.columns.str.replace(' ', '_')
c1.columns = c1.columns.str.replace('[^a-zA-Z0-9_]', '_', regex=True)
c1.head()

Unnamed: 0,X,MRN,Age_at_injury,Sex,ISS,CAD,Hypertension,Osteoporosis,Diabetes,Substance_abuse,...,SpecificActivities_5Y,Total_5Y,UsualActivities_5Y,Procedure,admission_date,Discharge,Procedure_date,Aditional_procedures,Revision_date,Revision_procedure
0,1,15390,79,F,9,,,,,,...,,,,ORIF,2006-08-05,2006-08-08,2006-08-07,,none,none
1,2,36020,63,F,9,Present,Present,,,,...,,,,ORIF,2013-01-10,2013-01-10,2013-01-10,,41305,Revision of IF device
2,3,45377,58,F,9,,,,,,...,,,,ORIF,2009-11-24,2009-11-26,2009-11-25,,none,none
3,4,74420,77,F,9,,Present,,,,...,0.0,5.0,10.0,ORIF,2006-04-03,2006-04-06,2006-04-05,,none,none
4,5,120632,63,F,9,,,,,,...,1.0,6.5,0.0,ORIF,2011-11-01,2011-11-01,2011-11-01,,none,none


In [4]:
# Convert binary categories to 0/1
binary_cols = ['Sex', 'CAD', 'Hypertension', 'Osteoporosis', 'Diabetes', 
               'Substance_abuse', 'Alcohol_abuse', 'Depression', 'Anxiety_disorder',
               'Psychosis', 'Malignancy', 'Stroke_TIA', 'Previous_orthopedic_trauma']

# Convert Sex F->0, M->1, and others None->0, Present->1
c1['Sex'] = (c1['Sex'] == 'M').astype(int)
for col in binary_cols[1:]:
    c1[col] = (c1[col] != 'None').astype(int)
# Special case for Revision_procedure
c1['Revision_procedure'] = (c1['Revision_procedure'] == 'Removal of  device').astype(int)

# Check the data
c1.describe()

Unnamed: 0,X,MRN,Age_at_injury,Sex,ISS,CAD,Hypertension,Osteoporosis,Diabetes,Substance_abuse,...,Pain_1Y,SpecificActivities_1Y,Total_1Y,UsualActivities_1Y,Function_5Y,Pain_5Y,SpecificActivities_5Y,Total_5Y,UsualActivities_5Y,Revision_procedure
count,447.0,447.0,447.0,447.0,447.0,447.0,447.0,447.0,447.0,447.0,...,319.0,319.0,319.0,319.0,305.0,305.0,305.0,305.0,305.0,447.0
mean,224.0,2552160.0,51.064877,0.326622,9.138702,1.0,1.0,1.0,1.0,1.0,...,9.413793,6.896552,15.183386,4.636364,3.37541,6.311475,4.170492,9.670492,2.560656,0.085011
std,129.182042,1062478.0,15.957842,0.469503,1.933537,0.0,0.0,0.0,0.0,0.0,...,9.616446,10.707898,17.593061,6.912257,6.484106,8.323518,8.410742,13.78824,5.016134,0.279211
min,1.0,15390.0,16.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,112.5,1625768.0,39.0,0.0,9.0,1.0,1.0,1.0,1.0,1.0,...,2.0,0.0,3.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,224.0,3024391.0,54.0,0.0,9.0,1.0,1.0,1.0,1.0,1.0,...,6.0,2.0,9.0,2.0,1.0,4.0,1.0,5.0,0.0,0.0
75%,335.5,3495483.0,63.0,1.0,9.0,1.0,1.0,1.0,1.0,1.0,...,14.0,8.0,20.75,6.0,4.0,8.0,4.0,12.5,3.0,0.0
max,447.0,3991658.0,93.0,1.0,27.0,1.0,1.0,1.0,1.0,1.0,...,46.0,58.0,87.0,34.0,42.0,41.0,52.0,83.0,32.0,1.0


# Data Clean

In [5]:
# Remove rows where all Total scores are NA or 0
total_cols = ['Total_3M', 'Total_6M', 'Total_1Y', 'Total_5Y']
c1 = c1[~(c1[total_cols].isna().all(axis=1) | (c1[total_cols] == 0).all(axis=1))]
c1.describe()

Unnamed: 0,X,MRN,Age_at_injury,Sex,ISS,CAD,Hypertension,Osteoporosis,Diabetes,Substance_abuse,...,Pain_1Y,SpecificActivities_1Y,Total_1Y,UsualActivities_1Y,Function_5Y,Pain_5Y,SpecificActivities_5Y,Total_5Y,UsualActivities_5Y,Revision_procedure
count,423.0,423.0,423.0,423.0,423.0,423.0,423.0,423.0,423.0,423.0,...,319.0,319.0,319.0,319.0,305.0,305.0,305.0,305.0,305.0,423.0
mean,224.456265,2557278.0,51.205674,0.330969,9.122931,1.0,1.0,1.0,1.0,1.0,...,9.413793,6.896552,15.183386,4.636364,3.37541,6.311475,4.170492,9.670492,2.560656,0.08747
std,129.182766,1056225.0,15.882052,0.471119,1.928587,0.0,0.0,0.0,0.0,0.0,...,9.616446,10.707898,17.593061,6.912257,6.484106,8.323518,8.410742,13.78824,5.016134,0.282858
min,1.0,15390.0,17.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,112.5,1625768.0,39.0,0.0,9.0,1.0,1.0,1.0,1.0,1.0,...,2.0,0.0,3.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,224.0,3024391.0,55.0,0.0,9.0,1.0,1.0,1.0,1.0,1.0,...,6.0,2.0,9.0,2.0,1.0,4.0,1.0,5.0,0.0,0.0
75%,335.5,3495483.0,63.0,1.0,9.0,1.0,1.0,1.0,1.0,1.0,...,14.0,8.0,20.75,6.0,4.0,8.0,4.0,12.5,3.0,0.0
max,447.0,3991658.0,93.0,1.0,27.0,1.0,1.0,1.0,1.0,1.0,...,46.0,58.0,87.0,34.0,42.0,41.0,52.0,83.0,32.0,1.0


In [6]:
# Data for LME
# Filter for at least two non-NA data points
c3 = c1[c1[total_cols].isna().sum(axis=1) <= 2].copy()

# Create easyc3 with regrouped attributes
cols_to_keep = ['X', 'MRN', 'Age_at_injury', 'Sex', 'ISS', 'CAD', 'Hypertension', 
                'Osteoporosis', 'Diabetes', 'Substance_abuse', 'Alcohol_abuse',
                'Depression', 'Anxiety_disorder', 'Psychosis', 'Malignancy', 
                'Stroke_TIA', 'Previous_orthopedic_trauma', 'Revision_procedure',
                'Function_Baseline', 'Pain_Baseline', 'Total_Baseline',
                'Function_3M', 'Pain_3M', 'Total_3M', 'Function_6M', 'Pain_6M', 
                'Total_6M', 'Function_1Y', 'Pain_1Y', 'Total_1Y',
                'Function_5Y', 'Pain_5Y', 'Total_5Y']

easyc3 = c3[cols_to_keep].copy()

easyc3.describe()

Unnamed: 0,X,MRN,Age_at_injury,Sex,ISS,CAD,Hypertension,Osteoporosis,Diabetes,Substance_abuse,...,Total_3M,Function_6M,Pain_6M,Total_6M,Function_1Y,Pain_1Y,Total_1Y,Function_5Y,Pain_5Y,Total_5Y
count,377.0,377.0,377.0,377.0,377.0,377.0,377.0,377.0,377.0,377.0,...,136.0,373.0,373.0,373.0,312.0,312.0,312.0,302.0,302.0,302.0
mean,225.62069,2573012.0,51.355438,0.31565,9.196286,1.0,1.0,1.0,1.0,1.0,...,25.080882,8.088472,11.962466,20.037534,5.741987,9.298077,15.040064,3.299669,6.215232,9.498344
std,128.305225,1046806.0,15.85526,0.465392,1.860476,0.0,0.0,0.0,0.0,0.0,...,20.438364,9.521281,9.555648,18.35829,8.615208,9.601594,17.629288,6.365561,8.210076,13.529437
min,1.0,15390.0,17.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,114.0,1629372.0,39.0,0.0,9.0,1.0,1.0,1.0,1.0,1.0,...,9.0,1.0,5.0,6.5,0.5,2.0,3.0,0.0,0.0,0.0
50%,226.0,3044895.0,55.0,0.0,9.0,1.0,1.0,1.0,1.0,1.0,...,19.75,4.5,10.0,14.5,2.25,6.0,8.5,0.75,4.0,5.0
75%,333.0,3493755.0,63.0,1.0,9.0,1.0,1.0,1.0,1.0,1.0,...,37.125,11.5,17.0,28.5,7.5,14.0,20.5,4.0,8.0,12.5
max,447.0,3991658.0,93.0,1.0,27.0,1.0,1.0,1.0,1.0,1.0,...,88.5,45.5,45.0,89.5,42.0,46.0,87.0,42.0,41.0,83.0


In [7]:
# Create combined categories
easyc3['SubAbuse'] = ((easyc3['Substance_abuse'] + easyc3['Alcohol_abuse']) != 0).astype(int)
easyc3['Mental_illness'] = ((easyc3['Depression'] + easyc3['Anxiety_disorder'] + 
                            easyc3['Psychosis']) != 0).astype(int)


In [8]:
# Select final columns for easyc3
final_cols = ['X', 'MRN', 'Age_at_injury', 'Sex', 'ISS', 'CAD', 'Hypertension',
              'Osteoporosis', 'Diabetes', 'SubAbuse', 'Mental_illness',
              'Malignancy', 'Stroke_TIA', 'Previous_orthopedic_trauma', 
              'Revision_procedure', 'Function_Baseline', 'Pain_Baseline', 
              'Total_Baseline'] + [col for col in easyc3.columns 
                                  if any(x in col for x in ['_3M', '_6M', '_1Y', '_5Y'])]

easyc3 = easyc3[final_cols]

easyc3.describe()

Unnamed: 0,X,MRN,Age_at_injury,Sex,ISS,CAD,Hypertension,Osteoporosis,Diabetes,SubAbuse,...,Total_3M,Function_6M,Pain_6M,Total_6M,Function_1Y,Pain_1Y,Total_1Y,Function_5Y,Pain_5Y,Total_5Y
count,377.0,377.0,377.0,377.0,377.0,377.0,377.0,377.0,377.0,377.0,...,136.0,373.0,373.0,373.0,312.0,312.0,312.0,302.0,302.0,302.0
mean,225.62069,2573012.0,51.355438,0.31565,9.196286,1.0,1.0,1.0,1.0,1.0,...,25.080882,8.088472,11.962466,20.037534,5.741987,9.298077,15.040064,3.299669,6.215232,9.498344
std,128.305225,1046806.0,15.85526,0.465392,1.860476,0.0,0.0,0.0,0.0,0.0,...,20.438364,9.521281,9.555648,18.35829,8.615208,9.601594,17.629288,6.365561,8.210076,13.529437
min,1.0,15390.0,17.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,114.0,1629372.0,39.0,0.0,9.0,1.0,1.0,1.0,1.0,1.0,...,9.0,1.0,5.0,6.5,0.5,2.0,3.0,0.0,0.0,0.0
50%,226.0,3044895.0,55.0,0.0,9.0,1.0,1.0,1.0,1.0,1.0,...,19.75,4.5,10.0,14.5,2.25,6.0,8.5,0.75,4.0,5.0
75%,333.0,3493755.0,63.0,1.0,9.0,1.0,1.0,1.0,1.0,1.0,...,37.125,11.5,17.0,28.5,7.5,14.0,20.5,4.0,8.0,12.5
max,447.0,3991658.0,93.0,1.0,27.0,1.0,1.0,1.0,1.0,1.0,...,88.5,45.5,45.0,89.5,42.0,46.0,87.0,42.0,41.0,83.0


In [10]:
# Create long format data
id_vars = ['X', 'MRN', 'Age_at_injury', 'Sex', 'ISS', 'CAD', 'Hypertension',
           'Osteoporosis', 'Diabetes', 'SubAbuse', 'Mental_illness',
           'Malignancy', 'Stroke_TIA', 'Previous_orthopedic_trauma',
           'Revision_procedure', 'Function_Baseline', 'Pain_Baseline', 'Total_Baseline']

longc3 = pd.wide_to_long(easyc3, 
                        stubnames=['Function', 'Pain', 'Total'],
                        i=id_vars,
                        j='period',
                        suffix='_(3M|6M|1Y|5Y)',
                        sep='').reset_index()
longc3.describe()

Unnamed: 0,X,MRN,Age_at_injury,Sex,ISS,CAD,Hypertension,Osteoporosis,Diabetes,SubAbuse,...,Malignancy,Stroke_TIA,Previous_orthopedic_trauma,Revision_procedure,Function_Baseline,Pain_Baseline,Total_Baseline,Function,Pain,Total
count,1508.0,1508.0,1508.0,1508.0,1508.0,1508.0,1508.0,1508.0,1508.0,1508.0,...,1508.0,1508.0,1508.0,1508.0,1396.0,1396.0,1396.0,1123.0,1123.0,1123.0
mean,225.62069,2573012.0,51.355438,0.31565,9.196286,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,0.095491,0.553009,1.260745,1.813754,6.60285,9.825467,16.425646
std,128.177452,1045763.0,15.839471,0.464928,1.858623,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.293989,3.023098,3.793702,6.127181,9.110183,9.674829,18.003405
min,1.0,15390.0,17.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,114.0,1629372.0,39.0,0.0,9.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.5,3.0,3.5
50%,226.0,3044895.0,55.0,0.0,9.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,0.0,0.0,0.0,0.0,2.5,7.0,10.5
75%,333.0,3493755.0,63.0,1.0,9.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,0.0,0.0,0.0,1.0,9.0,14.5,23.0
max,447.0,3991658.0,93.0,1.0,27.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,38.5,39.0,69.5,45.5,46.0,89.5


In [11]:
# Convert period to months
period_to_months = {'3M': 3, '6M': 6, '1Y': 12, '5Y': 60}
longc3['month'] = longc3['period'].map(period_to_months)
longc3.describe()

Unnamed: 0,X,MRN,Age_at_injury,Sex,ISS,CAD,Hypertension,Osteoporosis,Diabetes,SubAbuse,...,Stroke_TIA,Previous_orthopedic_trauma,Revision_procedure,Function_Baseline,Pain_Baseline,Total_Baseline,Function,Pain,Total,month
count,1508.0,1508.0,1508.0,1508.0,1508.0,1508.0,1508.0,1508.0,1508.0,1508.0,...,1508.0,1508.0,1508.0,1396.0,1396.0,1396.0,1123.0,1123.0,1123.0,0.0
mean,225.62069,2573012.0,51.355438,0.31565,9.196286,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,0.095491,0.553009,1.260745,1.813754,6.60285,9.825467,16.425646,
std,128.177452,1045763.0,15.839471,0.464928,1.858623,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.293989,3.023098,3.793702,6.127181,9.110183,9.674829,18.003405,
min,1.0,15390.0,17.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,
25%,114.0,1629372.0,39.0,0.0,9.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,0.0,0.0,0.0,0.0,0.5,3.0,3.5,
50%,226.0,3044895.0,55.0,0.0,9.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,0.0,0.0,0.0,0.0,2.5,7.0,10.5,
75%,333.0,3493755.0,63.0,1.0,9.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,0.0,0.0,0.0,1.0,9.0,14.5,23.0,
max,447.0,3991658.0,93.0,1.0,27.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,38.5,39.0,69.5,45.5,46.0,89.5,


In [12]:
# Create lme_longc3 by dropping NA values
lme_longc3 = longc3.dropna(subset=['Total'])
lme_longc3.describe()

Unnamed: 0,X,MRN,Age_at_injury,Sex,ISS,CAD,Hypertension,Osteoporosis,Diabetes,SubAbuse,...,Stroke_TIA,Previous_orthopedic_trauma,Revision_procedure,Function_Baseline,Pain_Baseline,Total_Baseline,Function,Pain,Total,month
count,1123.0,1123.0,1123.0,1123.0,1123.0,1123.0,1123.0,1123.0,1123.0,1123.0,...,1123.0,1123.0,1123.0,1046.0,1046.0,1046.0,1123.0,1123.0,1123.0,0.0
mean,226.367765,2569656.0,52.044524,0.299199,9.180766,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,0.099733,0.625239,1.24283,1.868069,6.60285,9.825467,16.425646,
std,130.285806,1050690.0,15.389876,0.458111,1.750046,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.299777,3.403572,3.819657,6.622803,9.110183,9.674829,18.003405,
min,1.0,15390.0,17.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,
25%,113.0,1628921.0,41.0,0.0,9.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,0.0,0.0,0.0,0.0,0.5,3.0,3.5,
50%,223.0,3015998.0,55.0,0.0,9.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,0.0,0.0,0.0,0.0,2.5,7.0,10.5,
75%,337.0,3497326.0,63.0,1.0,9.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,0.0,0.0,0.0,0.5,9.0,14.5,23.0,
max,447.0,3991658.0,93.0,1.0,27.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,38.5,39.0,69.5,45.5,46.0,89.5,


In [18]:
# count missing values
# lme_longc3.groupby('period')['Total_Baseline'].isnull().value_counts()
# count missing values
print("\nMissing values by period:")
print(lme_longc3.groupby('period')['Total_Baseline'].apply(lambda x: x.isna().sum()))



Missing values by period:
period
_1Y    19
_3M    10
_5Y    20
_6M    28
Name: Total_Baseline, dtype: int64


In [25]:
# LME Plot function

fig = plt.figure(figsize=(10, 20))

# Fit individual linear models for each subject
subjects = lme_longc3['X'].unique()
coefficients = []

for subject in subjects:
    subject_data = lme_longc3[lme_longc3['X'] == subject]
    if len(subject_data) >= 2:  # Need at least 2 points for regression
        X = subject_data['month'].values.reshape(-1, 1)
        y = subject_data['Total'].values
        try:
            slope, intercept = np.polyfit(X.ravel(), y, 1)
            coefficients.append({'subject': subject, 'slope': slope, 
                                'intercept': intercept})
        except:
            continue
    # Plot individual regression lines
coef_df = pd.DataFrame(coefficients)

plt.scatter(coef_df['intercept'], coef_df['slope'], alpha=0.5)
plt.xlabel('Intercept')
plt.ylabel('Slope')
plt.title('Individual Regression Coefficients')

plt.show() 

KeyError: 'intercept'

<Figure size 1000x2000 with 0 Axes>

In [23]:
# Create the plot
lme_plot = plot_individual_slopes()
plt.show()

KeyError: 'intercept'

<Figure size 1000x2000 with 0 Axes>