# Lindorm 多模态检索

lindorm 实例利用大模型实现多模态检索

图片一键入库

多模态检索

    以图搜图
    以文搜图
    增量图片去重导入 


## PicturesImport - 批量导入图片
### 数据源介绍
    数据源: https://www.kaggle.com/datasets/iamsouravbanerjee/animal-image-dataset-90-different-animals
    Zooming in on Wildlife: 5400 Animal Images Across 90 Diverse Classes
    
    下载测试数据，解压后，在data目录下
    
### 数据入库
  首先配置待导入的数据目录

  目录中所有图片数据经过 lindorm ai 引擎进行向量化，将向量写入lindorm

In [None]:
import ipywidgets as widgets

# 创建输入框和按钮
dir_input = widgets.Text(
    value='data/animals',
    placeholder='data/animals',
    description='图片目录:',
    disabled=False,
    layout=widgets.Layout(width='auto')
)

display(dir_input)

In [None]:
import os
# -*- coding: utf-8 -*-
import time
import concurrent
from tqdm import tqdm

from concurrent.futures import ThreadPoolExecutor
from src.lindorm import Lindorm

def handle_picture(file_path: str):
    # 读取本地图片内容
    with open(file_path, 'rb') as f:
        content = f.read()
        # print(f"读取本地图片内容 {file_path}, size {len(content)}")
        retry_cnt = 5
        while True and retry_cnt > 0:
            code, embedding = lindorm.picture_embedding(content)
            if code != 0:
                print(f"图文向量化失败 {file_path}, error {embedding}")
                time.sleep(1)
                retry_cnt -= 1
            else:
                return lindorm.write_doc(file_path, embedding)
        return None


def find_jpg_files(dir: str):
    jps_paths = []
    for root, dirs, files in os.walk(dir):
        for file in files:
            if file.endswith(".jpg"):
                jps_paths.append(os.path.join(root, file))
    return jps_paths
            

def import_all_data(dir: str):
    # 遍历目录下所有图片
    print('导入数据目录:', dir)
    file_paths = find_jpg_files(dir)
    
    with ThreadPoolExecutor(max_workers=8) as executor:
        # 创建一个任务列表
        future_to_record = {executor.submit(handle_picture, key): key for key in file_paths}
        
        # 使用 tqdm 显示进度条
        for future in tqdm(concurrent.futures.as_completed(future_to_record), total=len(file_paths), desc="Importing Data"):
            result = future.result()
    print("向量化后的数据入库完成")


if __name__ == '__main__':
    print("start")
    lindorm = Lindorm()
    
    if lindorm.get_index() is not None:
        print("索引已存在, 删除索引")
        lindorm.drop_index()
    lindorm.create_search_index()
    print("索引创建完成")
    import_all_data(dir_input.value)

## 交互式多模态检索 

提供以图搜图和以文本搜图两种检索方式

以图搜图: 输入图片 url, 返回与输入图片相似的图片
以文本搜图: 输入文本, 返回与输入文本相似的图片

In [None]:
import io
import ipywidgets as widgets
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import display, HTML, clear_output
from src.lindorm import Lindorm
    
# 创建输入框和按钮
text_input = widgets.Text(
    value='',
    placeholder='输入关键字',
    description='以文搜图:',
    disabled=False
)

file_input = widgets.FileUpload(
    accept='image/*',
    description='以图搜图',
    multiple=False  # 只允许选择一张图片
)

# 定义搜索按钮
search_buttons = [
    widgets.Button(description="纯向量检索", button_style=''),
    # widgets.Button(description="RRF融合检索", button_style='', disabled=True),
]

output = widgets.Output()

lindorm = Lindorm()

def show_hits(hits):
    if hits is None or len(hits) == 0:
        print("图片搜索失败")
        return
    print('search, count', len(hits))
    for hit in hits:
        print(hit.get('_id'), hit.get('_score'))
    # 创建一个 nx3 的子图
    n = int((lindorm.top_k + 2) / 3)
    fig, axes = plt.subplots(n, 3, figsize=(25, 15))
    # 遍历图片 URL
    for ax, hit in zip(axes.flatten(), hits):
        url = hit.get('_id')
        # 在子图中显示图片
        img = Image.open(url)
        # print(url, img)
        ax.imshow(img)
        ax.axis('off')  # 隐藏坐标轴
    plt.tight_layout()  # 调整布局
    plt.show()

# 定义按钮点击事件处理函数
def on_button_clicked(b):
    with output:
        clear_output()  # 清除上次输出
        if text_input.value:
            code, embedding = lindorm.text_embedding(text_input.value)
        else:
            print("请先输入描述文字")
            return
        
        if b.description == "纯向量检索":
            hits = lindorm.knn_search(embedding)
        elif b.description == "RRF融合检索":
            if text_input.value == '':
                print("请先输入关键字")
                return
            hits = lindorm.rrf_search(text_input.value, embedding)
        print("文本搜图")
        show_hits(hits)
        
