In [None]:
import mlflow
import mlflow.spark
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml import Pipeline

# Enable MLflow autologging
mlflow.spark.autolog()

# Load Gold from Unity Catalog
df_gold = spark.read.format("delta").table("main.default.gold_machine_features")

# Select features and label
features = ["temperature", "pressure", "vibration", "rpm", 
            "temp_roll_avg", "temp_diff", "pressure_roll_avg", "vibration_roll_avg"]

df_ml = df_gold.select("is_failure", *features).na.drop()

print(f"Total records: {df_ml.count()}")
print(f"Failure cases: {df_ml.filter('is_failure = 1').count()}")
print(f"Non-failure cases: {df_ml.filter('is_failure = 0').count()}")

# Assemble features
assembler = VectorAssembler(inputCols=features, outputCol="features")

# Classifier with explicit parameters
rf = RandomForestClassifier(
    labelCol="is_failure", 
    featuresCol="features",
    numTrees=100,
    maxDepth=10,
    seed=42
)

pipeline = Pipeline(stages=[assembler, rf])

# Train/test split
train, test = df_ml.randomSplit([0.8, 0.2], seed=42)
print(f"Training set: {train.count()}, Test set: {test.count()}")

# Start MLflow run
with mlflow.start_run(run_name="machine_failure_rf") as run:
    # Train
    model = pipeline.fit(train)
    
    # Predict
    predictions = model.transform(test)
    
    # Evaluate with multiple metrics
    auc_evaluator = BinaryClassificationEvaluator(labelCol="is_failure", metricName="areaUnderROC")
    auc = auc_evaluator.evaluate(predictions)
    
    accuracy_evaluator = MulticlassClassificationEvaluator(labelCol="is_failure", metricName="accuracy")
    accuracy = accuracy_evaluator.evaluate(predictions)
    
    f1_evaluator = MulticlassClassificationEvaluator(labelCol="is_failure", metricName="f1")
    f1 = f1_evaluator.evaluate(predictions)
    
    print(f"\nTest Metrics:")
    print(f"AUC: {auc:.4f}")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"F1 Score: {f1:.4f}")
    
    # Log additional metrics
    mlflow.log_metric("test_auc", auc)
    mlflow.log_metric("test_accuracy", accuracy)
    mlflow.log_metric("test_f1", f1)
    
    # Display sample predictions
    display(predictions.select("is_failure", "prediction", "probability", *features[:3]).limit(10))

print(f"\nMLflow Run ID: {run.info.run_id}")