1、创建表

In [15]:
import psycopg2
import struct

# 1. 设置 PostgreSQL 连接
def create_connection():
    try:
        conn = psycopg2.connect(
            database="vectordb",
            user="vectordb",
            host="172.17.0.2",  # depend on your own docker container ip
            port="5432"
        )
        return conn
    except Exception as e:
        print(f"连接数据库失败: {e}")
        return None
def create_table():
    conn = create_connection()
    if not conn:
        print("无法连接到数据库，终止操作。")
        return

    try:
        cursor = conn.cursor()
        create_table_query = """
        CREATE TABLE IF NOT EXISTS deep_range (
            id int PRIMARY KEY,
            image_embedding FLOAT8[] NOT NULL,
            col_1 int NOT NULL
        );
        """
        cursor.execute(create_table_query)
        conn.commit()
        print("表 deep_range 创建成功！")
        cursor.close()
    except Exception as e:
        print(f"创建表时出错: {e}")
    finally:
        conn.close()

if __name__ == "__main__":
    create_table()

表 deep_range 创建成功！


1、插入数据

In [None]:
import psycopg2
import struct
import os

# 1. 设置 PostgreSQL 连接
def create_connection():
    try:
        conn = psycopg2.connect(
            database="vectordb",
            user="vectordb",
            host="172.17.0.2",  # depend on your own docker container ip
            port="5432"
        )
        return conn
    except Exception as e:
        print(f"连接数据库失败: {e}")
        return None

# 2. 批量插入数据
def insert_data(conn, batch_data):
    try:
        cursor = conn.cursor()
        # 假设数据库表中有足够的列来接收 col_1 到 col_19
        insert_query = "INSERT INTO deep_range (id, image_embedding, col_1) VALUES (%s, %s, %s)"
        cursor.executemany(insert_query, batch_data)
        conn.commit()
        print(f"成功插入 {len(batch_data)} 条数据！")
        cursor.close()
    except Exception as e:
        print(f"插入数据时出错: {e}")

# 3. 读取文件并分批插入数据
def load_and_insert_data(vector_file_path):
    # 打开向量文件
    with open(vector_file_path, "rb") as vector_file:
        batch_data = []
        id_counter = 0  # id 从 0 开始自增
        batch_size = 1000  # 每批次插入1000条数据

        # 创建数据库连接
        conn = create_connection()
        if not conn:
            print("无法连接到数据库，终止操作。")
            return

        while True:
            # 读取向量数据
            dim_bytes = vector_file.read(4)
            if not dim_bytes:
                break  # 文件结束
            dim = struct.unpack('i', dim_bytes)[0]
            vector = struct.unpack('f' * dim, vector_file.read(4 * dim))

            # 将解析后的数据追加到批量数据列表
            batch_data.append((id_counter, list(vector), id_counter))

            # 检查是否达到批量大小，如果是，则插入数据库并清空批量数据列表
            if len(batch_data) == batch_size:
                insert_data(conn, batch_data)
                batch_data = []

            id_counter += 1

        # 插入剩余的批量数据
        if batch_data:
            insert_data(conn, batch_data)

        conn.close()  # 关闭数据库连接

ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), "../../../data/Experiment"))
# 4. 执行数据加载和插入
vector_file_path = os.path.join(ROOT_DIR, "rangefilterData/datasets/deep/deep_base.fvecs")
load_and_insert_data(vector_file_path)

print("数据导入完成。")

2、计算索引构建时间

In [23]:
import psycopg2
from psycopg2 import OperationalError
import time
import gc
import resource
import os

# 创建数据库连接
def create_connection():
    conn = None
    try:
        conn = psycopg2.connect(
            database="vectordb",
            user="vectordb",
            host="172.17.0.2",
            port="5432"
        )
        print("连接成功！")
    except OperationalError as e:
        print(f"连接失败: {e}")
    return conn

