In [None]:
# Cell 1: Setup and Imports
# ===========================
import numpy as np
import pandas as pd
import xgboost as xgb
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from snowflake.ml.model import custom_model
from snowflake.ml.registry import Registry

# We will use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session

# Get the active Snowpark session (will raise if none is available).
# When imported into Snowsight an active session will be present.
session = get_active_session()

# Configuration: for Snowsight imports we use the active session's current
# database and schema so there are no hard-coded names. Registry actions
# will run by default when the notebook is executed in Snowsight.
USE_REGISTRY = True
# Obtain the current database and schema from the active Snowpark session
# This ensures the notebook adapts to whichever account/schema the user has selected in Snowsight.
current = session.sql("select current_database() as db, current_schema() as schema").collect()
if len(current) > 0:
    DB_NAME = current[0][0]
    SCHEMA_NAME = current[0][1]
else:
    # Fallback to generic names if the query fails for any reason
    DB_NAME = 'MODELS'
    SCHEMA_NAME = 'XGB_MODELS'

print(f'USE_REGISTRY={USE_REGISTRY}, DB={DB_NAME}, SCHEMA={SCHEMA_NAME}')

In [None]:
# Cell 2: Step 1 - Generate a Dataset with a Dominant Hidden Signal
# ===================================================================
print("Generating a synthetic dataset where base_margin is critical...")

# Define a single input feature that our models will be allowed to see.
X_feature = np.linspace(start=-10, stop=10, num=10000)

# 1. Create a simple, learnable signal based on the input feature.
#    This is the part of the problem that a standard model CAN learn.
y_known_signal = X_feature**2 * 2  # A simple parabola

# 2. Create a dominant "hidden" signal that is INDEPENDENT of the input X_feature.
#    A model cannot learn this pattern from X_feature alone; it's the "secret key".
hidden_indices = np.linspace(0, 100, 10000) # An independent axis for the hidden signal
y_hidden_signal = np.sin(hidden_indices) * 1000 # A sine wave

# 3. Combine the signals and add random noise to create the final target value, y.
#    The value of y is mostly determined by the hidden signal.
noise = np.random.normal(0, 20, 10000)
y = y_known_signal + y_hidden_signal + noise

# 4. Create the final feature DataFrame 'X' for the model.
#    CRUCIALLY, it ONLY contains the 'X_feature' and has no information about the hidden signal.
X = pd.DataFrame({'X_feature': X_feature})

# 5. Split all data arrays together.
#    This is critical to ensure that X_train, y_train, and hidden_train all correspond
#    to the same rows, which allows us to use `hidden_train` as the base_margin later.
X_train, X_test, y_train, y_test, hidden_train, hidden_test = train_test_split(
    X, y, y_hidden_signal, test_size=0.2, random_state=42
)

print("\nNew synthetic data generated and split successfully.")

In [None]:
# Cell 3: Step 2 & 3 - Train and Evaluate Both Models
# =====================================================

# --- Model 1: Standard XGBoost ---
print("\n--- Training Model 1: Standard XGBoost Regressor ---")
model_1 = xgb.XGBRegressor(random_state=42)
model_1.fit(X_train, y_train)
y_pred_1 = model_1.predict(X_test)
mse_1 = mean_squared_error(y_test, y_pred_1)

print("\n--- Performance of Model 1 (Standard XGBoost) ---")
print(f"Mean Squared Error (MSE): {mse_1:.4f}  <-- Very high, as it cannot predict the dominant hidden signal.")

# --- Model 2: XGBoost with base_margin ---
print("\n--- Training Model 2: XGBoost Regressor with base_margin ---")

# The base_margin IS the hidden signal. This gives the model the "secret key".
model_2 = xgb.XGBRegressor(random_state=42)
model_2.fit(X_train, y_train, base_margin=hidden_train)
y_pred_2 = model_2.predict(X_test, base_margin=hidden_test)
mse_2 = mean_squared_error(y_test, y_pred_2)

print("\n--- Performance of Model 2 (XGBoost with base_margin) ---")
print(f"Mean Squared Error (MSE): {mse_2:.4f}  <-- Much better, as it only needs to learn the simple remaining signal.")

In [None]:
# Cell 4: Step 4 - Visualize and Compare the Dramatic Difference
# ================================================================
print("\nGenerating visualization...")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8), sharey=True)
fig.suptitle("Test Set Performance: Demonstrating the Power of base_margin", fontsize=18)

# Sort values for clean line plotting
sorted_indices = X_test['X_feature'].argsort()
X_test_sorted = X_test.iloc[sorted_indices]
y_test_sorted = y_test[sorted_indices]
y_pred_1_sorted = y_pred_1[sorted_indices]
y_pred_2_sorted = y_pred_2[sorted_indices]

# Plot for Model 1 (Standard)
# Its prediction is a simple parabola, missing the huge variance from the hidden signal.
ax1.scatter(X_test_sorted['X_feature'], y_test_sorted, s=5, alpha=0.2, label='Actual Values')
ax1.plot(X_test_sorted['X_feature'], y_pred_1_sorted, color='red', linewidth=1, label='Predicted Values')
ax1.set_title(f"Model 1: Standard XGBoost \nMSE = {mse_1:.4f}")
ax1.set_xlabel("Feature")
ax1.set_ylabel("Target")
ax1.legend()
ax1.grid(True, linestyle='--', alpha=0.6)

