# OpenSearch Neural Sparse Integration Test

This notebook tests the v21.4 Korean Neural Sparse Encoder integration with OpenSearch.

## Prerequisites

1. OpenSearch cluster running (local or remote)
2. ML plugin enabled
3. v21.4 model uploaded to HuggingFace or model artifacts available locally

In [None]:
import sys
from pathlib import Path

def find_project_root():
    current = Path.cwd()
    for parent in [current] + list(current.parents):
        if (parent / "pyproject.toml").exists() or (parent / "src").exists():
            return parent
    return Path.cwd().parent.parent

PROJECT_ROOT = find_project_root()
sys.path.insert(0, str(PROJECT_ROOT))

import json
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForMaskedLM

print(f"Project root: {PROJECT_ROOT}")

## 1. Configuration

In [None]:
# OpenSearch Configuration
OPENSEARCH_HOST = "localhost"
OPENSEARCH_PORT = 9200
OPENSEARCH_URL = f"http://{OPENSEARCH_HOST}:{OPENSEARCH_PORT}"

# Model paths
MODEL_PATH = PROJECT_ROOT / "huggingface" / "v21.4"
CHECKPOINT_PATH = PROJECT_ROOT / "outputs" / "v21.4_korean_enhanced" / "best_model.pt"

# Index configuration
INDEX_NAME = "korean_sparse_test_v21_4"

print(f"OpenSearch URL: {OPENSEARCH_URL}")
print(f"Model path: {MODEL_PATH}")
print(f"Index name: {INDEX_NAME}")

## 2. Load Model for Local Inference

In [None]:
class KoreanNeuralSparseEncoder:
    """Korean Neural Sparse Encoder for OpenSearch integration."""
    
    def __init__(self, model_path: Path, device: str = 'cpu'):
        self.device = device
        
        # Load from HuggingFace format or checkpoint
        if (model_path / "model.safetensors").exists() or (model_path / "pytorch_model.bin").exists():
            self.model = AutoModelForMaskedLM.from_pretrained(model_path)
            self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        else:
            # Load from training checkpoint
            checkpoint = torch.load(CHECKPOINT_PATH, map_location='cpu')
            config = checkpoint.get('config', {})
            model_name = config.get('model_name', 'skt/A.X-Encoder-base')
            
            self.model = AutoModelForMaskedLM.from_pretrained(model_name)
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            
            # Load trained weights
            state_dict = checkpoint['model_state_dict']
            new_state_dict = {k[6:] if k.startswith('model.') else k: v for k, v in state_dict.items()}
            self.model.load_state_dict(new_state_dict, strict=True)
        
        self.model = self.model.to(device)
        self.model.eval()
        self.relu = nn.ReLU()
    
    @torch.no_grad()
    def encode(self, text: str, top_k: int = 20) -> dict:
        """
        Encode text to sparse vector format for OpenSearch.
        
        Returns:
            Dictionary mapping token strings to weights (OpenSearch sparse_vector format)
        """
        inputs = self.tokenizer(
            text,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=64
        ).to(self.device)
        
        outputs = self.model(**inputs)
        logits = outputs.logits
        
        # SPLADE: log(1 + ReLU(x))
        token_scores = torch.log1p(self.relu(logits))
        
        # Mask padding
        mask = inputs['attention_mask'].unsqueeze(-1).float()
        token_scores = token_scores * mask
        
        # Max pooling
        sparse_repr = token_scores.max(dim=1).values[0]  # [vocab_size]
        
        # Get top tokens
        top_values, top_indices = sparse_repr.topk(top_k)
        
        result = {}
        for idx, val in zip(top_indices.tolist(), top_values.tolist()):
            if val > 0:
                token = self.tokenizer.decode([idx]).strip()
                if token and token not in ['[CLS]', '[SEP]', '[PAD]', '[UNK]', '[MASK]']:
                    result[token] = round(val, 4)
        
        return result
    
    def encode_batch(self, texts: list, top_k: int = 20) -> list:
        """Encode multiple texts."""
        return [self.encode(text, top_k) for text in texts]


# Initialize encoder
device = 'cuda' if torch.cuda.is_available() else 'cpu'
encoder = KoreanNeuralSparseEncoder(MODEL_PATH, device)
print(f"Encoder initialized on {device}")

In [None]:
# Test encoding
test_text = "당뇨병 환자의 인슐린 치료"
sparse_vector = encoder.encode(test_text)

print(f"Input: {test_text}")
print(f"Sparse vector ({len(sparse_vector)} tokens):")
for token, weight in sorted(sparse_vector.items(), key=lambda x: -x[1])[:10]:
    print(f"  {token}: {weight}")