# 执行索引构建SQL
def create_index(conn):
    cursor = conn.cursor()
    try:
        # 禁用垃圾回收
        gc.disable()
        
        # 记录开始时间
        start_time = time.time()

        # 获取索引构建前的内存占用
        start_memory = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss

        # 执行索引构建
        create_index_sql = """
            CREATE INDEX vbase_deep_16_200
            ON deep_range
            USING hnsw(image_embedding)
            WITH (
                dimension = 96,
                distmethod = 'l2_distance'
            );
        """
        cursor.execute(create_index_sql)
        conn.commit()  # 提交事务

        # 获取索引构建后的内存占用
        end_memory = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
        
        # 记录结束时间
        end_time = time.time()
        index_build_time = end_time - start_time

        print(f"索引构建完成！")
        print(f"索引构建时间: {index_build_time:.2f} 秒")


    except Exception as e:
        print(f"索引构建失败: {e}")
    finally:
        cursor.close()

def main():
    # 创建数据库连接
    connection = create_connection()

    if connection:
        try:
            create_index(connection)

        finally:
            connection.close()

if __name__ == "__main__":
    main()

连接成功！
索引构建完成！
索引构建时间: 438.92 秒


3、脚本单线程搜索

In [None]:
import psycopg2
import struct
import time
import os

# 1. 设置 PostgreSQL 连接
def create_connection():
    try:
        conn = psycopg2.connect(
            database="vectordb",
            user="vectordb",
            host="172.17.0.2",  # 你的数据库主机
            port="5432"
        )
        return conn
    except Exception as e:
        print(f"连接数据库失败: {e}")
        return None

def read_fvecs_file(file_path):
    """读取 .fvecs 文件并返回所有向量"""
    query_vectors = []
    with open(file_path, "rb") as f:
        while True:
            # 读取向量维度
            dim_bytes = f.read(4)
            if not dim_bytes:
                break
            dim = struct.unpack('i', dim_bytes)[0]
            # 读取向量数据
            vector_bytes = f.read(dim * 4)
            if not vector_bytes:
                break
            vector = struct.unpack(f'{dim}f', vector_bytes)
            query_vectors.append(vector)
    return query_vectors

def read_range_file(file_path):
    """读取范围文件并返回范围列表"""
    ranges = []
    with open(file_path, "r") as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip()
            if line:  # 跳过空行
                start, end = map(int, line.split())  # 假设范围文件每行是 "start end"
                ranges.append((start, end))
    return ranges

def execute_query(conn, query, params, output_file, j_value, total_time, k):
    cursor = conn.cursor()
    try:
        cursor.execute("SET enable_seqscan = off;")  # 设置索引扫描
        cursor.execute("SET enable_indexscan = on;")  # 设置索引扫描
        cursor.execute(f"SET hnsw.ef_search = {j_value};")  # 设置索引扫描
        # 设置索引扫描和关闭顺序扫描
        start_time = time.time()  # 开始计时
        # 执行查询
        cursor.execute(query, params)
        result = cursor.fetchall()  # 获取查询结果
        end_time = time.time()  # 结束计时
        print(f"查询执行时间: {end_time - start_time:.6f} 秒")
        if k != 0:
            total_time += (end_time - start_time)  # 累加查询时间

        # 将查询结果写入文件
        with open(output_file, "a") as f:
            ids = [str(row[0] - 1) for row in result]  # 提取 ID 列
            f.write(" ".join(ids) + "\n")  # 每次查询结果写在一行
    except Exception as e:
        print(f"查询失败: {e}")
    finally:
        cursor.close()
    return total_time

def main():
    ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), "../../../data/Experiment"))
    # 文件路径
    fvecs_file = os.path.join(ROOT_DIR, "rangefilterData/datasets/deep/deep_query.fvecs")

    list_1 = ["2", "8"]
    list_3 = [20, 30, 50, 86, 150, 250, 400, 500]

    # 读取向量数据
    query_vectors = read_fvecs_file(fvecs_file)

    for i_value in list_1:
        for j_value in list_3:
            # 创建数据库连接
            connection = create_connection()
            if connection:
                try:
                    # 根据 i_value 确定范围文件路径
                    if i_value == "2":
                        range_file = os.path.join(ROOT_DIR, "rangefilterData/query_range/deep/deep-96-euclidean_queries_2pow-2_ranges.txt")
                    elif i_value == "8":
                        range_file = os.path.join(ROOT_DIR, "rangefilterData/query_range/deep/deep-96-euclidean_queries_2pow-8_ranges.txt")
                    else:
                        print(f"未知的范围文件类型: {i_value}")
                        continue

                    output_file = os.path.join(os.getcwd(), "result", f"{i_value}_16_200_{j_value}.out")
                    ranges = read_range_file(range_file)
                    # 清空结果文件，确保每次执行时写入的是最新的查询结果
                    with open(output_file, "w") as f:
                        f.truncate(0)

                    total_time = 0
                    # 遍历范围
                    for k, (start, end) in enumerate(ranges):
                        query_vector = query_vectors[k]
                        # 构造 SQL 查询语句
                        query = f"""
                            SELECT id, image_embedding <-> ARRAY[{', '.join(map(str, query_vector))}] AS distance
                            FROM deep_range
                            WHERE col_1 BETWEEN %s AND %s
                            ORDER BY distance
                            LIMIT 10;
                        """
                        print(f"正在执行范围查询: {start} 到 {end}...")
                        total_time = execute_query(connection, query, (start, end), output_file, j_value, total_time, k)
                    qps = len(query_vectors) / total_time if total_time != 0 else 0
                    # 将 QPS 写入 output_file
                    with open(output_file, "a") as f:  # 使用追加模式
                        f.write(f"\nQPS: {qps}\n")
                finally:
                    connection.close()

