1、插入数据

In [None]:
import psycopg2
import struct
import os
# 1. 设置 PostgreSQL 连接
def create_connection():
    try:
        conn = psycopg2.connect(
            database="postgres",
            user="postgres",
            host="222.20.98.71",  # 你的数据库主机
            port="5432"
        )
        return conn
    except Exception as e:
        print(f"连接数据库失败: {e}")
        return None

# 2. 批量插入数据
def insert_data(conn, batch_data):
    try:
        cursor = conn.cursor()
        # 假设数据库表中有足够的列来接收 col_1 到 col_19
        insert_query = "INSERT INTO sift_label (id, image_embedding, col_1, col_2, col_3, col_4, col_5, col_6, col_7, col_8, col_9, col_10, col_11, col_12, col_13, col_14, col_15, col_16, col_17, col_18, col_19) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)"
        cursor.executemany(insert_query, batch_data)
        conn.commit()
        print(f"成功插入 {len(batch_data)} 条数据！")
        cursor.close()
    except Exception as e:
        print(f"插入数据时出错: {e}")

# 3. 读取文件并分批插入数据
def load_and_insert_data(vector_file_path, label_file_path):
    # 打开向量文件和标签文件
    with open(vector_file_path, "rb") as vector_file, open(label_file_path, "r") as label_file:
        batch_data = []
        id_counter = 1  # id 从 1 开始自增
        batch_size = 1000  # 每批次插入1000条数据

        # 创建数据库连接
        conn = create_connection()
        if not conn:
            print("无法连接到数据库，终止操作。")
            return

        # 不再跳过第一行
        for label_line in label_file:
            # 读取向量数据
            dim_bytes = vector_file.read(4)
            if not dim_bytes:
                break  # 文件结束
            dim = struct.unpack('i', dim_bytes)[0]
            vector = struct.unpack('f' * dim, vector_file.read(4 * dim))

            # 解析标签行
            label_values = label_line.strip().split()
            if len(label_values) != 19:
                print(f"警告：标签行中数据数量不正确（应为19个），实际数量为 {len(label_values)}。跳过该行。")
                continue

            # 将解析后的数据追加到批量数据列表
            batch_data.append((id_counter, list(vector), *label_values))

            # 检查是否达到批量大小，如果是，则插入数据库并清空批量数据列表
            if len(batch_data) == batch_size:
                insert_data(conn, batch_data)
                batch_data = []

            id_counter += 1

        # 插入剩余的批量数据
        if batch_data:
            insert_data(conn, batch_data)

        conn.close()  # 关闭数据库连接

ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), "../../../data/Experiment"))
# 4. 执行数据加载和插入
vector_file_path = os.path.join(ROOT_DIR, "labelfilterData/datasets/sift/sift_base.fvecs")
label_file_path = os.path.join(ROOT_DIR, "labelfilterData/labels/sift/labels_with_selectivity.txt")
load_and_insert_data(vector_file_path, label_file_path)

print("数据导入完成。")

2、索引构建时间

In [None]:
import psycopg2
from psycopg2 import OperationalError
import time

# 创建数据库连接
def create_connection():
    conn = None
    try:
        conn = psycopg2.connect(
            database="postgres",
            user="postgres",
            host="222.20.98.71",
            port="5432"
        )
        print("连接成功！")
    except OperationalError as e:
        print(f"连接失败: {e}")
    return conn

# 执行索引构建SQL
def create_index(conn):
    cursor = conn.cursor()
    try:
        
        # 记录开始时间
        start_time = time.time()
        
        # 执行索引构建
        create_index_sql = """
            CREATE INDEX pase_sift_16_200
            ON sift_label
            USING pase_hnsw(image_embedding)
            WITH (dim = 128, base_nb_num = 16, ef_build = 200, ef_search = 200, base64_encoded = 0);
        """
        cursor.execute(create_index_sql)
        conn.commit()  # 提交事务

        # 记录结束时间
        end_time = time.time()
        index_build_time = end_time - start_time
        
        print(f"索引构建完成！")
        print(f"索引构建时间: {index_build_time:.2f} 秒")

    except Exception as e:
        print(f"索引构建失败: {e}")
    finally:
        cursor.close()

