diff --git a/src/semantic-router/pkg/cache/milvus_cache.go b/src/semantic-router/pkg/cache/milvus_cache.go index 725f711b..6e5d6af7 100644 --- a/src/semantic-router/pkg/cache/milvus_cache.go +++ b/src/semantic-router/pkg/cache/milvus_cache.go @@ -307,19 +307,11 @@ func (c *MilvusCache) createCollection() error { return err } - // Create index - indexParams := map[string]string{ - "index_type": c.config.Collection.Index.Type, - "metric_type": c.config.Collection.VectorField.MetricType, - "params": fmt.Sprintf(`{"M": %d, "efConstruction": %d}`, - c.config.Collection.Index.Params.M, - c.config.Collection.Index.Params.EfConstruction), - } - - observability.Debugf("MilvusCache.createCollection: creating index for %d-dimensional vectors", actualDimension) - // Create index with updated API - index := entity.NewGenericIndex(c.config.Collection.VectorField.Name, entity.IndexType(c.config.Collection.Index.Type), indexParams) + index, err := entity.NewIndexHNSW(entity.MetricType(c.config.Collection.VectorField.MetricType), c.config.Collection.Index.Params.EfConstruction, c.config.Collection.Index.Params.M) + if err != nil { + return fmt.Errorf("failed to create HNSW index: %w", err) + } if err := c.client.CreateIndex(ctx, c.collectionName, c.config.Collection.VectorField.Name, index, false); err != nil { return err } @@ -517,112 +509,87 @@ func (c *MilvusCache) FindSimilar(model string, query string) ([]byte, bool, err ctx := context.Background() - // Query for completed entries with the same model - // Using Query approach for comprehensive similarity search - queryExpr := fmt.Sprintf("model == \"%s\" && response_body != \"\"", model) - observability.Debugf("MilvusCache.FindSimilar: querying with expr: %s (embedding_dim: %d)", - queryExpr, len(queryEmbedding)) - - // Use Query to get all matching entries, then compute similarity manually - results, err := c.client.Query(ctx, c.collectionName, []string{}, queryExpr, - []string{"query", "response_body", c.config.Collection.VectorField.Name}) + // Define search parameters + searchParam, err := entity.NewIndexHNSWSearchParam(c.config.Search.Params.Ef) + if err != nil { + return nil, false, fmt.Errorf("failed to create search parameters: %w", err) + } + + // Use Milvus Search for efficient similarity search + searchResult, err := c.client.Search( + ctx, + c.collectionName, + []string{}, + fmt.Sprintf("model == \"%s\" && response_body != \"\"", model), + []string{"response_body"}, + []entity.Vector{entity.FloatVector(queryEmbedding)}, + c.config.Collection.VectorField.Name, + entity.MetricType(c.config.Collection.VectorField.MetricType), + c.config.Search.TopK, + searchParam, + ) if err != nil { - observability.Debugf("MilvusCache.FindSimilar: query failed: %v", err) + observability.Debugf("MilvusCache.FindSimilar: search failed: %v", err) atomic.AddInt64(&c.missCount, 1) metrics.RecordCacheOperation("milvus", "find_similar", "error", time.Since(start).Seconds()) metrics.RecordCacheMiss() return nil, false, nil } - if len(results) == 0 { + if len(searchResult) == 0 || searchResult[0].ResultCount == 0 { atomic.AddInt64(&c.missCount, 1) - observability.Debugf("MilvusCache.FindSimilar: no entries found with responses") + observability.Debugf("MilvusCache.FindSimilar: no entries found") metrics.RecordCacheOperation("milvus", "find_similar", "miss", time.Since(start).Seconds()) metrics.RecordCacheMiss() return nil, false, nil } - // Calculate semantic similarity for each candidate - bestSimilarity := float32(-1.0) - var bestResponse string - - // Find columns by type instead of assuming order - var queryColumn *entity.ColumnVarChar - var responseColumn *entity.ColumnVarChar - var embeddingColumn *entity.ColumnFloatVector + bestScore := searchResult[0].Scores[0] + if bestScore < c.similarityThreshold { + atomic.AddInt64(&c.missCount, 1) + observability.Debugf("MilvusCache.FindSimilar: CACHE MISS - best_similarity=%.4f < threshold=%.4f", + bestScore, c.similarityThreshold) + observability.LogEvent("cache_miss", map[string]interface{}{ + "backend": "milvus", + "best_similarity": bestScore, + "threshold": c.similarityThreshold, + "model": model, + "collection": c.collectionName, + }) + metrics.RecordCacheOperation("milvus", "find_similar", "miss", time.Since(start).Seconds()) + metrics.RecordCacheMiss() + return nil, false, nil + } - for _, col := range results { - switch typedCol := col.(type) { - case *entity.ColumnVarChar: - if typedCol.Name() == "query" { - queryColumn = typedCol - } else if typedCol.Name() == "response_body" { - responseColumn = typedCol - } - case *entity.ColumnFloatVector: - if typedCol.Name() == c.config.Collection.VectorField.Name { - embeddingColumn = typedCol - } - } + // Cache Hit + var responseBody []byte + responseBodyColumn, ok := searchResult[0].Fields[0].(*entity.ColumnVarChar) + if ok && responseBodyColumn.Len() > 0 { + responseBody = []byte(responseBodyColumn.Data()[0]) } - if queryColumn == nil || responseColumn == nil || embeddingColumn == nil { - observability.Debugf("MilvusCache.FindSimilar: missing required columns in results") + if responseBody == nil { + observability.Debugf("MilvusCache.FindSimilar: cache hit but response_body is missing or not a string") atomic.AddInt64(&c.missCount, 1) metrics.RecordCacheOperation("milvus", "find_similar", "error", time.Since(start).Seconds()) metrics.RecordCacheMiss() return nil, false, nil } - for i := 0; i < queryColumn.Len(); i++ { - storedEmbedding := embeddingColumn.Data()[i] - - // Calculate dot product similarity score - var similarity float32 - for j := 0; j < len(queryEmbedding) && j < len(storedEmbedding); j++ { - similarity += queryEmbedding[j] * storedEmbedding[j] - } - - if similarity > bestSimilarity { - bestSimilarity = similarity - bestResponse = responseColumn.Data()[i] - } - } - - observability.Debugf("MilvusCache.FindSimilar: best similarity=%.4f, threshold=%.4f (checked %d entries)", - bestSimilarity, c.similarityThreshold, queryColumn.Len()) - - if bestSimilarity >= c.similarityThreshold { - atomic.AddInt64(&c.hitCount, 1) - observability.Debugf("MilvusCache.FindSimilar: CACHE HIT - similarity=%.4f >= threshold=%.4f, response_size=%d bytes", - bestSimilarity, c.similarityThreshold, len(bestResponse)) - observability.LogEvent("cache_hit", map[string]interface{}{ - "backend": "milvus", - "similarity": bestSimilarity, - "threshold": c.similarityThreshold, - "model": model, - "collection": c.collectionName, - }) - metrics.RecordCacheOperation("milvus", "find_similar", "hit", time.Since(start).Seconds()) - metrics.RecordCacheHit() - return []byte(bestResponse), true, nil - } - - atomic.AddInt64(&c.missCount, 1) - observability.Debugf("MilvusCache.FindSimilar: CACHE MISS - best_similarity=%.4f < threshold=%.4f", - bestSimilarity, c.similarityThreshold) - observability.LogEvent("cache_miss", map[string]interface{}{ - "backend": "milvus", - "best_similarity": bestSimilarity, - "threshold": c.similarityThreshold, - "model": model, - "collection": c.collectionName, - "entries_checked": queryColumn.Len(), + atomic.AddInt64(&c.hitCount, 1) + observability.Debugf("MilvusCache.FindSimilar: CACHE HIT - similarity=%.4f >= threshold=%.4f, response_size=%d bytes", + bestScore, c.similarityThreshold, len(responseBody)) + observability.LogEvent("cache_hit", map[string]interface{}{ + "backend": "milvus", + "similarity": bestScore, + "threshold": c.similarityThreshold, + "model": model, + "collection": c.collectionName, }) - metrics.RecordCacheOperation("milvus", "find_similar", "miss", time.Since(start).Seconds()) - metrics.RecordCacheMiss() - return nil, false, nil + metrics.RecordCacheOperation("milvus", "find_similar", "hit", time.Since(start).Seconds()) + metrics.RecordCacheHit() + return responseBody, true, nil } // Close releases all resources held by the cache