# Model Promotion & Rollback Demo

This notebook demonstrates ZenML's model promotion workflow:

1. **View model versions** - See all versions and their stages
2. **Promote models** - Move models through staging ‚Üí production
3. **Rollback** - Revert to previous versions when needed
4. **Audit trail** - Track who promoted what and when

In [None]:
from zenml.client import Client
from zenml.enums import ModelStages

client = Client()
MODEL_NAME = "breast_cancer_classifier"

print(f"Connected to ZenML: {client.active_stack_model.name}")

## 1. View All Model Versions

See all versions and their current stages.

In [None]:
# List all model versions
versions = client.list_model_versions(model_name_or_id=MODEL_NAME)

print(f"üì¶ Model: {MODEL_NAME}")
print(f"   Total versions: {len(versions)}")
print("\n" + "=" * 60)

for v in versions:
    stage_emoji = {
        "production": "üöÄ",
        "staging": "üî¨",
        "archived": "üì¶",
    }.get(str(v.stage).lower(), "  ")
    
    print(f"{stage_emoji} Version {v.number}")
    print(f"   Stage: {v.stage or 'None'}")
    print(f"   Created: {v.created}")
    print("-" * 60)

## 2. Check Current Production & Staging Models

In [None]:
def get_model_by_stage(stage):
    """Safely get model version by stage."""
    try:
        return client.get_model_version(MODEL_NAME, stage)
    except KeyError:
        return None

production = get_model_by_stage(ModelStages.PRODUCTION)
staging = get_model_by_stage(ModelStages.STAGING)

print("Current Model Stages:")
print("=" * 40)

if production:
    print(f"üöÄ Production: v{production.number}")
    print(f"   Created: {production.created}")
else:
    print("üöÄ Production: (none)")

print()

if staging:
    print(f"üî¨ Staging: v{staging.number}")
    print(f"   Created: {staging.created}")
else:
    print("üî¨ Staging: (none)")

## 3. View Model Metrics Before Promotion

Always check metrics before promoting to production.

In [None]:
def show_metrics(model_version, title):
    """Display metrics for a model version."""
    print(f"\nüìä {title} (v{model_version.number})")
    print("-" * 40)
    
    metrics = model_version.run_metadata
    key_metrics = ["accuracy", "precision", "recall", "f1"]
    
    for key in key_metrics:
        if key in metrics:
            val = metrics[key].value if hasattr(metrics[key], "value") else metrics[key]
            status = "‚úÖ" if float(val) >= 0.8 else "‚ö†Ô∏è"
            print(f"   {status} {key}: {float(val):.4f}")

if staging:
    show_metrics(staging, "Staging Model")

if production:
    show_metrics(production, "Production Model")

## 4. Promote a Model to Staging

First step: promote latest version to staging for validation.

In [None]:
# Get latest model version
latest = client.get_model_version(MODEL_NAME, ModelStages.LATEST)
print(f"Latest version: v{latest.number} (stage: {latest.stage or 'None'})")

# Uncomment to promote to staging:
# latest.set_stage(ModelStages.STAGING, force=True)
# print(f"\n‚úÖ Promoted v{latest.number} to staging!")

## 5. Promote Staging to Production

After validation, promote staging model to production.

In [None]:
# Production promotion thresholds
PROD_THRESHOLDS = {
    "accuracy": 0.80,
    "precision": 0.80,
    "recall": 0.80,
}

def validate_for_production(model_version):
    """Check if model meets production thresholds."""
    metrics = model_version.run_metadata
    failures = []
    
    for metric, threshold in PROD_THRESHOLDS.items():
        if metric in metrics:
            val = metrics[metric].value if hasattr(metrics[metric], "value") else metrics[metric]
            if float(val) < threshold:
                failures.append(f"{metric}: {float(val):.3f} < {threshold}")
    
    return len(failures) == 0, failures