def main():
    # 创建数据库连接
    connection = create_connection()
    
    if connection:
        try:
            # 执行索引构建
            create_index(connection)
        finally:
            connection.close()

if __name__ == "__main__":
    main()

3、单线程测试脚本

In [None]:
import psycopg2
import struct
import time


# 1. 设置 PostgreSQL 连接
def create_connection():
    try:
        conn = psycopg2.connect(
            database="postgres",
            user="postgres",
            host="222.20.98.71",  # 你的数据库主机
            port="5432"
        )
        return conn
    except Exception as e:
        print(f"连接数据库失败: {e}")
        return None
    
def read_fvecs_file(file_path):
    """读取 .fvecs 文件并返回所有向量"""
    query_vectors = []
    with open(file_path, "rb") as f:
        while True:
            # 读取向量维度
            dim_bytes = f.read(4)
            if not dim_bytes:
                break
            dim = struct.unpack('i', dim_bytes)[0]
            # 读取向量数据
            vector_bytes = f.read(dim * 4)
            if not vector_bytes:
                break
            vector = struct.unpack(f'{dim}f', vector_bytes)
            query_vectors.append(vector)
    return query_vectors
def read_txt_file(file_path):
    """读取 .txt 文件并返回查询条件列表"""
    conditions = []
    with open(file_path, "r") as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip()
            if line:  # 跳过空行
                # 用空格分割，保留所有值作为条件
                parts = line.split()
                conditions.append(parts)  # 存储所有条件的值
    return conditions
def execute_query(conn, query, params, output_file, j_value, total_time, k):
    cursor = conn.cursor()
    try:
        cursor.execute("SET enable_seqscan = off;")  # 设置索引扫描
        cursor.execute("SET enable_indexscan = on;")  # 设置索引扫描
        cursor.execute(f"SET hnsw.ef_search = {j_value};")  # 设置索引扫描
        # 设置索引扫描和关闭顺序扫描
        start_time = time.time()  # 开始计时
        # 执行查询
        cursor.execute(query, params)
        result = cursor.fetchall()  # 获取查询结果
        end_time = time.time()  # 结束计时
        print(f"查询执行时间: {end_time - start_time:.6f} 秒")
        if k != 0:
            total_time += (end_time - start_time)  # 累加查询时间

        # 将查询结果写入文件
        with open(output_file, "a") as f:
            ids = [str(row[0] - 1) for row in result]  # 提取 ID 列
            f.write(" ".join(ids) + "\n")  # 每次查询结果写在一行
    except Exception as e:
        print(f"查询失败: {e}")
    finally:
        cursor.close()
    return total_time
