In [2]:
# 02_supervised.ipynb

%load_ext autoreload
%autoreload 2

import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split

from setup_utils import load_intermediate_data, plr, save_intermediate_data

df = load_intermediate_data('data/intermediate_unsupervised.csv')

cols_vitals = ['heart_rate', 'resp_rate', 'mbp', 'temperature', 'spo2']
non_hr_vitals = [col for col in cols_vitals if col != 'heart_rate']

# Convert boolean outliers to int
df['kmeans_outlier'] = df['kmeans_outlier'].astype(int)
df['dbscan_outlier'] = df['dbscan_outlier'].astype(int)
df['umap_dbscan_outlier'] = df['umap_dbscan_outlier'].astype(int)

# Create an outlier score as example
df['outlier_score'] = df['kmeans_outlier'] + df['dbscan_outlier'] + df['umap_dbscan_outlier']

# One-hot encode outlier combination
df = pd.get_dummies(df, columns=['outlier_combination'], prefix='outlier_comb')

X = df[non_hr_vitals]  # Features
y = df['heart_rate']    # Target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

model = RandomForestRegressor(random_state=42, n_estimators=100)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

baseline_pred = [y_train.mean()] * len(y_test)
baseline_mse = mean_squared_error(y_test, baseline_pred)
model_mse = mean_squared_error(y_test, y_pred)
model_rmse = np.sqrt(model_mse)
model_r2 = r2_score(y_test, y_pred)

print("Random Forest Results:")
print(f"Baseline MSE: {baseline_mse:.2f}")
print(f"Model MSE: {model_mse:.2f}")
print(f"Model RMSE: {model_rmse:.2f}")
print(f"R-squared (R²): {model_r2:.2f}")

plr()

# (Optional) Save intermediate supervised data if needed
save_intermediate_data(df, 'data/intermediate_supervised.csv')


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Random Forest Results:
Baseline MSE: 237.34
Model MSE: 227.85
Model RMSE: 15.09
R-squared (R²): 0.04
@24/12/12 05:51:11 Eastern Standard Time