# Plot for Model 2 (with base_margin)
# Its prediction perfectly traces the actual data because it was given the hidden signal.
ax2.scatter(X_test_sorted['X_feature'], y_test_sorted, s=5, alpha=0.2, label='Actual Values')
ax2.plot(X_test_sorted['X_feature'], y_pred_2_sorted, color='green', linewidth=1, label='Predicted Values')
ax2.set_title(f"Model 2: XGBoost with base_margin does better\nMSE = {mse_2:.4f}")
ax2.set_xlabel("Feature")
ax2.legend()
ax2.grid(True, linestyle='--', alpha=0.6)

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

In [None]:
# Cell 4: Step 4 - Define and Log the Custom Model
# ==============================================================
from snowflake.ml.model import custom_model
from snowflake.ml.registry import Registry

# --- 3a. Define the ModelContext with the in-memory model ---
# We pass the trained model directly as a keyword argument.
model_context = custom_model.ModelContext(
    xgb_model_for_inference = model_2
)

# --- 4b. Define the Custom Model Class ---
# This class acts as a wrapper that knows how to handle the base_margin input.
class PassThroughMarginModel(custom_model.CustomModel):
    def __init__(self, context: custom_model.ModelContext) -> None:
        super().__init__(context)
        self.model = self.context["xgb_model_for_inference"]

    @custom_model.inference_api
    def predict(self, X: pd.DataFrame) -> pd.DataFrame:
        # Separate the base_margin column from the actual features
        base_margin = X['BASE_MARGIN']
        model_features = X.drop(columns=['BASE_MARGIN'])
        
        # Pass both to the underlying model's predict function
        predictions = self.model.predict(model_features, base_margin=base_margin)
        
        return pd.DataFrame({'PREDICTION': predictions})

# --- 4c. Prepare for Logging ---
# Create a sample input DataFrame that includes the BASE_MARGIN column
X_train_with_margin = X_train.copy()
X_train_with_margin['BASE_MARGIN'] = hidden_train

# Initialize registry handle using the active session's DB/schema
# This ensures `registry` is defined before we check it below (safe for top-to-bottom runs).
registry = None
if USE_REGISTRY:
    registry = Registry(session=session, database_name=DB_NAME, schema_name=SCHEMA_NAME)

# --- 4d. Log the Model (conditional) ---
if USE_REGISTRY and registry is not None:
    print("\nLogging the custom model that accepts base_margin...")
    custom_model_to_log = PassThroughMarginModel(context=model_context)

    model_version = registry.log_model(
        model=custom_model_to_log,
        model_name="XGB_WITH_PASSTHROUGH_MARGIN",
        version_name="v1",
        comment="Custom XGBoost model that accepts a base_margin column during inference.",
        conda_dependencies=["scikit-learn", "xgboost", "pandas"],
        sample_input_data=X_train_with_margin.head(100),
        options={'relax_version': False}
    )
    print("Custom model logged successfully!")
else:
    print("USE_REGISTRY is False or registry unavailable - skipping model logging.")

In [None]:
# Cell 5: Make Predictions with the Logged Model in Python
# ==========================================================
import pandas as pd
from snowflake.ml.registry import Registry
from snowflake.snowpark import Session

# Assume 'session' is your active Snowpark session.
# session = Session.builder.getOrCreate() # Or get_active_session() in a notebook

# Define your registry's database and schema (only used if USE_REGISTRY=True)
if USE_REGISTRY:
    registry = Registry(session=session, database_name=DB_NAME, schema_name=SCHEMA_NAME)

# --- Step 1: Reference the specific model version ---
if USE_REGISTRY:
    # We can either use the `model_version` object directly from the previous cell
    # or look it up like this:
    model_name = "XGB_WITH_PASSTHROUGH_MARGIN"
    version_name = "v1" # Use the version we might have logged
    try:
        retrieved_model_version = registry.get_model(model_name).version(version_name)
        print(f"Retrieved model: {retrieved_model_version.model_name}, Version: {retrieved_model_version.version_name}")
    except Exception as e:
        retrieved_model_version = None
        print(f"Could not retrieve model from registry: {e}.")
else:
    retrieved_model_version = None
    print("USE_REGISTRY is False - skipping registry retrieval and inference via registry.")

# --- Step 2: Prepare the input data ---
# Our input DataFrame MUST have columns with the exact names the model was trained on,
# including the special 'BASE_MARGIN' column. We already have X_test/hidden_test in memory.
X_test_with_margin = X_test.copy()
X_test_with_margin['BASE_MARGIN'] = hidden_test
X_test_with_margin = X_test_with_margin.reset_index(drop=True)

# Convert the Pandas DataFrame to a Snowpark DataFrame if we will call the registry
inference_data_snowpark = None
if USE_REGISTRY and retrieved_model_version is not None:
    inference_data_snowpark = session.create_dataframe(X_test_with_margin)

# --- Step 3: Run prediction (via registry) ---
if USE_REGISTRY and retrieved_model_version is not None and inference_data_snowpark is not None:
    print("\nRunning inference via Snowpark Python API...")
    predictions_df = retrieved_model_version.run(inference_data_snowpark)
    # Show the first few predictions
    predictions_df.show()
else:
    print("Registry-based inference was not executed. If you want to test registry inference, set USE_REGISTRY=1 and ensure the model is logged.")