# Custom Clustering Models with Snowflake Model Registry

Transform unstructured text into actionable insights using AI-powered clustering.

**What you'll build:**

| Model | Algorithm | Runtime | Use Case |
|-------|-----------|---------|----------|
| `AI_CLUSTER_KMEANS` | K-Means | Warehouse | Group into **specified number** of clusters |
| `AI_CLUSTER_DEEP` | HDBSCAN | SPCS | **Auto-discover** clusters + outliers |

**Three Steps:**
1. **Build** - Define custom model classes
2. **Register** - Log to Model Registry with versioning
3. **Use** - Run inference via SQL or SPCS batch jobs

In [None]:
# Install required packages
!pip install --upgrade snowflake-ml-python hdbscan -q
print("Packages installed")

---
## Setup

In [None]:
# Core imports
import pandas as pd
import numpy as np
from snowflake.snowpark.context import get_active_session
from snowflake.ml.registry import Registry
from snowflake.ml.model import custom_model

# Initialize
session = get_active_session()
reg = Registry(session=session)

# Configuration
DATABASE = "HALEY_DEMOS"
SCHEMA = "CLUSTERING"
COMPUTE_POOL = "MLOPS_COMPUTE_POOL"

print(f"Connected: {session.get_current_account()}")
print(f"Location: {DATABASE}.{SCHEMA}")

---
## Sample Data

In [None]:
# Create sample text data across distinct topics
sample_texts = [
    # Technology
    "Machine learning models require large datasets for training",
    "Neural networks can recognize patterns in complex data",
    "Cloud computing enables scalable data processing",
    "APIs allow different software systems to communicate",
    
    # Finance
    "Stock market volatility affects investment portfolios",
    "Interest rates influence borrowing costs for businesses",
    "Diversification reduces risk in investment strategies",
    "Quarterly earnings reports drive stock price movements",
    
    # Healthcare
    "Clinical trials test the efficacy of new medications",
    "Patient outcomes improve with early disease detection",
    "Electronic health records streamline medical documentation",
    "Preventive care reduces long-term healthcare costs",
    
    # Environment
    "Renewable energy sources reduce carbon emissions",
    "Climate change impacts agricultural productivity",
    "Sustainable practices minimize environmental footprint",
    "Conservation efforts protect endangered species"
]

# Create table
session.sql(f"CREATE SCHEMA IF NOT EXISTS {DATABASE}.{SCHEMA}").collect()
df = session.create_dataframe([[t] for t in sample_texts], schema=["TEXT_CONTENT"])
df.write.mode("overwrite").save_as_table(f"{DATABASE}.{SCHEMA}.SAMPLE_TEXTS")

print(f"Created {len(sample_texts)} sample records")
session.table(f"{DATABASE}.{SCHEMA}.SAMPLE_TEXTS").show(5)

In [None]:
# Generate embeddings for model registration
embeddings_df = session.sql(f"""
    SELECT 
        TEXT_CONTENT,
        AI_EMBED('snowflake-arctic-embed-l-v2.0', TEXT_CONTENT)::ARRAY AS EMBEDDING
    FROM {DATABASE}.{SCHEMA}.SAMPLE_TEXTS
""").to_pandas()

print(f"Generated {len(embeddings_df)} embeddings (dim={len(embeddings_df['EMBEDDING'].iloc[0])})")

---
# STEP 1: BUILD MODELS

Define custom model classes using `snowflake.ml.model.custom_model`.

### Model 1: K-Means

**Use when:** You know how many groups you want.

- Partitions data into exactly K clusters
- Uses `@partitioned_inference_api` for batch processing
- Runs on **Warehouse** (TABLE_FUNCTION)