# 定义图片显示函数
def show_uploaded_image(change):
    with output:
        clear_output()  # 清除上次输出
        if change['new']:
            # 获取上传的文件内容
            uploaded_file = change['new'][0]
            content = uploaded_file['content']
            # 使用 PIL 打开图片
            image = Image.open(io.BytesIO(content))
            # 使用 matplotlib 显示图片
            plt.imshow(image)
            plt.axis('off')  # 不显示坐标轴
            plt.show()
            code, embedding = lindorm.picture_embedding(content)
            hits = lindorm.knn_search(embedding)
            print("图片上传检索")
            show_hits(hits)
            
        
for button in search_buttons:
    button.on_click(lambda b: on_button_clicked(b))
    
# 绑定文件上传事件
file_input.observe(show_uploaded_image, names='value')

display(file_input, text_input, *search_buttons, output)

## 增量图片去重导入

通过以图搜图，提供图片自动去重功能，相似度 >= 阈值的图片不导入，并返回相似照片

相似度[0, 1], 越大越相似， 1 表示完全相同

In [None]:
import os
import requests
import uuid
import io
import ipywidgets as widgets
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import display, clear_output
from src.lindorm import Lindorm
from src import Config

lindorm = Lindorm()

# 创建一个浮点数输入框，并设定默认值为 0.95
float_input = widgets.FloatText(
    value=0.95,
    placeholder=0.95,
    min=0.0,     # 最小值
    max=1.0,     # 最大值
    step=0.02,    # 步长
    description='相似度阈值:',
    disabled=False
)

# 创建输入框和按钮
url_input = widgets.Text(
    value='',
    placeholder='输入图片 URL',
    description='URL:',
    disabled=False
)

file_input = widgets.FileUpload(
    accept='image/*',
    description='选择文件',
    multiple=False  # 只允许选择一张图片
)

button = widgets.Button(
    description='图片去重导入',
    disabled=False,
    button_style='',
    tooltip='点击显示图片',
    icon='check'
)

output = widgets.Output()

def picture_insert(image_data, image_path):
    # 图片写入本地
    with open(image_path, 'wb') as f:
        f.write(image_data)
        
def http_download(url: str):
        headers = {
            "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36"
        }
        try:
            response = requests.get(url, timeout=20, headers=headers)
            # 确保请求成功
            response.raise_for_status()
            image = response.content
            return 0, image
        except requests.exceptions.Timeout:
            print(url, "download timeout")
            return -1, "download timeout"
        except requests.exceptions.RequestException as e:
            print(url, "download failed", e)
            return -1, "download failed"

# 写入新图片
def import_new_image(image_data, image_path, picture_embedding):
    # 写入 本地库
    print(f"导入图片到目录中: {image_path}, 导入结果: {picture_insert(image_data, image_path)}")
    # 写入 lindorm
    doc_id = image_path
    print(f"导入图片向量到 lindorm 实例中 {doc_id}, 导入结果: {lindorm.write_doc(doc_id, picture_embedding)}")

# 定义按钮点击事件处理函数
def on_button_clicked(b):
    with output:
        clear_output()  # 清除上次输出
        if url_input.value:
            code, image_data = http_download(url_input.value)
            url_input.value = ''
            img = Image.open(io.BytesIO(image_data))
            image_type = img.format.lower()
        elif file_input.value:
            # 获取上载图像的数据
            uploaded_file = next(iter(file_input.value))
            image_data = bytes(uploaded_file['content'])
            image_type = uploaded_file['type'].split('/')[-1]
        else:
            print("请先输入图片 URL 或选择本地图片")
            return 
        
        if image_data is None:
            print("图片下载失败")
            return
        
        # 显示图片
        img = Image.open(io.BytesIO(image_data))
        plt.imshow(img)
        plt.axis('off')
        plt.show()
        # 计算图片向量
        code, picture_embedding = lindorm.picture_embedding(image_data)
        hits = lindorm.knn_search(picture_embedding)
        
        # 显示 float_input
        print('图片相似度阈值:', float_input.value)
        print('图片类型:', image_type)
        
        if hits is None or len(hits) == 0:
            print("图片搜索失败")
            return
            
        print('search, count', len(hits))
        for hit in hits:
            print(hit.get('_id'), hit.get('_score'))
            
        # 创建一个 nx3 的子图
        n = int((lindorm.top_k + 2) / 3)
        fig, axes = plt.subplots(n, 3, figsize=(25, 15))
        
        # 遍历图片 URL
        for ax, hit in zip(axes.flatten(), hits):
            url = hit.get('_id')
            # 在子图中显示图片
            img = Image.open(url)
            # print(url, img)
            ax.imshow(img)
            ax.axis('off')  # 隐藏坐标轴

        plt.tight_layout()  # 调整布局
        plt.show()
        
        if hits[0].get('_score') >= float_input.value:
            print(f"库中存在图片相似度 >= {float_input.value}，不导入")
        else:
            print(f"库中不存在图片相似度 >= {float_input.value}，导入这张图片")
            dir = dir_input.value
            if not os.path.exists(dir):
                os.makedirs(dir)
            pic_name = str(uuid.uuid4()) + "." + image_type
            pic_path = os.path.join(dir, f"{os.path.basename(pic_name)}")
            import_new_image(image_data, pic_path, picture_embedding)
        
# 绑定事件处理函数到按钮
button.on_click(on_button_clicked)

# 显示组件
display(float_input)
display(url_input)
display(file_input)

display(button)
display(output)