# Trauma Transfer Learning Experiment

This notebook compares baseline survival models and a transfer learning approach using Indian and Jordan trauma datasets.

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from lifelines import CoxPHFitter, KaplanMeierFitter
from lifelines.utils import concordance_index
from sksurv.ensemble import GradientBoostingSurvivalAnalysis

sns.set(style='whitegrid')


## Load datasets

In [2]:
india_file = 'trauma_india_brain_injury.csv'
jordan_file = 'traumatic_brain_injury.csv'

cols_india = ['age','sex','sbp_1','hr_1','rr_1','gcs_t_1','doa','toa','dodd','todd','died']
india = pd.read_csv(india_file, usecols=cols_india)

cols_jordan = ['Gender','age of diagnosis','ER-HR','ER-RR','ER-systolic BP','GCS in ER','length of stay in the hospital (in days)','outcome']
jordan = pd.read_csv(jordan_file, usecols=cols_jordan)


In [3]:
india

Unnamed: 0,age,sex,doa,toa,sbp_1,rr_1,hr_1,gcs_t_1,died,dodd,todd
0,40,Male,7/8/2276,16:00,120.0,22.0,70.0,15.0,No,7/21/2276,16:00
1,27,Male,7/31/2280,17:30,130.0,15.0,84.0,5.0,No,9/3/2280,16:00
2,45,Male,8/1/2108,13:40,110.0,,80.0,15.0,No,8/7/2108,12:40
3,50,Male,7/16/2132,21:50,,24.0,88.0,4.0,Yes,7/19/2132,7:30
4,50,Female,7/3/2047,23:50,130.0,22.0,80.0,3.0,No,7/6/2047,12:30
...,...,...,...,...,...,...,...,...,...,...,...
7973,8,Female,9/27/2150,7:00,102.0,,100.0,11.0,No,9/28/2150,12:00
7974,45,Male,9/17/2174,13:20,130.0,14.0,72.0,11.0,No,9/26/2174,10:00
7975,11,Male,9/16/2237,21:50,100.0,,88.0,15.0,No,9/19/2237,10:00
7976,45,Female,,,110.0,24.0,90.0,6.0,Yes,1/1/2196,19:40


In [4]:
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer

# Select only numeric columns for imputation
jordan_numeric = jordan.select_dtypes(include=[np.number])

# Perform iterative imputation
imputer = IterativeImputer(random_state=0)
jordan_imputed_array = imputer.fit_transform(jordan_numeric)

# Replace numeric columns in jordan with imputed values
jordan[jordan_numeric.columns] = jordan_imputed_array

print(jordan.head())

   Gender  age of diagnosis  GCS in ER        ER-HR      ER-RR ER-systolic BP  \
0  Female               6.0       15.0  tachycardia  tachypnea         normal   
1    Male               7.0       15.0       normal     normal         normal   
2    Male              13.0        3.0       normal     normal         normal   
3    Male               9.0       13.0       normal     normal         normal   
4    Male               1.5       15.0       normal     normal         normal   

   length of stay in the hospital (in days)   outcome  
0                                       5.0  survival  
1                                       7.0  survival  
2                                       7.0  survival  
3                                      28.0  survival  
4                                       3.0  survival  


## Feature engineering

In [5]:
# Pediatric reference ranges (example, adjust as needed for your population)
def hr_category(age, hr):
    # Age in years, HR in bpm
    if pd.isnull(hr) or pd.isnull(age):
        return np.nan
    if age < 1:
        if hr < 100: return 'bradycardia'
        elif hr > 160: return 'tachycardia'
        else: return 'normal'
    elif age < 3:
        if hr < 90: return 'bradycardia'
        elif hr > 150: return 'tachycardia'
        else: return 'normal'
    elif age < 6:
        if hr < 80: return 'bradycardia'
        elif hr > 140: return 'tachycardia'
        else: return 'normal'
    elif age < 12:
        if hr < 70: return 'bradycardia'
        elif hr > 120: return 'tachycardia'
        else: return 'normal'
    else:
        if hr < 60: return 'bradycardia'
        elif hr > 100: return 'tachycardia'
        else: return 'normal'

