In [None]:
import pickle
import pandas as pd
import numpy as np
import xgboost as xgb
from sklearn.metrics import classification_report, confusion_matrix, balanced_accuracy_score
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.model_selection import GridSearchCV, StratifiedKFold


df_mine = pd.read_csv("/Users/utkarshbansal/Desktop/APLASIA/final.csv")

# 2. Check for columns
target_col = 'target_first_episode_duration'
id_col_name = 'chemo_hadm_id'

if target_col not in df_mine.columns:
    raise ValueError(f"Column '{target_col}' not found")


bins = [-1, 9, 20, np.inf]
labels = [0, 1, 2]
df_mine['target_class'] = pd.cut(df_mine[target_col], bins=bins, labels=labels).astype(int)


print(" STARTING GLOBAL GRID SEARCH ".center(60, '='))
print("Optimizing parameters on the entire dataset first...")

# A.1 Prepare Global Data
drop_cols_global = [id_col_name, target_col, 'target_class']
X_global = df_mine.drop(columns=drop_cols_global)
y_global = df_mine['target_class']

global_weights = compute_sample_weight(class_weight='balanced', y=y_global) ## to address class imbalance
## gives higher weight to minority class.

param_grid = {
    'max_depth': [3, 4, 5],           
    'learning_rate': [0.01, 0.05, 0.1], 
    'n_estimators': [100, 200],       
    'colsample_bytree': [0.5, 0.8],   
    'subsample': [0.8, 1.0]           
}

grid_search = GridSearchCV(
    estimator=xgb.XGBClassifier(
        objective='multi:softmax', 
        num_class=3, 
        random_state=42, 
        eval_metric='mlogloss'
    ),
    param_grid=param_grid,
    scoring='balanced_accuracy', 
    cv=StratifiedKFold(n_splits=3), 
    verbose=1,
    n_jobs=-1 
)

# A.5 Run the Search
grid_search.fit(X_global, y_global, sample_weight=global_weights)

best_params = grid_search.best_params_
print(f"\n>>> BEST PARAMETERS FOUND: {best_params}")
print(f">>> BEST BALANCED ACCURACY: {grid_search.best_score_:.4f}")
print("="*60 + "\n")


fold_files = [f'fold_{i}.pkl' for i in range(5)]
fold_scores = []
long_class_recalls = []

for fold_num, file_path in enumerate(fold_files):
    
    try:
        with open(file_path, 'rb') as f:
            original_train_ids, original_val_ids, original_test_ids = pickle.load(f)
    except FileNotFoundError:
        print(f"Skipping Fold {fold_num}: File {file_path} not found.")
        continue

    def filter_my_data(original_pairs):
        valid_hadm_ids = set([pair[1] for pair in original_pairs])
        return df_mine[df_mine[id_col_name].isin(valid_hadm_ids)]

    df_train = filter_my_data(original_train_ids)
    df_val   = filter_my_data(original_val_ids)
    df_test  = filter_my_data(original_test_ids)

    print(f" PROCESSING FOLD {fold_num} ".center(60, '#'))
    print(f"Samples -> Train: {len(df_train)}, Val: {len(df_val)}, Test: {len(df_test)}")

    drop_cols = [id_col_name, target_col, 'target_class']
    
    X_train = df_train.drop(columns=drop_cols)
    X_val   = df_val.drop(columns=drop_cols)
    X_test  = df_test.drop(columns=drop_cols)

    y_train = df_train['target_class']
    y_val   = df_val['target_class']
    y_test  = df_test['target_class']

    weights_train = compute_sample_weight(class_weight='balanced', y=y_train)

    model = xgb.XGBClassifier(
        **best_params,              
        objective='multi:softmax',
        num_class=3,
        random_state=42,
        eval_metric='mlogloss',
        early_stopping_rounds=20
    )

    # Training
    if len(X_train) > 0 and len(X_val) > 0:
        model.fit(
            X_train, y_train,
            sample_weight=weights_train,
            eval_set=[(X_val, y_val)],
            verbose=False
        )

        if len(X_test) > 0:
            y_pred = model.predict(X_test)
            
            bal_acc = balanced_accuracy_score(y_test, y_pred)
            fold_scores.append(bal_acc)
            
            # Detailed Report
            report = classification_report(y_test, y_pred, target_names=['Short', 'Medium', 'Long'], output_dict=True)
            
            recall_long = report['Long']['recall']
            long_class_recalls.append(recall_long)
            
            print(f"Fold {fold_num} Balanced Accuracy: {bal_acc:.4f}")
            print(f"Fold {fold_num} Long Class Recall:   {recall_long:.4f}")

            fold_results = pd.DataFrame({
                'chemo_hadm_id': df_test[id_col_name],
                'Actual_Class': y_test.values,
                'Predicted_Class': y_pred,
                'Actual_Days': df_test[target_col].values
            })
            
            missed_long = fold_results[(fold_results['Actual_Class'] == 2) & (fold_results['Predicted_Class'] != 2)]
            
            if not missed_long.empty:
                print(f"\n>>> Fold {fold_num} MISSED {len(missed_long)} 'LONG' CASES:")
                print(missed_long.to_string(index=False))
            else:
                print("\n>>> Excellent! No 'Long' cases were missed in this fold.")

            print("\nConfusion Matrix:")
            print(confusion_matrix(y_test, y_pred))
            print("-" * 40 + "\n")
            
        else:
            print(f"Fold {fold_num}: No test samples found.")
    else:
        print(f"Fold {fold_num}: Not enough train/val samples.")

print("\n" + "="*60)
if len(fold_scores) > 0:
    print(f"FINAL AVERAGE BALANCED ACCURACY: {np.mean(fold_scores):.4f}")
    print(f"FINAL AVERAGE LONG CLASS RECALL: {np.mean(long_class_recalls):.4f}")
print("="*60)

Optimizing parameters on the entire dataset first...
Fitting 3 folds for each of 72 candidates, totalling 216 fits

>>> BEST PARAMETERS FOUND: {'colsample_bytree': 0.5, 'learning_rate': 0.01, 'max_depth': 3, 'n_estimators': 200, 'subsample': 0.8}
>>> BEST BALANCED ACCURACY: 0.5941

#################### PROCESSING FOLD 0 #####################
Samples -> Train: 733, Val: 79, Test: 198
Fold 0 Balanced Accuracy: 0.5825
Fold 0 Long Class Recall:   0.2500

>>> Fold 0 MISSED 3 'LONG' CASES:
 chemo_hadm_id  Actual_Class  Predicted_Class  Actual_Days
      21176832             2                1         41.0
      25841662             2                1         24.0
      27576498             2                0         25.0

Confusion Matrix:
[[156  19   1]
 [  6  11   1]
 [  1   2   1]]
----------------------------------------

#################### PROCESSING FOLD 1 #####################
Samples -> Train: 721, Val: 85, Test: 204
Fold 1 Balanced Accuracy: 0.5136
Fold 1 Long Class Recall:   0.25