# Model Training

### Setup MLFlow

In [None]:
import mlflow
import os

# Set tracking URI (relative to notebook location)
mlflow.set_tracking_uri("sqlite:///../mlflow.db")

# Set experiment name (update to your project name)
mlflow.set_experiment("YourProjectName")

### Import Libraries

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import joblib

from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, 
    roc_auc_score, confusion_matrix, classification_report
)

# Add your model imports here
# from sklearn.ensemble import RandomForestClassifier
# from xgboost import XGBClassifier
# from lightgbm import LGBMClassifier

### Load Preprocessed Data

In [None]:
TRAIN_DATA_PATH = '../data/processed/train_fe.csv'
TEST_DATA_PATH = '../data/processed/test_fe.csv'
TARGET_COL = 'target'  # Update to your target column name

# Load datasets
train_df = pd.read_csv(TRAIN_DATA_PATH)
test_df = pd.read_csv(TEST_DATA_PATH)

# Split features and target
X_train = train_df.drop(columns=[TARGET_COL])
y_train = train_df[TARGET_COL]

X_test = test_df.drop(columns=[TARGET_COL])
y_test = test_df[TARGET_COL]

print(f"Train shapes: {X_train.shape}, {y_train.shape}")
print(f"Test shapes:  {X_test.shape}, {y_test.shape}")

### Load Preprocessing Pipeline (Optional)

In [None]:
# Load preprocessing artifacts if needed for inference
# PREPROCESSING_PATH = '../models/preprocessing.joblib'
# preprocessing_pipeline = joblib.load(PREPROCESSING_PATH)
# print("Preprocessing pipeline loaded successfully")

### Define Model Hyperparameters

In [None]:
# Example: Update with your model-specific parameters
model_params = {
    "param1": "value1",
    "param2": "value2",
    "random_state": 42
}

# Data processing parameters for tracking
data_params = {
    "data_split_ratio": "80/20",
    "preprocessing_steps": "Describe your preprocessing steps",
    "feature_engineering": "Describe your FE approach"
}

### Model Training & Evaluation

In [None]:
with mlflow.start_run(run_name="EXP_01_BaselineModel"):
    
    # --- Log Parameters ---
    mlflow.log_params(model_params)
    mlflow.log_params(data_params)
    
    mlflow.log_param("input_rows", X_train.shape[0])
    mlflow.log_param("input_cols", X_train.shape[1])
    mlflow.log_param("column_names", X_train.columns.tolist())
    
    # Optional: Log notebooks as artifacts
    # mlflow.log_artifact("01_eda.ipynb", artifact_path="code_snapshot")
    # mlflow.log_artifact("02_preprocessing.ipynb", artifact_path="code_snapshot")
    
    # --- Train Model ---
    print("Training model...")
    # model = YourModel(**model_params)
    # model.fit(X_train, y_train)
    
    # --- Make Predictions ---
    # y_pred = model.predict(X_test)
    # y_prob = model.predict_proba(X_test)[:, 1]  # For binary classification
    
    # --- Log Metrics ---
    # metrics = {
    #     "accuracy": accuracy_score(y_test, y_pred),
    #     "precision": precision_score(y_test, y_pred, average='weighted', zero_division=0),
    #     "recall": recall_score(y_test, y_pred, average='weighted', zero_division=0),
    #     "f1_score": f1_score(y_test, y_pred, average='weighted', zero_division=0),
    #     "roc_auc": roc_auc_score(y_test, y_prob)  # For binary classification
    # }
    # mlflow.log_metrics(metrics)
    # print(f"Logged Metrics: {metrics}")
    
    # --- Log Model ---
    # mlflow.sklearn.log_model(model, "model")
    
    # --- Log Confusion Matrix ---
    # cm = confusion_matrix(y_test, y_pred)
    # fig = plt.figure(figsize=(8, 6))
    # sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    # plt.ylabel('Actual')
    # plt.xlabel('Predicted')
    # plt.title('Confusion Matrix')
    # mlflow.log_figure(fig, "confusion_matrix.png")
    # plt.close(fig)
    
    # --- Feature Importance (if applicable) ---
    # if hasattr(model, 'feature_importances_'):
    #     importance_df = pd.DataFrame({
    #         'feature': X_train.columns,
    #         'importance': model.feature_importances_
    #     }).sort_values('importance', ascending=False)
    #     
    #     fig, ax = plt.subplots(figsize=(10, 8))
    #     top_n = 20
    #     plot_data = importance_df.head(top_n).sort_values('importance', ascending=True)
    #     ax.barh(plot_data['feature'], plot_data['importance'])
    #     ax.set_xlabel('Importance')
    #     ax.set_title(f'Top {top_n} Feature Importances')
    #     mlflow.log_figure(fig, 'feature_importance.png')
    #     plt.close(fig)
    #     
    #     # Save importance data
    #     os.makedirs('../importance', exist_ok=True)
    #     importance_path = '../importance/feature_importance.csv'
    #     importance_df.to_csv(importance_path, index=False)
    #     mlflow.log_artifact(importance_path, artifact_path='feature_importance')
    
    print("=== Run Complete ===")
    print("Check MLflow UI for results: mlflow ui --backend-store-uri sqlite:///../mlflow.db")

### Model Comparison (Optional)

In [None]:
# Compare multiple models in separate runs
# models = {
#     "RandomForest": RandomForestClassifier(**rf_params),
#     "XGBoost": XGBClassifier(**xgb_params),
#     "LightGBM": LGBMClassifier(**lgbm_params)
# }
# 
# for model_name, model in models.items():
#     with mlflow.start_run(run_name=f"EXP_01_{model_name}"):
#         # Training and logging code here
#         pass