In [0]:
import os
import subprocess

import mlflow
import mlflow.spark

from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.functions import vector_to_array
from mlflow.exceptions import MlflowException


import pandas as pd
import matplotlib.pyplot as plt

In [0]:
# ------------------------------------------------------------------------------
# MLflow experiment configuration
# ------------------------------------------------------------------------------

# High-level parameters for the trading setup
cutoff_date = "2020-01-01"   # train/test split
lookahead_days = 5           # matches the 5-day lookahead in feature engineering
long_threshold = 0.55        # P(up) threshold for long
short_threshold = 0.45       # P(up) threshold for short

feature_cols = ["sma_20", "std_20", "daily_return"]

In [0]:
# ------------------------------------------------------------------------------
# Data loading
# ------------------------------------------------------------------------------

df = spark.table("market.features_labeled")

df_clean = df.dropna(subset=feature_cols + ["label"])

train = df_clean.filter(F.col("Date") < cutoff_date)
test = df_clean.filter(F.col("Date") >= cutoff_date)

In [0]:
# ------------------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------------------

def compute_sharpe(pdf: pd.DataFrame, ret_col: str = "portfolio_return") -> float | None:
    """
    Compute annualized Sharpe ratio from a pandas DataFrame with a return column.
    Assumes daily returns.
    """
    mean_ret = pdf[ret_col].mean()
    std_ret = pdf[ret_col].std()

    if std_ret and std_ret != 0:
        return (mean_ret / std_ret) * (252 ** 0.5)
    return None


def backtest_from_predictions(preds_sel):
    """
    preds_sel: Spark DataFrame with columns:
      Date, symbol, label, prediction, probability, daily_return

    Returns a Spark DataFrame with:
      Date, portfolio_return, benchmark_return, cml_return, benchmark_cml_return
    """
    # Build trading strategy DataFrame
    strategy_df = (
        preds_sel
        .withColumn("prob_array", vector_to_array("probability"))
        .withColumn("p_up", F.col("prob_array")[1])  # P(class = 1)
        .withColumn(
            "position",
            F.when(F.col("p_up") > long_threshold, 1)
             .when(F.col("p_up") < short_threshold, -1)
             .otherwise(0)
        )
        .withColumn("strategy_return", F.col("position") * F.col("daily_return"))
    )

    # Aggregate to daily portfolio returns
    portfolio_df = (
        strategy_df
        .groupBy("Date")
        .agg(F.avg("strategy_return").alias("portfolio_return"))
        .orderBy("Date")
    )

    # Equal-weight benchmark on daily_return across all symbols
    benchmark_df = (
        df_clean
        .groupBy("Date")
        .agg(F.avg("daily_return").alias("benchmark_return"))
        .orderBy("Date")
    )

    # Combine strategy and benchmark, compute cumulative returns
    comparison_df = (
        portfolio_df.alias("p")
        .join(benchmark_df.alias("b"), "Date", "inner")
        .select(
            "Date",
            "portfolio_return",
            "benchmark_return",
        )
    )

    w_date = Window.orderBy("Date").rowsBetween(Window.unboundedPreceding, 0)

    comparison_df = (
        comparison_df
        .withColumn("cml_return", F.sum("portfolio_return").over(w_date))
        .withColumn("benchmark_cml_return", F.sum("benchmark_return").over(w_date))
    )

    return comparison_df


def get_git_sha() -> str | None:
    try:
        return (
            subprocess.check_output(["git", "rev-parse", "HEAD"])
            .decode("utf-8")
            .strip()
        )
    except Exception:
        return None


In [0]:
# ------------------------------------------------------------------------------
# Core MLflow experiment function
# ------------------------------------------------------------------------------

