1、建表

In [4]:
from pymilvus import MilvusClient, DataType
import time

# 1. Set up a Milvus client
client = MilvusClient(
    uri="http://222.20.98.71:19530"
)

# 3. Create a collection in customized setup mode
# 3.1. Create schema
schema = MilvusClient.create_schema(
    auto_id=False,
    enable_dynamic_field=True,
)
# 3.2. Add fields to schema
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)  # 主键字段
schema.add_field(field_name="image_embedding", datatype=DataType.FLOAT_VECTOR, dim=200)  # 向量字段
schema.add_field(field_name="col_1", datatype=DataType.INT64)  # 整数字段



# 3.3 Prepare index parameters
index_params = client.prepare_index_params()
# 3.4 Add indexs

index_params.add_index(
    field_name="image_embedding", 
    index_type="IVF_FLAT",
    metric_type="L2",
    params={
        "nlist": 3000
    }
)

client.create_collection(
    collection_name="t2i_label",
    schema=schema,
    index_params=index_params
)

time.sleep(1)   # 休眠一秒钟

res = client.get_load_state(
    collection_name="t2i_label"
)

print(res)


{'state': <LoadState: Loaded>}


2、插入数据

In [None]:
import struct
from pymilvus import MilvusClient
import os

# 1. 设置 Milvus 客户端
client = MilvusClient(
    uri="http://192.168.191.160:19530"
)

# 2. 定义集合名称和批量插入的批量大小
collection_name = "t2i_label"
batch_size = 1000

# 3. 解析文件并分批插入数据
def load_and_insert_data(vector_file_path, label_file_path):
    # 打开向量文件和标签文件
    with open(vector_file_path, "rb") as vector_file, open(label_file_path, "r") as label_file:
        batch_data = []
        id_counter = 1  # id 从 1 开始自增

        # 跳过标签文件的第一行
        next(label_file)

        while True:
            # 读取向量数据
            dim = vector_file.read(4)
            if not dim:
                break  # 文件结束
            dim = struct.unpack('i', dim)[0]
            vector = struct.unpack('f' * dim, vector_file.read(4 * dim))

            # 读取标签数据
            label_line = next(label_file, None)
            if label_line is None:
                break  # 文件结束
            label_line = label_line.strip()

            # 将解析后的数据追加到批量数据列表
            data_entry = {
                "id": id_counter,  # 使用自增的 id
                "image_embedding": list(vector),  # 向量数据
                "col_1": int(label_line)  # 单个标签值
            }

            batch_data.append(data_entry)

            # id 自增
            id_counter += 1

            # 每批次插入一次数据
            if len(batch_data) >= batch_size:
                client.insert(collection_name=collection_name, data=batch_data)
                print(f"插入了 {len(batch_data)} 行数据。")
                batch_data = []

        # 插入剩余数据
        if batch_data:
            client.insert(collection_name=collection_name, data=batch_data)
            print(f"插入了 {len(batch_data)} 行数据。")

ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), "../../../data/Experiment"))
# 4. 执行数据加载和插入
vector_file_path = os.path.join(ROOT_DIR, "labelfilterData/datasets/audio/audio_base.fvecs")
label_file_path = os.path.join(ROOT_DIR, "labelfilterData/labels/text2image/label_1.txt")
load_and_insert_data(vector_file_path, label_file_path)

print("数据导入完成。")

# 建索引前，文件大小11533

3、单线程脚本测试

In [None]:
from pymilvus import MilvusClient
import time
import numpy as np
import psutil  # 用于监控进程内存
import os

# 获取当前进程
process = psutil.Process(os.getpid())
peak_memory_mb = 0  # 用于记录峰值内存（单位 MB）

# 定义一个函数来更新峰值内存
def update_peak_memory():
    global peak_memory_mb
    memory_info = process.memory_info()
    current_memory_mb = memory_info.rss / 1024 / 1024  # 将字节转换为 MB
    peak_memory_mb = max(peak_memory_mb, current_memory_mb)

# 1. Set up a Milvus client
client = MilvusClient(
    uri="http://192.168.191.160:19530"
)