In [None]:
class AIClusterKMeans(custom_model.CustomModel):
    """K-Means clustering for embeddings."""
    
    def __init__(self, context: custom_model.ModelContext) -> None:
        super().__init__(context)
        self.default_n_clusters = 4
    
    def _parse_embedding(self, emb):
        if isinstance(emb, str):
            import json
            return json.loads(emb)
        return list(emb)
    
    @custom_model.partitioned_inference_api
    def predict(self, input_df: pd.DataFrame) -> pd.DataFrame:
        """Cluster embeddings into K groups.
        
        Input: ROW_INDEX, EMBEDDING (JSON string), N_CLUSTERS (optional)
        Output: ROW_INDEX, CLUSTER_ID, DISTANCE_TO_CENTER
        """
        from sklearn.cluster import KMeans
        import numpy as np
        
        row_index = input_df['ROW_INDEX'].values if 'ROW_INDEX' in input_df.columns else range(len(input_df))
        embeddings = np.array([self._parse_embedding(e) for e in input_df['EMBEDDING']])
        n_clusters = int(input_df['N_CLUSTERS'].iloc[0]) if 'N_CLUSTERS' in input_df.columns else self.default_n_clusters
        n_clusters = min(n_clusters, len(embeddings))
        
        kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
        labels = kmeans.fit_predict(embeddings)
        distances = np.min(kmeans.transform(embeddings), axis=1)
        
        return pd.DataFrame({
            'ROW_INDEX': row_index,
            'CLUSTER_ID': labels,
            'DISTANCE_TO_CENTER': np.round(distances, 4)
        })

print("Defined: AIClusterKMeans")

### Model 2: HDBSCAN

**Use when:** You want the algorithm to discover natural groupings.

- Automatically finds clusters + identifies outliers
- Uses `@inference_api` for standard inference
- Runs on **SPCS** (requires `hdbscan` pip package)

In [None]:
class AIClusterDeep(custom_model.CustomModel):
    """HDBSCAN clustering - auto-discovers clusters and outliers."""
    
    def __init__(self, context: custom_model.ModelContext) -> None:
        super().__init__(context)
        self.default_min_cluster_size = 3
    
    def _parse_embedding(self, emb):
        if isinstance(emb, str):
            import json
            return json.loads(emb)
        return list(emb)
    
    @custom_model.inference_api
    def predict(self, input_df: pd.DataFrame) -> pd.DataFrame:
        """Discover clusters in embedding data.
        
        Input: EMBEDDING, MIN_CLUSTER_SIZE (optional)
        Output: CLUSTER_ID (-1=outlier), IS_OUTLIER, PROBABILITY
        """
        import hdbscan
        import numpy as np
        
        embeddings = np.array([self._parse_embedding(e) for e in input_df['EMBEDDING']])
        min_size = int(input_df['MIN_CLUSTER_SIZE'].iloc[0]) if 'MIN_CLUSTER_SIZE' in input_df.columns else self.default_min_cluster_size
        
        clusterer = hdbscan.HDBSCAN(
            min_cluster_size=min_size,
            metric='euclidean',
            prediction_data=True
        )
        clusterer.fit(embeddings)
        
        return pd.DataFrame({
            'CLUSTER_ID': clusterer.labels_,
            'IS_OUTLIER': (clusterer.labels_ == -1),
            'PROBABILITY': np.round(clusterer.probabilities_, 4)
        })

print("Defined: AIClusterDeep")

---
# STEP 2: REGISTER MODELS

Log models to Snowflake Model Registry with version control.

In [None]:
# Prepare sample inputs for model validation
embeddings_list = embeddings_df['EMBEDDING'].tolist()[:10]

sample_input_kmeans = pd.DataFrame({
    'ROW_INDEX': list(range(10)),
    'EMBEDDING': [str(e) for e in embeddings_list],
    'N_CLUSTERS': [4] * 10
})

sample_input_deep = pd.DataFrame({
    'EMBEDDING': embeddings_list,
    'MIN_CLUSTER_SIZE': [3] * 10
})

print("Sample inputs ready")

In [None]:
# Register K-Means (Warehouse runtime)
print("Registering AI_CLUSTER_KMEANS...")

mv_kmeans = reg.log_model(
    AIClusterKMeans(custom_model.ModelContext()),
    model_name="AI_CLUSTER_KMEANS",
    version_name="V1",
    sample_input_data=sample_input_kmeans,
    options={"function_type": "TABLE_FUNCTION"},
    conda_dependencies=["scikit-learn", "numpy", "pandas"],
    comment="K-Means clustering for AI_EMBED embeddings"
)

session.sql(f"ALTER MODEL {DATABASE}.{SCHEMA}.AI_CLUSTER_KMEANS SET DEFAULT_VERSION = V1").collect()
print(f"Registered: {mv_kmeans.model_name} V1 (Warehouse)")

