In [None]:
"""
Train a classification model (Logistic Regression) with MLflow tracking.
Saves joblib model artifact and logs via MLflow to local mlruns/.
"""
import os
import joblib
import mlflow
import mlflow.sklearn
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, roc_auc_score
from src.preprocess import build_preprocessor
from src.utils import load_raw
from sklearn.model_selection import GridSearchCV

MLFLOW_URI = f"file://{os.path.abspath('mlruns')}"

def train(train_csv="data/processed/train.csv", val_csv="data/processed/val.csv", artifact_dir="models", run_name="logreg_baseline"):
    os.makedirs(artifact_dir, exist_ok=True)
    mlflow.set_tracking_uri(MLFLOW_URI)
    mlflow.set_experiment("placement_local_experiment")

    train_df = pd.read_csv(train_csv)
    val_df = pd.read_csv(val_csv)

    # Build preprocessor same as in preprocess stage (ensure consistency)
    preprocessor, num_feats, cat_feats = build_preprocessor()

    X_train = train_df[num_feats + cat_feats]
    y_train = train_df["PlacementStatus"]
    X_val = val_df[num_feats + cat_feats]
    y_val = val_df["PlacementStatus"]

    # pipeline
    clf = Pipeline([
        ("preprocessor", preprocessor),
        ("clf", LogisticRegression(max_iter=1000, solver="liblinear"))
    ])

    params = {
        "clf__C": [0.1, 1.0, 10.0],
        "clf__penalty": ["l2"]
    }

    with mlflow.start_run(run_name=run_name) as run:
        gs = GridSearchCV(clf, params, cv=3, scoring="roc_auc", n_jobs=-1)
        gs.fit(X_train, y_train)
        best = gs.best_estimator_
        y_pred = best.predict(X_val)
        y_proba = best.predict_proba(X_val)[:,1]
        acc = accuracy_score(y_val, y_pred)
        roc = roc_auc_score(y_val, y_proba)

        mlflow.log_params(gs.best_params_)
        mlflow.log_metric("val_accuracy", float(acc))
        mlflow.log_metric("val_roc_auc", float(roc))
        mlflow.sklearn.log_model(gs.best_estimator_, artifact_path="model")

        # save joblib (for FastAPI)
        joblib.dump(gs.best_estimator_, os.path.join(artifact_dir, "model.joblib"))
        print(f"Run ID: {run.info.run_id}, Val Accuracy: {acc:.4f}, Val ROC AUC: {roc:.4f}")

if __name__ == "__main__":
    train()