# 2. Function to read fvecs file
def read_fvecs(file_path, num_vectors):
    data = []
    with open(file_path, 'rb') as f:
        for _ in range(num_vectors):
            dim = np.frombuffer(f.read(4), dtype=np.int32)[0]
            vector = np.frombuffer(f.read(dim * 4), dtype=np.float32)
            data.append(vector.tolist())
    update_peak_memory()  # 检查内存
    return data

def read_txt_file(file_path):
    """读取 .txt 文件并返回查询条件列表"""
    conditions = []
    with open(file_path, "r") as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip()
            if line:  # 跳过空行
                # 用空格分割，获取两个值作为范围
                parts = line.split()
                if len(parts) == 2:
                    conditions.append((int(parts[0]), int(parts[1])))  # 存储范围值
    return conditions

list_1 = ["2"]
list_3 = [50 ,80 ,100 ,150 ,200 ,300 ,500 ,1000]

num_vectors = 10000
ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), "../../../data/Experiment"))
    # 文件路径
fvecs_file = os.path.join(ROOT_DIR, "labelfilterData/datasets/audio/audio_query.fvecs")
data = read_fvecs(fvecs_file, num_vectors)
# File path and number of vectors to load
for i, i_value in enumerate(list_1):
    for j_value in list_3:
        txt_file = os.path.join(ROOT_DIR, f"labelfilterData/query_label/text2image/{i_value}.txt")
        output_file = os.path.join(os.getcwd(), "result", f"{i_value}_16_200_{j_value}.out")
        conditions = read_txt_file(txt_file)
        print(f"Loaded {len(data)} vectors and {len(conditions)} filter conditions.")

        # Ensure the number of filters matches the number of query vectors
        if len(conditions) != len(data):
            raise ValueError("The number of filters must match the number of query vectors.")

        # Load the collection
        client.load_collection(collection_name="t2i_label")

        # List to store all results
        all_results = []

        # Start timing
        start_time_1 = time.perf_counter()

        # Perform search for each vector
        for k in range(len(conditions)):
            lower_bound, upper_bound = conditions[k]
            conditions_sql = f"id >= {lower_bound} && id <= {upper_bound}"
            start_time = time.perf_counter()
            res = client.search(
                collection_name="t2i_label",
                data=[data[k]],
                filter=conditions_sql,
                anns_field="image_embedding",
                limit=10,
                search_params={"metric_type": "L2", "params": {"nprobe": j_value}},
                output_fields=["id"]
            )
            end_time = time.perf_counter()
            print(f"Search completed in {end_time - start_time} seconds.")
            # Extract ids from the result structure
            result_ids = []
            for item in res[0]:
                result_ids.append(str(item["id"] - 1))

            # Join the IDs with space and add to all_results
            all_results.append(" ".join(result_ids))
            update_peak_memory()  # 在每次搜索后检查内存

        # End timing   
        end_time_1 = time.perf_counter()
        # Write all results to the output file
        with open(output_file, 'w') as f_out:
            # 写入搜索结果
            for result in all_results:
                f_out.write(result + "\n")
            # 追加 QPS 和 Peak RES memory
            f_out.write(f"QPS: {10000 / (end_time_1 - start_time_1):.2f}")
            f_out.write(f"Peak RES memory usage: {peak_memory_mb:.2f} MB\n")


Loaded 10000 vectors and 10000 filter conditions.
Search completed in 0.007511022035032511 seconds.
Search completed in 0.00803903816267848 seconds.
Search completed in 0.007604395970702171 seconds.
Search completed in 0.007705677999183536 seconds.
Search completed in 0.007769239833578467 seconds.
Search completed in 0.006352720083668828 seconds.
Search completed in 0.0060203298926353455 seconds.
Search completed in 0.006483841920271516 seconds.
Search completed in 0.008961732964962721 seconds.
Search completed in 0.005574007984250784 seconds.
Search completed in 0.006483352975919843 seconds.
Search completed in 0.006192775210365653 seconds.
Search completed in 0.006137422984465957 seconds.
Search completed in 0.007081619929522276 seconds.
Search completed in 0.006713980110362172 seconds.
Search completed in 0.006205965997651219 seconds.
Search completed in 0.006113593000918627 seconds.
Search completed in 0.006063461070880294 seconds.
Search completed in 0.006713188951835036 seconds.