In [None]:
# Register HDBSCAN (SPCS runtime - pip package required)
print("Registering AI_CLUSTER_DEEP...")

mv_deep = reg.log_model(
    AIClusterDeep(custom_model.ModelContext()),
    model_name="AI_CLUSTER_DEEP",
    version_name="V1",
    sample_input_data=sample_input_deep,
    pip_requirements=["hdbscan", "scikit-learn", "numpy", "pandas"],
    target_platforms=["SNOWPARK_CONTAINER_SERVICES"],
    comment="HDBSCAN clustering - auto-discovers clusters and outliers"
)

print(f"Registered: {mv_deep.model_name} V1 (SPCS)")

In [None]:
# Verify registered models
print("Registered Models:")
print("=" * 50)

for model_name in ["AI_CLUSTER_KMEANS", "AI_CLUSTER_DEEP"]:
    try:
        model = reg.get_model(model_name)
        versions = model.show_versions()
        print(f"\n{model_name}")
        print(f"  Versions: {', '.join(versions['name'].tolist())}")
    except Exception as e:
        print(f"\n{model_name}: Not found")

---
# STEP 3: USE MODELS

Run inference using the registered models.

## 3a. K-Means (SQL - Instant)

Runs on warehouse via SQL TABLE function. Fast and simple.

In [None]:
# Test K-Means clustering via SQL
print("Running K-Means clustering...\n")

results_kmeans = session.sql(f"""
    WITH input_data AS (
        SELECT 
            ROW_NUMBER() OVER (ORDER BY TEXT_CONTENT) AS ROW_INDEX,
            TEXT_CONTENT
        FROM {DATABASE}.{SCHEMA}.SAMPLE_TEXTS
    ),
    cluster_results AS (
        SELECT r.*
        FROM (
            SELECT 
                ROW_NUMBER() OVER (ORDER BY TEXT_CONTENT) AS ROW_INDEX,
                1 AS BATCH_ID,
                TO_JSON(AI_EMBED('snowflake-arctic-embed-l-v2.0', TEXT_CONTENT)::ARRAY) AS EMBEDDING,
                4 AS N_CLUSTERS
            FROM {DATABASE}.{SCHEMA}.SAMPLE_TEXTS
        ) src,
        TABLE(MODEL({DATABASE}.{SCHEMA}.AI_CLUSTER_KMEANS, V1)!PREDICT(
              src.ROW_INDEX, src.EMBEDDING, src.N_CLUSTERS)
              OVER (PARTITION BY src.BATCH_ID)) r
    )
    SELECT 
        c.CLUSTER_ID,
        i.TEXT_CONTENT,
        c.DISTANCE_TO_CENTER
    FROM input_data i
    JOIN cluster_results c ON i.ROW_INDEX = c.ROW_INDEX
    ORDER BY c.CLUSTER_ID, c.DISTANCE_TO_CENTER
""").to_pandas()

# Display by cluster
for cluster_id in sorted(results_kmeans['CLUSTER_ID'].unique()):
    cluster_data = results_kmeans[results_kmeans['CLUSTER_ID'] == cluster_id]
    print(f"\nCluster {cluster_id} ({len(cluster_data)} items)")
    for _, row in cluster_data.iterrows():
        print(f"  - {row['TEXT_CONTENT'][:60]}...")

## 3b. HDBSCAN - Build Container (Takes ~2-3 min)

SPCS models require building a container image first. This is a one-time setup per version.

In [None]:
# Prepare input data with embeddings
print("Preparing input data...")

session.sql(f"""
    CREATE OR REPLACE TABLE {DATABASE}.{SCHEMA}.HDBSCAN_INPUT AS
    SELECT 
        TEXT_CONTENT,
        TO_JSON(AI_EMBED('snowflake-arctic-embed-l-v2.0', TEXT_CONTENT)::ARRAY) AS EMBEDDING,
        2 AS MIN_CLUSTER_SIZE
    FROM {DATABASE}.{SCHEMA}.SAMPLE_TEXTS
""").collect()

print(f"Created: {DATABASE}.{SCHEMA}.HDBSCAN_INPUT")

In [None]:
# Get model version reference
mv_deep = reg.get_model("AI_CLUSTER_DEEP").version("V1")
print(f"Model: {mv_deep.model_name} V1")