def run_experiment(
    reg_param: float,
    elastic_net_param: float,
    max_iter: int,
    registered_model_name: str = "ml_trading_lr_v1",
):
    """
    Run a single logistic regression experiment and log everything to MLflow.
    """

    git_sha = get_git_sha()

    with mlflow.start_run(
        run_name=f"logreg_r{reg_param}_en{elastic_net_param}_iter{max_iter}"
    ):
        # Tags and params for lineage
        mlflow.log_param("reg_param", reg_param)
        mlflow.log_param("elastic_net_param", elastic_net_param)
        mlflow.log_param("max_iter", max_iter)
        mlflow.log_param("lookahead_days", lookahead_days)
        mlflow.log_param("long_threshold", long_threshold)
        mlflow.log_param("short_threshold", short_threshold)
        mlflow.log_param("cutoff_date", cutoff_date)

        if git_sha:
            mlflow.set_tag("git_commit", git_sha)
        mlflow.set_tag("dataset_table", "market.features_labeled")
        mlflow.set_tag("model_type", "logistic_regression")
        mlflow.set_tag("project", "ml_trading_databricks")

        # Build feature assembler and model
        assembler = VectorAssembler(
            inputCols=feature_cols,
            outputCol="features"
        )

        lr = LogisticRegression(
            labelCol="label",
            featuresCol="features",
            regParam=reg_param,
            elasticNetParam=elastic_net_param,
            maxIter=max_iter,
        )

        dataset = assembler.transform(df_clean)

        train_local = dataset.filter(F.col("Date") < cutoff_date)
        test_local = dataset.filter(F.col("Date") >= cutoff_date)

        model = lr.fit(train_local)

        # Predictions
        preds = model.transform(test_local)

        preds_sel = preds.select(
            "Date",
            "symbol",
            "label",
            "prediction",
            "probability",
            "daily_return",
        )

        # Backtest
        comparison_df = backtest_from_predictions(preds_sel)
        pdf_compare = comparison_df.toPandas()

        # Metrics
        sharpe = compute_sharpe(pdf_compare, ret_col="portfolio_return")
        benchmark_sharpe = compute_sharpe(pdf_compare, ret_col="benchmark_return")
        mean_ret = pdf_compare["portfolio_return"].mean()
        std_ret = pdf_compare["portfolio_return"].std()

        mlflow.log_metric("sharpe", sharpe if sharpe is not None else float("nan"))
        mlflow.log_metric(
            "benchmark_sharpe",
            benchmark_sharpe if benchmark_sharpe is not None else float("nan"),
        )
        mlflow.log_metric("mean_daily_return", float(mean_ret))
        mlflow.log_metric("std_daily_return", float(std_ret))

        # Plot cumulative returns
        fig = plt.figure(figsize=(12, 6))
        plt.plot(pdf_compare["Date"], pdf_compare["cml_return"], label="ML strategy")
        plt.plot(
            pdf_compare["Date"],
            pdf_compare["benchmark_cml_return"],
            label="Benchmark",
        )
        plt.legend()
        plt.title("Cumulative returns: ML strategy vs. benchmark")
        plt.tight_layout()

        mlflow.log_figure(fig, "cumulative_returns.png")
        plt.close(fig)

        # Log and register the model (may fail on shared/serverless clusters)
        try:
            mlflow.spark.log_model(
                spark_model=model,
                artifact_path="model",
                registered_model_name=registered_model_name,
            )
        except MlflowException as e:
            print(
                "Skipping Spark model logging on this cluster. "
                "Params/metrics/plots are still logged to MLflow. "
                f"Reason: {e}"
            )


In [0]:
# ------------------------------------------------------------------------------
# Hyperparameter sweep
# ------------------------------------------------------------------------------

if __name__ == "__main__":
    configs = [
        {"reg_param": 0.0, "elastic_net_param": 0.0, "max_iter": 50},
        {"reg_param": 0.01, "elastic_net_param": 0.0, "max_iter": 50},
        {"reg_param": 0.1, "elastic_net_param": 0.0, "max_iter": 50},
        {"reg_param": 0.01, "elastic_net_param": 0.5, "max_iter": 100},
    ]

    for cfg in configs:
        run_experiment(**cfg)