def rr_category(age, rr):
    # Age in years, RR in breaths/min
    if pd.isnull(rr) or pd.isnull(age):
        return np.nan
    if age < 1:
        if rr < 30: return 'bradypnea'
        elif rr > 60: return 'tachypnea'
        else: return 'normal'
    elif age < 3:
        if rr < 24: return 'bradypnea'
        elif rr > 40: return 'tachypnea'
        else: return 'normal'
    elif age < 6:
        if rr < 22: return 'bradypnea'
        elif rr > 34: return 'tachypnea'
        else: return 'normal'
    elif age < 12:
        if rr < 18: return 'bradypnea'
        elif rr > 30: return 'tachypnea'
        else: return 'normal'
    else:
        if rr < 12: return 'bradypnea'
        elif rr > 20: return 'tachypnea'
        else: return 'normal'

def sbp_category(age, sbp):
    # Age in years, SBP in mmHg
    if pd.isnull(sbp) or pd.isnull(age):
        return np.nan
    if age < 1:
        if sbp < 70: return 'hypotension'
        else: return 'normal'
    elif age < 10:
        if sbp < 70 + 2*age: return 'hypotension'
        else: return 'normal'
    else:
        if sbp < 90: return 'hypotension'
        else: return 'normal'



In [6]:
def parse_datetime(date_col, time_col):
    dt = pd.to_datetime(date_col + ' ' + time_col, errors='coerce')
    return dt

india['admit_time'] = parse_datetime(india['doa'], india['toa'])
india['discharge_time'] = parse_datetime(india['dodd'], india['todd'])
india['los'] = (india['discharge_time'] - india['admit_time']).dt.total_seconds() / 3600
india['los'] = india['los'].fillna(0)
india['event'] = (india['died'] == 'Yes').astype(int)

jordan['los'] = pd.to_numeric(jordan['length of stay in the hospital (in days)'], errors='coerce')
jordan['event'] = (jordan['outcome'] == 'died').astype(int)

india['sex'] = india['sex'].map({'Male':1,'Female':0})
jordan['Gender'] = jordan['Gender'].map({'Male':1,'Female':0})

features_india = india[['age','sex','sbp_1','hr_1','rr_1','gcs_t_1']].copy()
features_jordan = jordan[['age of diagnosis','Gender','ER-HR','ER-RR','ER-systolic BP','GCS in ER']].copy()
features_jordan.columns = ['age','sex','hr','rr','sbp','gcs']
features_india.columns = ['age','sex','sbp','hr','rr','gcs']


  dt = pd.to_datetime(date_col + ' ' + time_col, errors='coerce')
  dt = pd.to_datetime(date_col + ' ' + time_col, errors='coerce')
  dt = pd.to_datetime(date_col + ' ' + time_col, errors='coerce')


In [None]:
# Ensure numeric conversion for Indian vital signs before categorization
features_india['hr'] = pd.to_numeric(features_india['hr'], errors='coerce')
features_india['rr'] = pd.to_numeric(features_india['rr'], errors='coerce')
features_india['sbp'] = pd.to_numeric(features_india['sbp'], errors='coerce')
features_india['age'] = pd.to_numeric(features_india['age'], errors='coerce')

# Now apply the categorization functions
features_india['hr_cat'] = [hr_category(a, h) for a, h in zip(features_india['age'], features_india['hr'])]
features_india['rr_cat'] = [rr_category(a, r) for a, r in zip(features_india['age'], features_india['rr'])]
features_india['sbp_cat'] = [sbp_category(a, s) for a, s in zip(features_india['age'], features_india['sbp'])]

# For Jordan, use the existing categorical columns (assumed to be ER-HR, ER-RR, ER-systolic BP)
features_jordan['hr_cat'] = features_jordan['hr']
features_jordan['rr_cat'] = features_jordan['rr']
features_jordan['sbp_cat'] = features_jordan['sbp']


# Use correct argument for OneHotEncoder for scikit-learn >= 1.2
from sklearn.preprocessing import OneHotEncoder

cat_cols = ['hr_cat', 'rr_cat', 'sbp_cat']
enc = OneHotEncoder(sparse_output=False, handle_unknown='ignore')