def main():
    ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), "../../../data/Experiment"))
    # 文件路径
    fvecs_file = os.path.join(ROOT_DIR, "labelfilterData/datasets/sift/sift_query.fvecs")

    list_1 = ["1", "3_1", "3_2", "3_3", "3_4", "4", "5_1", "5_2", "5_3", "5_4"]
    list_2 = [["col_1"], ["col_16"], ["col_17"], ["col_18"], ["col_19"], ["col_8"], ["col_2"], ["col_3"], ["col_5"], ["col_1"]]
    list_3 = [100, 150, 200, 250, 400, 1000]


    # 读取向量数据和查询条件
    query_vectors = read_fvecs_file(fvecs_file)
   
    for i, i_value in enumerate(list_1):
        for j_value in list_3:
            # 创建数据库连接
            connection = create_connection()
            if connection:
                try:
                    txt_file = os.path.join(ROOT_DIR, f"labelfilterData/query_label/sift/{i_value}.txt")
                    output_file = os.path.join(os.getcwd(), "result", f"{i_value}_16_200_{j_value}.out")
                    conditions = read_txt_file(txt_file)
                    # 清空结果文件，确保每次执行时写入的是最新的查询结果
                    with open(output_file, "w") as f:
                        f.truncate(0)

                        # 定义数据库表的列名，与条件对应
                        columns = list_2[i]
                        total_time = 0
                        # 遍历查询条件
                        for k in range(len(conditions)):
                        # for k in range(len(conditions)):
                            condition_values = conditions[k]
                            query_vector = query_vectors[k]
                            # 构造 SQL 查询语句
                            # 动态生成 WHERE 条件
                            if len(condition_values) == 1:
                                conditions_sql = f"{columns[0]} = %s"
                            else:
                                conditions_sql = " AND ".join([f"{col} = %s" for col in columns[:len(condition_values)]])
                            query = f"""
                                SELECT id
                                FROM sift_label
                                WHERE {conditions_sql}
                                ORDER BY image_embedding <?> '{','.join(map(str, query_vector))}'::pase ASC
                                LIMIT 10;
                            """
                            print(f"正在执行查询条件: {condition_values}...")
                            total_time = execute_query(connection, query, tuple(condition_values), output_file, j_value, total_time, k)
                        qps = len(conditions) / total_time if total_time != 0 else 0
                        # 将 QPS 写入 output_file
                        with open(output_file, "a") as f:  # 使用追加模式
                            f.write(f"\nQPS: {qps}\n")
                finally:
                    connection.close()
    
if __name__ == "__main__":
    main()

4、召回率计算脚本

In [None]:
import numpy as np

def read_ivecs(fname):
    with open(fname, "rb") as f:
        data = []
        while True:
            try:
                # Read the dimension
                width = np.fromfile(f, 'int32', 1)[0]
                
                # Read the vector data
                vector = np.fromfile(f, 'int32', width)
                
                # If the vector is longer than 10, we only take the first 10 elements
                data.append(vector[:10])
            except IndexError:
                break  # End of file

    return np.array(data)

def read_output_file(fname):
    with open(fname, 'r') as f:
        lines = f.readlines()  # 读取所有行
        if len(lines) > 2:  # 如果文件行数大于2，忽略最后两行
            lines = lines[:-2]
        return [list(map(int, line.split())) for line in lines]

def extract_qps(fname):
    with open(fname, 'r') as f:
        lines = f.readlines()
        if lines:
            last_line = lines[-1].strip()  # 获取最后一行并去除首尾空格
            if "QPS:" in last_line:
                qps_value = last_line.split("QPS:")[-1].strip()  # 提取 QPS 值
                return qps_value
    return "N/A"  # 如果没有找到 QPS 值，返回 N/A

# 定义结果文件路径
ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), "../../../data/Experiment"))
recall_file = os.path.join(os.getcwd(), "result", f"single_qps.out")

# 打开文件用于写入结果
with open(recall_file, "w") as output_file:
    list_1 = ["1", "3_1", "3_2", "3_3", "3_4", "4", "5_1", "5_2", "5_3", "5_4"]
    list_2 = [100, 150, 200, 250, 400, 1000]

    for i_value in list_1:
        for j_value in list_2:
            # Read ground truth
            gt_file = os.path.join(ROOT_DIR, f"labelfilterData/gt/sift/gt-query_set_{i_value}.ivecs")
            gt_data = read_ivecs(gt_file)

            # Read result file
            result_file = os.path.join(os.getcwd(), "result", f"{i_value}_16_200_{j_value}.out")
            result_data = read_output_file(result_file)

            # Extract QPS value
            qps_value = extract_qps(result_file)

            # Calculate recall
            total_queries = len(gt_data)
            correct_matches = 0

            # For each query in ground truth
            for i, gt_row in enumerate(gt_data):
                # Check if any of the first 10 elements from ground truth are in the result's top 10
                correct_matches += sum(1 for gt_val in gt_row if gt_val in result_data[i][:10])

            # Recall calculation
            recall = correct_matches / (total_queries * 10)  # Since each query has 10 elements to match

            # 写入结果到文件
            output_file.write(f"Recall Rate {i_value:<6} and {j_value:<4}: {recall:>6.4f}, QPS: {float(qps_value):>8.2f}\n")