4、单线程脚本计算召回率和QPS

In [None]:
import numpy as np

def read_ivecs(fname):
    with open(fname, "rb") as f:
        data = []
        while True:
            try:
                # Read the dimension
                width = np.fromfile(f, 'int32', 1)[0]
                
                # Read the vector data
                vector = np.fromfile(f, 'int32', width)
                
                # If the vector is longer than 10, we only take the first 10 elements
                data.append(vector[:10])
            except IndexError:
                break  # End of file

    return np.array(data)
def read_txt(file_path):
    """
    读取 .txt 文件，每行包含 10 个元素，返回二维列表
    """
    data = []
    with open(file_path, "r") as f:
        for line in f:
            line = line.strip()
            if line:  # 跳过空行
                # 将每行的 10 个元素分割并转换为整数
                elements = list(map(int, line.split()))
                if len(elements) == 10:  # 确保每行有 10 个元素
                    data.append(elements)
                else:
                    raise ValueError(f"文件格式错误：每行必须包含 10 个元素，但发现 {len(elements)} 个元素")
    return data

def read_output_file(fname):
    with open(fname, 'r') as f:
        lines = f.readlines()  # 读取所有行
        if len(lines) > 1 and ("QPS:" in lines[-1] or "Peak" in lines[-1]):  # 如果文件行数大于1，忽略最后一行
            lines = lines[:-1]
        return [list(map(int, line.split())) for line in lines]

def extract_qps(fname):
    with open(fname, 'r') as f:
        lines = f.readlines()
        if lines:
            last_line = lines[-1].strip()  # 获取最后一行并去除首尾空格
            if "QPS:" in last_line:
                qps_part = last_line.split("QPS:")[1].split("Peak")[0].strip()  # 提取 QPS 值
                return qps_part
    return "N/A"  # 如果没有找到 QPS 值，返回 N/A

def extract_res(fname):
    with open(fname, 'r') as f:
        lines = f.readlines()
        if lines:
            last_line = lines[-1].strip()  # 获取最后一行并去除首尾空格
            if "Peak RES memory usage:" in last_line:
                res_part = last_line.split("Peak RES memory usage:")[1].split("MB")[0].strip()  # 提取 RES 值
                return res_part
    return "N/A"  # 如果没有找到 RES 值，返回 N/A

# 定义结果文件路径
ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), "../../../data/Experiment"))
recall_file = os.path.join(os.getcwd(), "result", f"single_qps.out")

# 打开文件用于写入结果
with open(recall_file, "a") as output_file:
    list_1 = ["2"]
    list_2 = [50 ,80 ,100 ,150 ,200 ,300 ,500 ,1000]

    for i_value in list_1:
        for j_value in list_2:
            # Read ground truth
            gt_file = os.path.join(ROOT_DIR, f"labelfilterData/gt/text2image/gt-query_set_{i_value}.ivecs")
            gt_data = read_ivecs(gt_file)

            # Read result file
            result_file = os.path.join(os.getcwd(), "result", f"{i_value}_16_200_{j_value}.out")
            result_data = read_output_file(result_file)

            # Extract QPS value
            qps_value = extract_qps(result_file)
            res = extract_res(result_file)

            # Calculate recall
            correct_matches = 0
            valid_queries = min(len(gt_data), len(result_data))  # Only process queries with both gt and result data
            
            # For each query in ground truth that has corresponding result data
            for i in range(valid_queries):
                # Check if any of the first 10 elements from ground truth are in the result's top 10
                correct_matches += sum(1 for gt_val in gt_data[i] if gt_val in result_data[i][:10])

            # Recall calculation (using valid queries count)
            recall = correct_matches / (valid_queries * 10)  # Since each query has 10 elements to match

            # Print for debugging
            print(f"Processing {i_value}_{j_value}: GT size={len(gt_data)}, Result size={len(result_data)}, Valid queries={valid_queries}")
            
            # 写入结果到文件
            output_file.write(f"Recall Rate {i_value:<6} and {j_value:<4}: {recall:>6.4f}, QPS: {float(qps_value) if qps_value != 'N/A' else 0:>8.2f}, RES: {float(res) if res != 'N/A' else 0:>8.2f} MB\n")