diff --git a/examples/http_client/main.go b/examples/http_client/main.go index 092c7a7..ee331b5 100644 --- a/examples/http_client/main.go +++ b/examples/http_client/main.go @@ -125,7 +125,7 @@ func main() { } // Discover tools - tools, err := client.SearchTools("", 10) + tools, err := client.SearchTools("http", 10) if err != nil { log.Fatalf("search: %v", err) } diff --git a/src/tag/tag_search.go b/src/tag/tag_search.go index 96a2d77..5336b9c 100644 --- a/src/tag/tag_search.go +++ b/src/tag/tag_search.go @@ -30,7 +30,7 @@ func NewTagSearchStrategy(repo ToolRepository, descriptionWeight float64) *TagSe // SearchTools returns tools ordered by relevance to the query, using explicit tags and description keywords. func (s *TagSearchStrategy) SearchTools(ctx context.Context, query string, limit int) ([]Tool, error) { // Normalize query - queryLower := strings.ToLower(query) + queryLower := strings.ToLower(strings.TrimSpace(query)) words := s.wordRegex.FindAllString(queryLower, -1) queryWordSet := make(map[string]struct{}, len(words)) for _, w := range words { @@ -43,57 +43,74 @@ func (s *TagSearchStrategy) SearchTools(ctx context.Context, query string, limit return nil, err } - // SUTCP each tool - type sUTCPdTool struct { - t Tool - sUTCP float64 + // Compute SUTCP score for each tool + type scoredTool struct { + tool Tool + score float64 } - var sUTCPd []sUTCPdTool + var scored []scoredTool + for _, t := range tools { - var sUTCP float64 + var score float64 - // SUTCP from tags + // Match against tags for _, tag := range t.Tags { tagLower := strings.ToLower(tag) + + // Direct substring match if strings.Contains(queryLower, tagLower) { - sUTCP += 1.0 + score += 1.0 } - // Partial matches on tag words + + // Word-level overlap tagWords := s.wordRegex.FindAllString(tagLower, -1) for _, w := range tagWords { if _, ok := queryWordSet[w]; ok { - sUTCP += s.descriptionWeight + score += s.descriptionWeight } } } - // SUTCP from description + // Match against description if t.Description != "" { descWords := s.wordRegex.FindAllString(strings.ToLower(t.Description), -1) for _, w := range descWords { if len(w) > 2 { if _, ok := queryWordSet[w]; ok { - sUTCP += s.descriptionWeight + score += s.descriptionWeight } } } } - sUTCPd = append(sUTCPd, sUTCPdTool{t: t, sUTCP: sUTCP}) + scored = append(scored, scoredTool{tool: t, score: score}) } - // Sort by descending sUTCP - sort.Slice(sUTCPd, func(i, j int) bool { - return sUTCPd[i].sUTCP > sUTCPd[j].sUTCP + // Sort descending by score + sort.Slice(scored, func(i, j int) bool { + return scored[i].score > scored[j].score }) - // Return up to limit + // Collect only positive matches var result []Tool - for i, st := range sUTCPd { - if i >= limit { - break + for _, st := range scored { + if st.score > 0 { + result = append(result, st.tool) + if len(result) >= limit { + break + } } - result = append(result, st.t) } + + // If no matches, fallback to top N (for discoverability) + if len(result) == 0 && len(scored) > 0 { + for i, st := range scored { + if i >= limit { + break + } + result = append(result, st.tool) + } + } + return result, nil } diff --git a/utcp_client.go b/utcp_client.go index 0387d89..74ce758 100644 --- a/utcp_client.go +++ b/utcp_client.go @@ -517,26 +517,21 @@ func (c *UtcpClient) CallTool( return fn(ctx, args) } -func (c *UtcpClient) SearchTools(query string, limit int) ([]Tool, error) { - tools, err := c.searchStrategy.SearchTools(context.Background(), query, limit) +func (c *UtcpClient) SearchTools(providerPrefix string, limit int) ([]Tool, error) { + all, err := c.toolRepository.GetTools(context.Background()) if err != nil { return nil, err } - - // Convert []*Tool to []Tool if needed - result := make([]Tool, len(tools)) - for i, tool := range tools { - switch t := any(tool).(type) { - case Tool: - result[i] = t - case *Tool: - result[i] = *t - default: - // fallback (shouldn't happen) - result[i] = Tool{} + var filtered []Tool + for _, t := range all { + if strings.HasPrefix(t.Name, providerPrefix+".") { + filtered = append(filtered, t) } } - return result, nil + if len(filtered) == 0 { + return nil, fmt.Errorf("no tools found for provider %q", providerPrefix) + } + return filtered, nil } // ----- variable substitution src -----