# Fit on combined categories to ensure same columns
enc.fit(pd.concat([features_india[cat_cols], features_jordan[cat_cols]], axis=0))

X_india_cat = enc.transform(features_india[cat_cols])
X_jordan_cat = enc.transform(features_jordan[cat_cols])

# Replace original columns with encoded
import numpy as np
X_india_final = np.concatenate([features_india[['age','sex','gcs']].fillna(0).values, X_india_cat], axis=1)
X_jordan_final = np.concatenate([features_jordan[['age','sex','gcs']].fillna(0).values, X_jordan_cat], axis=1)

# Update y as before
y_india = np.array([(bool(e), t) for e,t in zip(india['event'], india['los']/24)], dtype=[('event', bool), ('time', float)])
y_jordan = np.array([(bool(e), t) for e,t in zip(jordan['event'], jordan['los'])], dtype=[('event', bool), ('time', float)])


TypeError: OneHotEncoder.__init__() got an unexpected keyword argument 'sparse'

## Exploratory analysis

In [None]:
print('India rows:', len(india))
print('Jordan rows:', len(jordan))

print('Indian LOS summary (hours):')
print(india['los'].describe())
print('Jordan LOS summary (days):')
print(jordan['los'].describe())

fig, axes = plt.subplots(1,2, figsize=(12,4))
axes[0].hist(india['los']/24, bins=30, color='skyblue')
axes[0].set_title('India LOS (days)')
axes[1].hist(jordan['los'], bins=15, color='salmon')
axes[1].set_title('Jordan LOS (days)')
plt.show()


In [None]:
kmf = KaplanMeierFitter()
kmf.fit(durations=india['los']/24, event_observed=india['event'], label='India')
ax = kmf.plot()
kmf.fit(durations=jordan['los'], event_observed=jordan['event'], label='Jordan')
kmf.plot(ax=ax)
ax.set_xlabel('Time (days)')
ax.set_ylabel('Survival probability')
plt.show()


## Baseline Cox model on Jordan data

In [None]:
baseline_data = features_jordan.copy()
baseline_data['duration'] = jordan['los']
baseline_data['event'] = jordan['event']

cph = CoxPHFitter()
cph.fit(baseline_data, duration_col='duration', event_col='event')
print(cph.summary)
print('Baseline C-index:', concordance_index(baseline_data['duration'], -cph.predict_partial_hazard(baseline_data), baseline_data['event']))


## Transfer learning with gradient boosting

In [None]:
X_india = features_india.fillna(features_india.median())
y_india = np.array([(bool(e), t) for e,t in zip(india['event'], india['los']/24)], dtype=[('event', bool), ('time', float)])

X_jordan = features_jordan.fillna(features_jordan.median())
y_jordan = np.array([(bool(e), t) for e,t in zip(jordan['event'], jordan['los'])], dtype=[('event', bool), ('time', float)])

# baseline without transfer
gb_baseline = GradientBoostingSurvivalAnalysis(random_state=0)
gb_baseline.fit(X_jordan, y_jordan)
base_cindex = gb_baseline.score(X_jordan, y_jordan)
print('Baseline gradient boosting C-index:', base_cindex)

# pretrain on India then fine-tune on Jordan
gb_transfer = GradientBoostingSurvivalAnalysis(random_state=0, n_estimators=100, warm_start=True)
gb_transfer.fit(X_india, y_india)
# add more estimators for fine-tuning
gb_transfer.set_params(n_estimators=150)
gb_transfer.fit(X_jordan, y_jordan)
transfer_cindex = gb_transfer.score(X_jordan, y_jordan)
print('Transfer learning C-index:', transfer_cindex)


In [None]:
plt.step(*gb_transfer.predict_survival_function(X_jordan.iloc[:5]).T)
plt.xlabel('Time (days)')
plt.ylabel('Survival probability')
plt.title('Example predicted curves (Jordan)')
plt.show()


In [None]:
import pyreadstat

china_file = 'trauma_china.sav'
china_df, meta = pyreadstat.read_sav(china_file)
china_df.to_csv('trauma_china.csv', index=False)
print('Saved trauma_china.csv')