if staging:
    valid, failures = validate_for_production(staging)
    
    if valid:
        print(f"‚úÖ Staging v{staging.number} PASSES production validation!")
        print("\n   Ready to promote. Uncomment below to proceed.")
        
        # Uncomment to promote:
        # staging.set_stage(ModelStages.PRODUCTION, force=True)
        # print(f"\nüöÄ Promoted v{staging.number} to production!")
    else:
        print(f"‚ùå Staging v{staging.number} FAILS production validation:")
        for f in failures:
            print(f"   - {f}")
else:
    print("‚ö†Ô∏è  No staging model to promote")

## 6. Rollback Production Model

If the new production model underperforms, rollback to previous version.

In [None]:
def find_previous_production(current_prod_number):
    """Find the version that was previously in production."""
    versions = client.list_model_versions(
        model_name_or_id=MODEL_NAME,
        sort_by="desc:number"
    )
    
    for v in versions:
        if v.number < current_prod_number:
            # Previous version (could be archived or no stage)
            return v
    return None

if production:
    previous = find_previous_production(production.number)
    
    print("üîÑ Rollback Plan:")
    print(f"   Current production: v{production.number}")
    
    if previous:
        print(f"   Rollback target: v{previous.number}")
        print("\n   To rollback, uncomment below:")
        
        # Uncomment to rollback:
        # production.set_stage(ModelStages.ARCHIVED, force=True)
        # previous.set_stage(ModelStages.PRODUCTION, force=True)
        # print(f"\n‚úÖ Rolled back to v{previous.number}!")
    else:
        print("   ‚ö†Ô∏è  No previous version to rollback to")
else:
    print("‚ö†Ô∏è  No production model to rollback")

## 7. Using the Rollback Script

For production use, prefer the CLI script which has more safety checks.

In [None]:
print("CLI commands for model management:")
print("=" * 50)
print()
print("# Promote to staging")
print(f"python scripts/promote_model.py --model {MODEL_NAME} --to-stage staging")
print()
print("# Promote staging to production")
print(f"python scripts/promote_model.py --model {MODEL_NAME} --from-stage staging --to-stage production")
print()
print("# Rollback production (dry-run first)")
print(f"python scripts/rollback_model.py --model {MODEL_NAME} --dry-run")
print(f"python scripts/rollback_model.py --model {MODEL_NAME}")
print()
print("# Rollback to specific version")
print(f"python scripts/rollback_model.py --model {MODEL_NAME} --to-version 3")

## 8. Promotion Audit Trail

All promotions are logged for compliance.

In [None]:
print("üìú Model Stage History:")
print("=" * 60)

versions = client.list_model_versions(
    model_name_or_id=MODEL_NAME,
    sort_by="desc:updated"
)

for v in versions[:5]:  # Show last 5
    print(f"\nVersion {v.number}:")
    print(f"   Current stage: {v.stage or 'None'}")
    print(f"   Created: {v.created}")
    print(f"   Updated: {v.updated}")
    
    # Check for rollback metadata
    if "rollback_event" in v.run_metadata:
        rollback = v.run_metadata["rollback_event"]
        print(f"   ‚ö†Ô∏è  ROLLBACK from v{rollback.get('from_version')}")
        print(f"      Reason: {rollback.get('reason', 'Not specified')}")

## Summary

This notebook demonstrated:

‚úÖ **View model versions** - See all versions and their stages  
‚úÖ **Check metrics** - Validate before promotion  
‚úÖ **Promote models** - staging ‚Üí production workflow  
‚úÖ **Rollback** - Revert when models underperform  
‚úÖ **Audit trail** - Track all stage changes  

### Production Best Practices

1. **Always validate in staging first** - Never skip staging
2. **Use thresholds** - Automate go/no-go decisions
3. **Dry-run rollbacks** - Test before executing
4. **Document reasons** - Add rollback reasons for audit
5. **Use CLI scripts** - They have more safety checks than direct API calls