In [1]:
# select_best_model.ipynb

import mlflow
import mlflow.pytorch
from mlflow.tracking import MlflowClient
import pandas as pd
import torch
from torchvision import datasets, transforms
import os

# Step 1: Set tracking URI and define experiment names
mlflow.set_tracking_uri("http://localhost:5000")
print("Tracking URI:", mlflow.get_tracking_uri())

# The names of your experiments
experiments = ["FashionMNIST-MLP", "FashionMNIST-CNN"]

# Step 2: Find the best run across specified experiments
all_runs_data = []
client = MlflowClient()

for exp_name in experiments:
    exp = client.get_experiment_by_name(exp_name)
    if exp:
        # Search for runs, ordering by 'test_accuracy' in descending order
        runs = client.search_runs(
            experiment_ids=[exp.experiment_id],
            order_by=["metrics.test_accuracy DESC"]
        )
        for run in runs:
            if "test_accuracy" in run.data.metrics:
                all_runs_data.append({
                    "run_id": run.info.run_id,
                    "experiment_name": exp_name,
                    "test_accuracy": run.data.metrics["test_accuracy"],
                    "artifact_uri": run.info.artifact_uri
                })
    else:
        print(f"Warning: Experiment '{exp_name}' not found. Skipping.")


  import pkg_resources  # noqa: TID251


Tracking URI: http://localhost:5000


In [2]:
# Step 3: Convert runs to DataFrame and select best
df = pd.DataFrame(all_runs_data)
if df.empty:
    raise ValueError("No runs with 'test_accuracy' found in specified experiments.")

best_run = df.sort_values("test_accuracy", ascending=False).iloc[0]
print(f"✅ Best model found from experiment: {best_run['experiment_name']}")
print(f"🔖 Best Run ID: {best_run['run_id']}")
print(f"🎯 Best Test Accuracy: {best_run['test_accuracy']:.4f}")
print(f"📦 Model Artifact URI for registration: runs:/{best_run['run_id']}/model")

# Step 4: Register and promote the best model
model_name = "FashionMNIST-BestModel"
model_uri = f"runs:/{best_run['run_id']}/model"

try:
    # Check if model already exists
    try:
        client.get_registered_model(model_name)
    except Exception:
        client.create_registered_model(model_name)

    # Register new version
    result = mlflow.register_model(model_uri=model_uri, name=model_name)
    print(f"📥 Registered new model version: {result.version}")

    # Promote to Production (overwrite existing if any)
    client.transition_model_version_stage(
        name=model_name,
        version=result.version,
        stage="Production",
        archive_existing_versions=True
    )
    print(f"🚀 Promoted version {result.version} of {model_name} to Production.")

except Exception as e:
    print("Error during model registration or promotion:", e)


Registered model 'FashionMNIST-BestModel' already exists. Creating a new version of this model...
2025/07/26 10:20:24 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: FashionMNIST-BestModel, version 1


✅ Best model found from experiment: FashionMNIST-CNN
🔖 Best Run ID: 9166d159b653492f94e82121918e684f
🎯 Best Test Accuracy: 0.9103
📦 Model Artifact URI for registration: runs:/9166d159b653492f94e82121918e684f/model
📥 Registered new model version: 1
🚀 Promoted version 1 of FashionMNIST-BestModel to Production.


Created version '1' of model 'FashionMNIST-BestModel'.
  client.transition_model_version_stage(


In [3]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

# Step 5: Load production model
try:
    prod_model_uri = f"models:/{model_name}/Production"
    model = mlflow.pytorch.load_model(prod_model_uri)
    model.eval()
    print(f"✅ Loaded model from URI: {prod_model_uri}")
except Exception as e:
    print(f"❌ Error loading production model or performing inference: {e}")
    raise

# Step 6: Run inference on 1 batch of test data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

test_dataset = datasets.FashionMNIST(root="./data", train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

sample = next(iter(test_loader))
x, y_true = sample
with torch.no_grad():
    logits = model(x)
    probs = F.softmax(logits, dim=1)
    pred_class = torch.argmax(probs, dim=1)

print(f"👕 True label: {y_true.item()}, 🔍 Predicted: {pred_class.item()}")


  latest = client.get_latest_versions(name, None if stage is None else [stage])


✅ Loaded model from URI: models:/FashionMNIST-BestModel/Production
👕 True label: 5, 🔍 Predicted: 5
