## 图片自动去重导入

通过以图搜图，提供图片自动去重功能，相似度 >= 阈值的图片不导入，并返回相似照片

相似度[0, 1], 越大越相似， 1 表示完全相同

In [1]:
import ipywidgets as widgets

# 创建输入框和按钮
dir_input = widgets.Text(
    value='data/jpg',
    placeholder='data/jpg',
    description='图片目录:',
    disabled=False,
    layout=widgets.Layout(width='auto')
)

display(dir_input)

Text(value='data/jpg', description='图片目录:', layout=Layout(width='auto'), placeholder='data/jpg')

In [2]:
import os
import requests
import uuid
import io
import ipywidgets as widgets
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import display, clear_output
from src.lindorm import Lindorm

lindorm = Lindorm()

# 创建一个浮点数输入框，并设定默认值为 0.95
float_input = widgets.FloatText(
    value=0.95,
    placeholder=0.95,
    min=0.0,     # 最小值
    max=1.0,     # 最大值
    step=0.02,    # 步长
    description='相似度阈值:',
    disabled=False
)

# 创建输入框和按钮
url_input = widgets.Text(
    value='',
    placeholder='输入图片 URL',
    description='URL:',
    disabled=False
)

file_input = widgets.FileUpload(
    accept='image/*',
    description='选择文件',
    multiple=False  # 只允许选择一张图片
)

button = widgets.Button(
    description='图片去重导入',
    disabled=False,
    button_style='',
    tooltip='点击显示图片',
    icon='check'
)

output = widgets.Output()

def picture_insert(image_data, image_path):
    # 图片写入本地
    with open(image_path, 'wb') as f:
        f.write(image_data)
        
def http_download(url: str):
        headers = {
            "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36"
        }
        try:
            response = requests.get(url, timeout=20, headers=headers)
            # 确保请求成功
            response.raise_for_status()
            image = response.content
            return 0, image
        except requests.exceptions.Timeout:
            print(url, "download timeout")
            return -1, "download timeout"
        except requests.exceptions.RequestException as e:
            print(url, "download failed", e)
            return -1, "download failed"

# 写入新图片
def import_new_image(image_data, image_path, picture_embedding):
    # 写入 本地库
    print(f"导入图片到目录中: {image_path}, 导入结果: {picture_insert(image_data, image_path)}")
    # 写入 lindorm
    doc_id = image_path
    print(f"导入图片向量到 lindorm 实例中 {doc_id}, 导入结果: {lindorm.write_doc(doc_id, picture_embedding)}")

# 定义按钮点击事件处理函数
def on_button_clicked(b):
    with output:
        clear_output()  # 清除上次输出
        if url_input.value:
            code, image_data = http_download(url_input.value)
            url_input.value = ''
            img = Image.open(io.BytesIO(image_data))
            image_type = img.format.lower()
        elif file_input.value:
            # 获取上载图像的数据
            uploaded_file = next(iter(file_input.value))
            image_data = bytes(uploaded_file['content'])
            image_type = uploaded_file['type'].split('/')[-1]
        else:
            print("请先输入图片 URL 或选择本地图片")
            return 
        
        if image_data is None:
            print("图片下载失败")
            return
        
        # 显示图片
        img = Image.open(io.BytesIO(image_data))
        plt.imshow(img)
        plt.axis('off')
        plt.show()
        # 计算图片向量
        code, picture_embedding = lindorm.picture_embedding(image_data)
        hits = lindorm.knn_search(picture_embedding)
        
        # 显示 float_input
        print('图片相似度阈值:', float_input.value)
        print('图片类型:', image_type)
        
        if hits is None or len(hits) == 0:
            print("图片搜索失败")
            return
            
        print('search, count', len(hits))
        for hit in hits:
            print(hit.get('_id'), hit.get('_score'))
            
        # 创建一个 3x3 的子图
        fig, axes = plt.subplots(3, 3, figsize=(9, 9))
        
        # 遍历图片 URL
        for ax, hit in zip(axes.flatten(), hits):
            url = hit.get('_id')
            # 在子图中显示图片
            img = Image.open(url)
            # print(url, img)
            ax.imshow(img)
            ax.axis('off')  # 隐藏坐标轴

        plt.tight_layout()  # 调整布局
        plt.show()
        
        if hits[0].get('_score') >= float_input.value:
            print(f"库中存在图片相似度 >= {float_input.value}，不导入")
        else:
            print(f"库中不存在图片相似度 >= {float_input.value}，导入这张图片")
            dir = dir_input.value
            if not os.path.exists(dir):
                os.makedirs(dir)
            pic_name = str(uuid.uuid4()) + "." + image_type
            pic_path = os.path.join(dir, f"{os.path.basename(pic_name)}")
            import_new_image(image_data, pic_path, picture_embedding)
        
# 绑定事件处理函数到按钮
button.on_click(on_button_clicked)

# 显示组件
display(float_input)
display(url_input)
display(file_input)

display(button)
display(output)

FloatText(value=0.95, description='相似度阈值:', step=0.02)

Text(value='', description='URL:', placeholder='输入图片 URL')

FileUpload(value=(), accept='image/*', description='选择文件')

Button(description='图片去重导入', icon='check', style=ButtonStyle(), tooltip='点击显示图片')

Output()