## 3c. HDBSCAN - Run Batch Inference

Submit a batch job to the SPCS compute pool. Container builds on first run, then reuses.

In [None]:
# Import batch inference specs
from snowflake.ml.model._client.model.batch_inference_specs import OutputSpec, SaveMode

# Input data
input_df = session.table(f"{DATABASE}.{SCHEMA}.HDBSCAN_INPUT").select("EMBEDDING", "MIN_CLUSTER_SIZE")

# Output location
output_stage = f"@{DATABASE}.{SCHEMA}.NOTEBOOKS/hdbscan_results/"

print(f"Submitting batch job to {COMPUTE_POOL}...")
print("(First run builds container image - takes ~2-3 min)")
print("(Subsequent runs are faster)\n")

job = mv_deep.run_batch(
    compute_pool=COMPUTE_POOL,
    X=input_df,
    output_spec=OutputSpec(
        stage_location=output_stage,
        mode=SaveMode.OVERWRITE
    )
)

print(f"Job ID: {job.id}")
print("Waiting for completion...")

In [None]:
# Wait for job completion
job.wait()
print("Batch job complete!")

In [None]:
# List output files
session.sql(f"LIST {output_stage}").show()

In [None]:
# Read results - get exact filename from LIST output
files = session.sql(f"LIST {output_stage}").collect()
parquet_file = [f['name'] for f in files if f['name'].endswith('.parquet')][0]

results_df = session.read.parquet(f"@{DATABASE}.{SCHEMA}.{parquet_file}")
results_hdbscan = results_df.to_pandas()

# Summary
n_clusters = len([c for c in results_hdbscan['CLUSTER_ID'].unique() if c != -1])
n_outliers = (results_hdbscan['CLUSTER_ID'] == -1).sum()
print(f"\nResults: {n_clusters} clusters, {n_outliers} outliers")
print(results_hdbscan[['CLUSTER_ID', 'IS_OUTLIER', 'PROBABILITY']])

## 3d. Cleanup - Cancel/View Jobs

Batch jobs auto-terminate when complete. Use these commands to manage jobs.

In [None]:
# View recent batch jobs
from snowflake.ml.jobs import list_jobs

print("Recent batch jobs:")
print(list_jobs())

In [None]:
# Cancel a running job (if needed)
# job.cancel()

# View job logs (for debugging)
# print(job.get_logs())

print("Batch jobs auto-terminate - no manual cleanup needed!")

---
## Usage Reference

### K-Means (SQL)

```sql
WITH input_data AS (
    SELECT 
        ROW_NUMBER() OVER (ORDER BY id) AS ROW_INDEX,
        1 AS BATCH_ID,
        TO_JSON(AI_EMBED('snowflake-arctic-embed-l-v2.0', your_text)::ARRAY) AS EMBEDDING,
        5 AS N_CLUSTERS
    FROM your_table
)
SELECT r.*
FROM input_data src,
TABLE(MODEL(db.schema.AI_CLUSTER_KMEANS, V1)!PREDICT(
    src.ROW_INDEX, src.EMBEDDING, src.N_CLUSTERS)
    OVER (PARTITION BY src.BATCH_ID)) r
```

### HDBSCAN (Python - run_batch)

```python
from snowflake.ml.model._client.model.batch_inference_specs import OutputSpec, SaveMode

mv = reg.get_model("AI_CLUSTER_DEEP").version("V1")

job = mv.run_batch(
    compute_pool="YOUR_COMPUTE_POOL",
    X=input_dataframe,  # Must have EMBEDDING column
    output_spec=OutputSpec(
        stage_location="@your_stage/output/",
        mode=SaveMode.OVERWRITE
    )
)

job.wait()
results = session.read.parquet("@your_stage/output/")
```

---
## Summary

| Model | Runtime | Invocation | Packages |
|-------|---------|------------|----------|
| `AI_CLUSTER_KMEANS` | Warehouse | SQL TABLE function | Anaconda |
| `AI_CLUSTER_DEEP` | SPCS | `mv.run_batch()` | pip |

**Model Registry Benefits:**
- Version control for iterations
- Lineage tracking
- Centralized access control
- Usage monitoring via `ACCOUNT_USAGE`