# MLflow Results Analysis

This notebook fetches all trials from MLflow experiment 413909140794424369 for analysis.

## 1. Setup and Connection

In [None]:
import warnings

import mlflow
import pandas as pd
import pydash as pyd
import seaborn as sns
from mlflow import MlflowClient

warnings.filterwarnings("ignore")

# MLflow configuration
TRACKING_URI = "http://127.0.0.1:5001"
EXPERIMENT_ID = "413909140794424369"

print(f"Connecting to MLflow server at: {TRACKING_URI}")
print(f"Target experiment ID: {EXPERIMENT_ID}")

In [None]:
# Set tracking URI and create client
mlflow.set_tracking_uri(TRACKING_URI)
client = MlflowClient()

# Validate experiment exists
try:
    experiment = client.get_experiment(EXPERIMENT_ID)
    print(f"✓ Connected to experiment: {experiment.name}")
    print(f"  Experiment ID: {experiment.experiment_id}")
    print(f"  Lifecycle stage: {experiment.lifecycle_stage}")
    print(f"  Artifact location: {experiment.artifact_location}")
except Exception as e:
    print(f"✗ Error accessing experiment: {e}")
    raise

## 2. Data Extraction

In [None]:
# Fetch all runs from the experiment
print("Fetching all runs from the experiment...")
runs = client.search_runs(
    experiment_ids=[EXPERIMENT_ID],
    max_results=10000,  # Adjust if you expect more runs
)

print(f"Found {len(runs)} runs in the experiment")

if len(runs) == 0:
    print("No runs found in this experiment.")
else:
    print(f"Run status breakdown:")
    status_counts = {}
    for run in runs:
        status = run.info.status
        status_counts[status] = status_counts.get(status, 0) + 1

    for status, count in status_counts.items():
        print(f"  {status}: {count} runs")

In [None]:
rows = pyd.map_(
    runs,
    lambda r: {
        "group_id": pyd.get(r, "data.tags.group_id"),
        "test_roc_auc": pyd.get(r, "data.metrics.test_roc_auc"),
        "use_batchnorm": pyd.get(r, "data.params.model/post_mp_layer/use_batchnorm"),
        "run_status": pyd.get(r, "info.status"),
    },
)
df = pd.DataFrame(rows)
df = df.sort_values(by=["group_id", "use_batchnorm"])
df["rank"] = df.groupby("group_id")["test_roc_auc"].rank(ascending=False)

In [None]:
# filter out groups in which some runs failed
df = df.groupby("group_id").filter(lambda x: (x["run_status"] == "FINISHED").all())

In [None]:
df.groupby("use_batchnorm")["rank"].mean()

In [None]:
sns.violinplot(df, hue="use_batchnorm", y="rank")