In [7]:
import kfp
from kfp.dsl import component, pipeline

## Component to query Milvus DB

In [8]:
@component(
    base_image='python:3.9',
    packages_to_install=["pymilvus", "numpy"],
)
def query_milvus(
    milvus_host: str = 'standalone-milvus.milvus.svc.cluster.local',
    milvus_port: int = 19530,
    collection_name: str = 'rag_embeddings'
):
    from pymilvus import connections, Collection
    import numpy as np

    # Connect to Milvus
    connections.connect(alias="default", host=milvus_host, port=milvus_port)
    print(f"Connected to Milvus at {milvus_host}:{milvus_port}")

    collection = Collection(collection_name)
    collection.load()

    query_embedding = np.random.rand(768).tolist()

    results = collection.search(
        data=[query_embedding],
        anns_field="embedding",
        param={"metric_type": "L2", "params": {"nprobe": 10}},
        limit=5,
        output_fields=["chunk_id", "text", "source_file"]
    )

    for hit in results[0]:
        print(f"Chunk ID: {hit.entity.get('chunk_id')}")
        print(f"Text: {hit.entity.get('text')}")
        print(f"Source File: {hit.entity.get('source_file')}")
        print(f"Distance: {hit.distance}\n")

## Execute the component

In [9]:
@pipeline(name='query-milvus-pipeline')
def query_milvus_pipeline():
    step1 = query_milvus()

client = kfp.Client()
kfp.compiler.Compiler().compile(query_milvus_pipeline, './pipelines-yaml/query_milvus_pipeline.yaml')
run = client.create_run_from_pipeline_func(query_milvus_pipeline, arguments={}, enable_caching=False)