## 3. OpenSearch Connection Test

In [None]:
import requests

def check_opensearch_connection():
    """Check if OpenSearch is accessible."""
    try:
        response = requests.get(OPENSEARCH_URL, timeout=5)
        if response.status_code == 200:
            info = response.json()
            print(f"OpenSearch connected!")
            print(f"  Version: {info.get('version', {}).get('number', 'unknown')}")
            print(f"  Cluster: {info.get('cluster_name', 'unknown')}")
            return True
    except requests.exceptions.ConnectionError:
        print(f"Cannot connect to OpenSearch at {OPENSEARCH_URL}")
        print("Make sure OpenSearch is running.")
    return False

opensearch_available = check_opensearch_connection()

## 4. Create Index with Sparse Vector Mapping

In [None]:
def create_sparse_index(index_name: str):
    """Create index with sparse_vector field mapping."""
    
    # Delete if exists
    requests.delete(f"{OPENSEARCH_URL}/{index_name}")
    
    # Create index with mapping
    mapping = {
        "settings": {
            "number_of_shards": 1,
            "number_of_replicas": 0
        },
        "mappings": {
            "properties": {
                "text": {
                    "type": "text",
                    "analyzer": "standard"
                },
                "sparse_embedding": {
                    "type": "rank_features"
                },
                "category": {
                    "type": "keyword"
                }
            }
        }
    }
    
    response = requests.put(
        f"{OPENSEARCH_URL}/{index_name}",
        json=mapping,
        headers={"Content-Type": "application/json"}
    )
    
    if response.status_code == 200:
        print(f"Index '{index_name}' created successfully!")
        return True
    else:
        print(f"Error creating index: {response.text}")
        return False

if opensearch_available:
    create_sparse_index(INDEX_NAME)

## 5. Index Test Documents

In [None]:
# Test documents covering various domains
TEST_DOCUMENTS = [
    # Medical domain
    {"text": "당뇨병 환자는 인슐린 주사를 맞아야 합니다.", "category": "medical"},
    {"text": "고혈압 증상으로는 두통과 어지러움이 있습니다.", "category": "medical"},
    {"text": "암 환자의 항암 치료 부작용을 관리하는 방법", "category": "medical"},
    {"text": "감기 증상 완화를 위한 약물 복용 방법", "category": "medical"},
    {"text": "폐렴 질환의 진단과 치료 과정", "category": "medical"},
    
    # Legal domain
    {"text": "부동산 계약 해지 시 위약금 규정", "category": "legal"},
    {"text": "임대차 보호법에 따른 세입자 권리", "category": "legal"},
    {"text": "상속법에 따른 유산 분배 절차", "category": "legal"},
    {"text": "저작권 침해에 대한 법적 대응 방안", "category": "legal"},
    {"text": "노동법 위반 시 처벌 규정 안내", "category": "legal"},
    
    # Technology domain
    {"text": "데이터베이스 최적화를 위한 인덱스 설계", "category": "tech"},
    {"text": "머신러닝 모델 학습을 위한 데이터 전처리", "category": "tech"},
    {"text": "클라우드 서버 보안 설정 가이드", "category": "tech"},
    {"text": "웹 애플리케이션 성능 최적화 방법", "category": "tech"},
    {"text": "API 개발을 위한 RESTful 설계 원칙", "category": "tech"},
]

print(f"Prepared {len(TEST_DOCUMENTS)} test documents")

In [None]:
def index_documents(index_name: str, documents: list):
    """Index documents with sparse embeddings."""
    
    indexed = 0
    for i, doc in enumerate(documents):
        # Generate sparse embedding
        sparse_vector = encoder.encode(doc["text"], top_k=30)
        
        # Prepare document
        doc_body = {
            "text": doc["text"],
            "sparse_embedding": sparse_vector,
            "category": doc["category"]
        }
        
        # Index document
        response = requests.put(
            f"{OPENSEARCH_URL}/{index_name}/_doc/{i}",
            json=doc_body,
            headers={"Content-Type": "application/json"}
        )
        
        if response.status_code in [200, 201]:
            indexed += 1
        else:
            print(f"Error indexing doc {i}: {response.text}")
    
    # Refresh index
    requests.post(f"{OPENSEARCH_URL}/{index_name}/_refresh")
    
    print(f"Indexed {indexed}/{len(documents)} documents")
    return indexed

if opensearch_available:
    index_documents(INDEX_NAME, TEST_DOCUMENTS)

## 6. Sparse Search Test