if __name__ == "__main__":
    main()

4、脚本计算召回率

In [None]:
import numpy as np

def read_ivecs(fname):
    with open(fname, "rb") as f:
        data = []
        while True:
            try:
                # Read the dimension
                width = np.fromfile(f, 'int32', 1)[0]
                
                # Read the vector data
                vector = np.fromfile(f, 'int32', width)
                
                # If the vector is longer than 10, we only take the first 10 elements
                data.append(vector[:10])
            except IndexError:
                break  # End of file

    return np.array(data)
def read_txt(fname):
    """读取以空格分隔的文本文件，每行包含10个元素"""
    with open(fname, "r") as f:
        data = []
        for line in f:
            # 将每行的数字解析为整数列表
            vector = list(map(int, line.strip().split()))
            data.append(vector[:10])  # 每行只取前10个元素
    return np.array(data)

def read_output_file(fname):
    with open(fname, 'r') as f:
        lines = f.readlines()  # 读取所有行
        if len(lines) > 2:  # 如果文件行数大于2，忽略最后两行
            lines = lines[:-2]
        return [list(map(int, line.split())) for line in lines]

def extract_qps(fname):
    with open(fname, 'r') as f:
        lines = f.readlines()
        if lines:
            last_line = lines[-1].strip()  # 获取最后一行并去除首尾空格
            if "QPS:" in last_line:
                qps_value = last_line.split("QPS:")[-1].strip()  # 提取 QPS 值
                return qps_value
    return "N/A"  # 如果没有找到 QPS 值，返回 N/A

ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), "../../../data/Experiment"))
recall_file = os.path.join(os.getcwd(), "result", f"single_qps.out")
# 打开文件用于写入结果
with open(recall_file, "a") as output_file:
    list_1 = ["2", "8"]
    list_2 = [20, 30, 50, 86, 150, 250, 400, 500]

    for i_value in list_1:
        for j_value in list_2:
            # Read ground truth
            gt_file = os.path.join(ROOT_DIR, f"rangefilterData/gt/deep/gt-query_set_{i_value}.ivecs")
            gt_data = read_ivecs(gt_file)

            # Read result file
            result_file = os.path.join(os.getcwd(), "result", f"{i_value}_16_200_{j_value}.out")
            result_data = read_output_file(result_file)

            # Extract QPS value
            qps_value = extract_qps(result_file)

            # Calculate recall
            total_queries = len(gt_data)
            correct_matches = 0

            # For each query in ground truth
            for i, gt_row in enumerate(gt_data):
                # Check if any of the first 10 elements from ground truth are in the result's top 10
                # correct_matches += sum(1 for gt_val in gt_row if gt_val in result_data[i][:10])
                # 将结果文件的值加 1 后再与 ground truth 比较
                result_row = [val + 1 for val in result_data[i][:10]]
                correct_matches += sum(1 for gt_val in gt_row if gt_val in result_row)

            # Recall calculation
            recall = correct_matches / (total_queries * 10)  # Since each query has 10 elements to match

            # 写入结果到文件
            output_file.write(f"Recall Rate {i_value:<6} and {j_value:<4}: {recall:>6.4f}, QPS: {float(qps_value):>8.2f}\n")