# Fine-Tuning Embeddings for JIRA Search

This demo shows how to **fine-tune a custom embedding model** using Snowflake ML Jobs and use it with **Cortex Search BYO Embedding**.

## Architecture Overview

```
┌─────────────────┐     ┌─────────────────┐     ┌─────────────────┐
│   ML Job #1     │     │   ML Job #2     │     │  Cortex Search  │
│   Fine-tune     │ ──> │   Log Model     │ ──> │  BYO Embedding  │
│   Embedder      │     │   to Registry   │     │                 │
└─────────────────┘     └─────────────────┘     └─────────────────┘
        │                       │                       │
        ▼                       ▼                       ▼
   Job Stage              Model Registry          Search Service
   (artifacts)            (versioned model)      (semantic search)
```

## Why Fine-Tune Embeddings?

Pre-trained embedding models work well on general text, but fine-tuning on **your domain data** improves search relevance:

- Learn domain terminology ("P0 bug" = critical issue)
- Understand relationships (auth issues ↔ SSO issues)
- Define similarity for your use case

## What We'll Use

| Component | Purpose |
|-----------|--------|
| [ML Jobs](https://docs.snowflake.com/developer-guide/snowflake-ml/ml-jobs/overview) | Run training on GPU compute pools |
| [Model Registry](https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/overview) | Version and deploy the model |
| [Cortex Search](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-search/cortex-search-overview) | Semantic search with BYO embedding |

---
## Prerequisites

Before running this demo, ensure you have:

1. **GPU Compute Pool** - For training
2. **External Access Integration** - For PyPI and HuggingFace access
3. **Warehouse** - For data operations

In [None]:
from snowflake.snowpark.context import get_active_session
session = get_active_session()

# Configuration - update these for your environment
DATABASE = "JIRA_EMBEDDING_DEMO"
SCHEMA = "PUBLIC"
COMPUTE_POOL = "JIRA_TRAINING_POOL"  # GPU compute pool
WAREHOUSE = "JIRA_DEMO_WH"
EAI = "ALLOW_ALL_EAI"  # External access integration

session.use_database(DATABASE)
session.use_schema(SCHEMA)
print(f"Using {DATABASE}.{SCHEMA}")

---
## Step 1: Prepare Training Data

We have 100 synthetic JIRA tickets. For fine-tuning, we create **training pairs** of similar tickets.

### Why Training Pairs?

The model learns via **contrastive learning** using `MultipleNegativesRankingLoss`:
- **Anchor**: A ticket summary
- **Positive**: Summary + description of a *similar* ticket (same issue type)
- **Negatives**: Other tickets in the batch (implicitly)

### Training Logic

We pair tickets by **issue type** (BUG, FEATURE, TASK, etc.):
- Tickets of the same type should have **similar** embeddings
- Different types should be **further apart** in embedding space

This teaches the model domain-specific similarity:
- "Login timeout" ↔ "SSO authentication failure" (both AUTH bugs)
- "Add dark mode" ↔ "Implement theme toggle" (both UI features)

### Why This Works

Pre-trained models treat all text equally. Fine-tuning teaches:
- Domain vocabulary ("P0" = critical, "SSO" = auth)
- Business similarity (bugs similar to bugs, features to features)
- Search relevance for *your* use case

In [None]:
-- View ticket distribution
SELECT ISSUE_TYPE, COUNT(*) as COUNT 
FROM JIRA_TICKETS 
GROUP BY ISSUE_TYPE 
ORDER BY COUNT DESC

In [None]:
-- Create training pairs: tickets of same type should be similar
CREATE OR REPLACE TABLE TRAINING_PAIRS AS
SELECT 
    a.SUMMARY AS ANCHOR,
    b.SUMMARY || ' ' || b.DESCRIPTION AS POSITIVE
FROM JIRA_TICKETS a
JOIN JIRA_TICKETS b 
    ON a.ISSUE_TYPE = b.ISSUE_TYPE 
    AND a.ISSUE_KEY != b.ISSUE_KEY
ORDER BY RANDOM()
LIMIT 500;

SELECT COUNT(*) AS NUM_TRAINING_PAIRS FROM TRAINING_PAIRS;

---
## Step 2: Submit Training Job

### What is ML Jobs?

[ML Jobs](https://docs.snowflake.com/developer-guide/snowflake-ml/ml-jobs/overview) runs Python workloads on Snowflake's managed GPU compute pools:

- **`submit_directory()`** - Upload code directory and run
- **Automatic Snowflake session** - Scripts can query data directly
- **Artifact storage** - Outputs saved to job stage (`job._stage_path`)

### Training Script Overview

Our `src/train.py` script:
1. Loads training pairs from `TRAINING_PAIRS` table
2. Fine-tunes `all-MiniLM-L6-v2` with `MultipleNegativesRankingLoss`
3. Saves model to `MLRS_STAGE_RESULT_PATH` (job output stage)

In [None]:
from snowflake.ml import jobs
import os

# Path to training code
# In a real setup, this would be a local path to your src/ directory
# For this demo, we'll create it inline
import tempfile

TRAIN_SCRIPT = '''
import os
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
from snowflake.snowpark import Session

session = Session.builder.getOrCreate()
df = session.sql("SELECT * FROM TRAINING_PAIRS").to_pandas()
print(f"Loaded {len(df)} training pairs")

examples = [
    InputExample(texts=[row["ANCHOR"], row["POSITIVE"]])
    for _, row in df.iterrows()
]

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
print(f"Base model dimension: {model.get_sentence_embedding_dimension()}")

train_dataloader = DataLoader(examples, shuffle=True, batch_size=16)
train_loss = losses.MultipleNegativesRankingLoss(model)

output_path = os.environ.get("MLRS_STAGE_RESULT_PATH", "/tmp/output")
model_path = os.path.join(output_path, "model")

model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=3,
    warmup_steps=10,
    output_path=model_path,
    show_progress_bar=True
)
print(f"Training complete! Model saved to {model_path}")
'''

REQUIREMENTS = 'sentence-transformers>=2.2.0\ntorch\n'

# Create temp directory with training code
payload_dir = tempfile.mkdtemp()
with open(os.path.join(payload_dir, "train.py"), "w") as f:
    f.write(TRAIN_SCRIPT)
with open(os.path.join(payload_dir, "requirements.txt"), "w") as f:
    f.write(REQUIREMENTS)

print(f"Training code prepared at {payload_dir}")

In [None]:
# Submit the training job
train_job = jobs.submit_directory(
    payload_dir,
    entrypoint="train.py",
    compute_pool=COMPUTE_POOL,
    stage_name="TRAINING_STAGE",
    external_access_integrations=[EAI],
    session=session,
)

print(f"Job submitted!")
print(f"  Job ID: {train_job.id}")
print(f"  Stage path: {train_job._stage_path}")

In [None]:
# Wait for training to complete (typically 5-10 minutes)
print("Waiting for training job to complete...")
status = train_job.wait()
print(f"\nJob finished with status: {status}")

if status == "FAILED":
    print("\n" + "="*60)
    print(train_job.get_logs())
else:
    print(f"\nModel saved to: {train_job._stage_path}/output/")
    # Show last part of logs
    logs = train_job.get_logs()
    print(f"\nTraining logs (last 1000 chars):\n{logs[-1000:]}")

---
## Step 3: Log Model to Registry

### What is Model Registry?

[Model Registry](https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/overview) provides:
- **Versioning** - Track model iterations
- **Lineage** - Know what data produced each model
- **Inference** - Run predictions via `mv.run()` or deploy as service

### CustomModel for Embeddings

Since `SentenceTransformer` isn't a natively supported model type, we wrap it in a `CustomModel` with an `encode()` inference API. This exposes the embedding function for use.

In [None]:
from snowflake.ml import jobs

# Log model script
LOG_SCRIPT = '''
import os
import argparse
import pandas as pd
from sentence_transformers import SentenceTransformer
from snowflake.snowpark import Session
from snowflake.ml.fileset.sfcfs import SFFileSystem
from snowflake.ml.registry import Registry
from snowflake.ml.model import custom_model

class JiraEmbedder(custom_model.CustomModel):
    """Custom model that exposes encode() for generating embeddings."""
    
    def __init__(self, context: custom_model.ModelContext) -> None:
        super().__init__(context)
        self.model = SentenceTransformer(context.path("model"))
        self.embedding_dim = self.model.get_sentence_embedding_dimension()

    @custom_model.inference_api
    def encode(self, input_df: pd.DataFrame) -> pd.DataFrame:
        texts = input_df["text"].tolist()
        embeddings = self.model.encode(texts, show_progress_bar=False)
        return pd.DataFrame({"embedding": [emb.tolist() for emb in embeddings]})

parser = argparse.ArgumentParser()
parser.add_argument("model_stage_path")
parser.add_argument("--model-name", required=True)
parser.add_argument("--version", default="v1")
args = parser.parse_args()

session = Session.builder.getOrCreate()

# Download model from stage
local_dir = "/tmp/model_download"
os.makedirs(local_dir, exist_ok=True)
fs = SFFileSystem(snowpark_session=session)
fs.get(args.model_stage_path.rstrip("/") + "/", local_dir, recursive=True)

# Find model directory
model_path = local_dir
for root, dirs, files in os.walk(local_dir):
    if "config.json" in files:
        model_path = root
        break

print(f"Model found at: {model_path}")

# Create and test custom model
context = custom_model.ModelContext(artifacts={"model": model_path})
embedder = JiraEmbedder(context)

test_result = embedder.encode(pd.DataFrame({"text": ["test"]}))
print(f"Embedding dimension: {len(test_result.iloc[0]['embedding'])}")

# Log to registry
registry = Registry(session=session)
mv = registry.log_model(
    embedder,
    model_name=args.model_name,
    version_name=args.version,
    sample_input_data=pd.DataFrame({"text": ["sample text"]}),
    pip_requirements=["sentence-transformers>=2.2.0", "torch"],
    target_platforms=["SNOWPARK_CONTAINER_SERVICES"],
    comment="Fine-tuned JIRA embedder"
)
print(f"Model logged: {mv.model_name} version {mv.version_name}")
print(f"Functions: {mv.show_functions()}")
'''

# Create payload
log_payload_dir = tempfile.mkdtemp()
with open(os.path.join(log_payload_dir, "log_model.py"), "w") as f:
    f.write(LOG_SCRIPT)

print(f"Log model script prepared")

In [None]:
# Get model path from training job (pass output/ dir, script finds model inside)
model_stage_path = f"{train_job._stage_path}/output/"
print(f"Model stage path: {model_stage_path}")

# Submit logging job
log_job = jobs.submit_directory(
    log_payload_dir,
    entrypoint="log_model.py",
    args=[model_stage_path, "--model-name", "JIRA_EMBEDDER", "--version", "v3"],
    compute_pool=COMPUTE_POOL,
    stage_name="LOG_MODEL_STAGE",
    external_access_integrations=[EAI],
    pip_requirements=["sentence-transformers>=2.2.0", "torch", "snowflake-ml-python>=1.6.0"],
    session=session,
)

print(f"Log job submitted: {log_job.id}")

In [None]:
# Wait for logging to complete
print("Waiting for model logging job...")
status = log_job.wait()
print(f"\nJob finished with status: {status}")
print(f"\nLogs:\n{log_job.get_logs()[-2000:]}")

---
## Step 4: Deploy Embedding Service

For search workloads, we deploy the model as a **service** on SPCS:

- **`create_service()`** - Deploys model to compute pool
- **`mv.run(..., service_name=...)`** - Runs inference via the service
- **Auto-suspend** - Service suspends when idle (saves cost)

This is better than `run_batch` for search because we get results directly as a DataFrame.

In [None]:
from snowflake.ml.registry import Registry
import pandas as pd

# Load the model from registry
registry = Registry(session=session)
mv = registry.get_model("JIRA_EMBEDDER").version("v1")

print("Model loaded from registry")
print(f"Available functions: {mv.show_functions()}")

# Create inference service (reuse if exists)
SERVICE_NAME = "JIRA_EMBEDDER_SVC"

# Check if service already exists (handle both column name cases)
existing = mv.list_services()
service_exists = False
if not existing.empty:
    # Column might be 'service_name' or 'SERVICE_NAME'
    cols = existing.columns.str.upper()
    if 'SERVICE_NAME' in cols:
        col_name = existing.columns[cols.tolist().index('SERVICE_NAME')]
        service_exists = SERVICE_NAME in existing[col_name].values

if service_exists:
    print(f"Service {SERVICE_NAME} already exists - reusing it")
else:
    # min_instances=0 enables auto-suspend (default: 30 min idle)
    # To change suspend time: ALTER SERVICE <name> SET AUTO_SUSPEND_SECS = 300
    print(f"Creating service {SERVICE_NAME}... (first run builds image ~10-20 min)")
    mv.create_service(
        service_name=SERVICE_NAME,
        service_compute_pool=COMPUTE_POOL,
        image_build_compute_pool=COMPUTE_POOL,
        min_instances=0,  # Enables auto-suspend after 30 min idle
        max_instances=1,
    )
    print("Service created! (auto-suspends after 30 min idle)")

In [None]:
# Generate embeddings for all JIRA tickets
tickets_df = session.table("JIRA_TICKETS").to_pandas()

# Combine summary and description for embedding
# Column name indicates what was embedded
texts_df = pd.DataFrame({
    "text": tickets_df["SUMMARY"] + " " + tickets_df["DESCRIPTION"]
})

print(f"Generating embeddings for {len(texts_df)} tickets...")
embeddings_df = mv.run(
    texts_df, 
    function_name="encode",
    service_name=SERVICE_NAME
)
print(f"Generated {len(embeddings_df)} embeddings")

# Name column to indicate source fields
EMBEDDING_COL = "EMBEDDING_SUMMARY_DESC"
dim = len(embeddings_df.iloc[0]['embedding'])
print(f"Embedding dimension: {dim}")

In [None]:
# Add embeddings to tickets and store
tickets_df[EMBEDDING_COL] = embeddings_df["embedding"]

# Create table with VECTOR column
session.sql(f"""
    CREATE OR REPLACE TABLE JIRA_WITH_EMBEDDINGS (
        ISSUE_KEY VARCHAR,
        ISSUE_TYPE VARCHAR,
        PRIORITY VARCHAR,
        STATUS VARCHAR,
        COMPONENT VARCHAR,
        SUMMARY VARCHAR,
        DESCRIPTION VARCHAR,
        {EMBEDDING_COL} VECTOR(FLOAT, {dim})
    )
""").collect()

# Insert data using SELECT (VALUES doesn't support VECTOR)
print(f"Inserting {len(tickets_df)} tickets with embeddings...")
for idx, row in tickets_df.iterrows():
    emb_str = ",".join(str(x) for x in row[EMBEDDING_COL])
    summary = row['SUMMARY'].replace("'", "''")
    desc = row['DESCRIPTION'].replace("'", "''")
    session.sql(f"""
        INSERT INTO JIRA_WITH_EMBEDDINGS 
        SELECT 
            '{row['ISSUE_KEY']}',
            '{row['ISSUE_TYPE']}',
            '{row['PRIORITY']}',
            '{row['STATUS']}',
            '{row['COMPONENT']}',
            '{summary}',
            '{desc}',
            [{emb_str}]::VECTOR(FLOAT, {dim})
    """).collect()

print(f"Done! Table JIRA_WITH_EMBEDDINGS ready.")

---
## Step 5: Create Cortex Search with BYO Embedding

[Cortex Search](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-search/cortex-search-overview) is Snowflake's managed search service. The **BYO Embedding** feature lets you use pre-computed embeddings instead of Snowflake's built-in models.

The `EMBEDDING` clause tells Cortex Search to use our custom embeddings.

In [None]:
-- Create Cortex Search service with BYO embedding
-- TEXT INDEXES: columns for keyword search
-- VECTOR INDEXES: our pre-computed embedding column
CREATE OR REPLACE CORTEX SEARCH SERVICE JIRA_SEARCH
    TEXT INDEXES (SUMMARY)
    VECTOR INDEXES (EMBEDDING_SUMMARY_DESC)
    ATTRIBUTES ISSUE_TYPE, PRIORITY, COMPONENT
    WAREHOUSE = JIRA_DEMO_WH
    TARGET_LAG = '1 hour'
AS SELECT 
    ISSUE_KEY,
    ISSUE_TYPE,
    PRIORITY,
    COMPONENT,
    SUMMARY,
    DESCRIPTION,
    EMBEDDING_SUMMARY_DESC
FROM JIRA_WITH_EMBEDDINGS

---
## Step 6: Search!

To search, we embed the query with the **same model** and find similar tickets.

In [None]:
def search_jira(query: str, top_k: int = 5):
    """Search JIRA tickets using our fine-tuned embedder."""
    
    # Embed query with our model (via service)
    query_df = pd.DataFrame({"text": [query]})
    query_emb = mv.run(
        query_df, 
        function_name="encode",
        service_name=SERVICE_NAME
    ).iloc[0]["embedding"]
    emb_str = ",".join(str(x) for x in query_emb)
    
    # Vector similarity search
    results = session.sql(f"""
        SELECT 
            ISSUE_KEY,
            ISSUE_TYPE,
            PRIORITY,
            SUMMARY,
            ROUND(VECTOR_COSINE_SIMILARITY(
                EMBEDDING_SUMMARY_DESC, 
                [{emb_str}]::VECTOR(FLOAT, {len(query_emb)})
            ), 3) AS SCORE
        FROM JIRA_WITH_EMBEDDINGS
        ORDER BY SCORE DESC
        LIMIT {top_k}
    """).to_pandas()
    
    return results

In [None]:
# Try some searches
print("Query: 'authentication login SSO error'\n")
print(search_jira("authentication login SSO error").to_string(index=False))

In [None]:
print("Query: 'payment checkout failure'\n")
print(search_jira("payment checkout failure").to_string(index=False))

In [None]:
print("Query: 'performance slow loading'\n")
print(search_jira("performance slow loading").to_string(index=False))

---
## Summary

We built an end-to-end pipeline following the [official ML Jobs pattern](https://github.com/Snowflake-Labs/sf-samples/tree/main/samples/ml/ml_jobs/llm_finetune):

| Step | What | How |
|------|------|-----|
| 1 | **Fine-tune** | ML Job with `submit_directory()` → model saved to job stage |
| 2 | **Log model** | Separate ML Job downloads from stage, wraps in `CustomModel`, logs to Registry |
| 3 | **Deploy service** | `mv.create_service()` → model runs on SPCS (auto-suspends when idle) |
| 4 | **Generate embeddings** | `mv.run(df, service_name=...)` → results direct to DataFrame |
| 5 | **Search** | Cortex Search with BYO Embedding clause |

### Key Patterns

- **`job._stage_path`** - Access training artifacts from completed job
- **`SFFileSystem.get()`** - Download from Snowflake stages
- **`CustomModel`** - Wrap non-native models with `@inference_api`
- **`create_service()`** - Deploy for recurring inference (auto-suspends)
- **`run_batch()`** - Use for one-off large batch jobs (outputs to stage)

### Cleanup

When completely done with the embedder:
```python
mv.delete_service("JIRA_EMBEDDER_SVC")
```

### Learn More

- [ML Jobs Documentation](https://docs.snowflake.com/developer-guide/snowflake-ml/ml-jobs/overview)
- [Model Registry Overview](https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/overview)
- [Cortex Search with BYO Embedding](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-search/cortex-search-overview)
- [Official LLM Fine-Tuning Sample](https://github.com/Snowflake-Labs/sf-samples/tree/main/samples/ml/ml_jobs/llm_finetune)

In [None]:
DROP SERVICE JIRA_EMBEDDER_SVC;