In [None]:
def sparse_search(index_name: str, query: str, top_k: int = 5):
    """
    Perform sparse vector search using rank_feature queries.
    """
    # Encode query
    query_vector = encoder.encode(query, top_k=20)
    
    # Build rank_feature queries
    should_queries = []
    for token, weight in query_vector.items():
        should_queries.append({
            "rank_feature": {
                "field": f"sparse_embedding.{token}",
                "boost": weight
            }
        })
    
    # Search query
    search_body = {
        "query": {
            "bool": {
                "should": should_queries,
                "minimum_should_match": 1
            }
        },
        "size": top_k,
        "_source": ["text", "category"]
    }
    
    response = requests.post(
        f"{OPENSEARCH_URL}/{index_name}/_search",
        json=search_body,
        headers={"Content-Type": "application/json"}
    )
    
    if response.status_code == 200:
        results = response.json()
        hits = results.get("hits", {}).get("hits", [])
        return hits
    else:
        print(f"Search error: {response.text}")
        return []


def print_search_results(query: str, results: list):
    """Pretty print search results."""
    print(f"\nQuery: {query}")
    print("-" * 60)
    
    if not results:
        print("No results found")
        return
    
    for i, hit in enumerate(results, 1):
        score = hit.get("_score", 0)
        source = hit.get("_source", {})
        text = source.get("text", "")
        category = source.get("category", "")
        print(f"{i}. [{category}] {text}")
        print(f"   Score: {score:.4f}")

In [None]:
# Test queries
TEST_QUERIES = [
    # Single term queries (problem terms from v21.3)
    "인슐린",
    "데이터베이스",
    "증상",
    "질환",
    "추천",
    
    # Short phrase queries
    "당뇨병 치료",
    "법적 분쟁",
    "서버 보안",
    
    # Natural language queries
    "암 환자 식단 추천",
    "부동산 계약서 작성 방법",
]

if opensearch_available:
    print("=" * 60)
    print("Sparse Search Results")
    print("=" * 60)
    
    for query in TEST_QUERIES:
        results = sparse_search(INDEX_NAME, query, top_k=3)
        print_search_results(query, results)
else:
    print("OpenSearch not available. Skipping search tests.")
    print("\nTo run these tests:")
    print("1. Start OpenSearch locally or connect to a remote cluster")
    print("2. Update OPENSEARCH_HOST and OPENSEARCH_PORT")
    print("3. Re-run this notebook")

## 7. Evaluation Metrics

In [None]:
# Relevance ground truth for evaluation
RELEVANCE_GT = {
    "인슐린": ["medical"],
    "데이터베이스": ["tech"],
    "증상": ["medical"],
    "질환": ["medical"],
    "당뇨병 치료": ["medical"],
    "법적 분쟁": ["legal"],
    "서버 보안": ["tech"],
    "암 환자 식단 추천": ["medical"],
    "부동산 계약서 작성 방법": ["legal"],
}

def evaluate_search(query: str, results: list, expected_categories: list):
    """Evaluate search result relevance."""
    if not results:
        return 0.0
    
    # Check if top result is in expected category
    top_category = results[0].get("_source", {}).get("category", "")
    return 1.0 if top_category in expected_categories else 0.0


if opensearch_available:
    print("\n" + "=" * 60)
    print("Search Relevance Evaluation")
    print("=" * 60)
    
    total_queries = 0
    correct = 0
    
    for query, expected in RELEVANCE_GT.items():
        results = sparse_search(INDEX_NAME, query, top_k=1)
        score = evaluate_search(query, results, expected)
        
        total_queries += 1
        correct += score
        
        status = "✓" if score == 1.0 else "✗"
        top_cat = results[0].get("_source", {}).get("category", "N/A") if results else "N/A"
        print(f"{status} '{query}' -> {top_cat} (expected: {expected})")
    
    precision = correct / total_queries * 100 if total_queries > 0 else 0
    print(f"\nPrecision@1: {precision:.1f}% ({int(correct)}/{total_queries})")

## 8. Cleanup

In [None]:
def cleanup_index(index_name: str):
    """Delete test index."""
    response = requests.delete(f"{OPENSEARCH_URL}/{index_name}")
    if response.status_code == 200:
        print(f"Index '{index_name}' deleted")
    else:
        print(f"Cleanup failed: {response.text}")

# Uncomment to cleanup:
# if opensearch_available:
#     cleanup_index(INDEX_NAME)

In [None]:
print("\n" + "=" * 60)
print("OpenSearch Integration Test Complete!")
print("=" * 60)
print("\nNext Steps:")
print("1. Deploy model to OpenSearch ML node")
print("2. Register model with OpenSearch ML plugin")
print("3. Create neural sparse ingest pipeline")
print("4. Index production documents with sparse embeddings")