#### 数据整理 --- 百度数据集

In [7]:
import pandas as pd
import numpy as np
import math
import os
import json
from PIL import Image
import requests
from tqdm import tqdm
import json
from IPython.display import display,HTML
from typing import  Optional, List
from io import BytesIO
import base64
from collections import defaultdict

import matplotlib.pyplot as plt
plt.rcParams["font.sans-serif"]=["SimHei"]
plt.rcParams["axes.unicode_minus"]=False 
from glob import glob






In [2]:
#### 可视化函数

def imgGallery(imglist: Optional[List], ids: Optional[List]=None, \
        scores: Optional[List] = None, tags: Optional[List]=None, \
        default_hight: Optional[str]='300px'):

    """ 显示图片：imglist：url列表或图片地址列表  """

    if ids is None:
        ids = []
    if tags is None:
        tags = []
    if scores is None:
        scores = []
    
    figures = []
    for i, img_path in enumerate(imglist):
        mid = ids[i] if ids else ''
        tag = tags[i] if tags else ""
        score = ', score=%.2f' % scores[i] if scores else ""
        file_info = f'<figcaption style="font-size: 1em;"><a target="_blank" href="{img_path}" style="display:inline;margin:1px"/> <br> {mid} {score} </figcaption>'
        desc_img = f'<figcaption style="font-size: 1em;"> <b> {tag} </b> <br> </figcaption>'
        figures.append(
            f"""
                <figure style="margin: 5px !important;">
                    {file_info}
                    <a target="_blank"  style="word-break: normal;"><img src="{img_path}" style="height: {default_hight}"></a>
                    {desc_img}
                </figure>
            """
            )

    
    return display(HTML(
                    data=f"""
                        <div style="display: flex; flex-flow: row wrap; text-align: center;">
                        {''.join(figures)}
                        </div>
                        """
                ))



#### 数据集准备

###### 原来爬虫代码出发百度反爬机制，所以这里就直接给出数据集

In [8]:
## keyword: 爬虫采集时使用的关键词；tag：百度图片中对应的图像描述, 需要注意长度；caption_cn：tag; file_name: 图片文件名称；filepath：图片文件路径
df = pd.read_csv('./baidu_caps.csv')
df


Unnamed: 0,file_name,tag,caption_cn,filepath,keyword
0,94a8f1e5bb3933204eb241fcb3ff9dfe.jpg,红色落地窗帘,红色落地窗帘,./data/img/94a8f1e5bb3933204eb241fcb3ff9dfe.jpg,红色窗帘
1,6b5f0c5976555c6d0a931f9158e25ba0.jpg,"sun zero becca 节能金属扣眼窗帘 砖红色 40"" x 63"" 53386","sun zero becca 节能金属扣眼窗帘 砖红色 40"" x 63"" 53386",./data/img/6b5f0c5976555c6d0a931f9158e25ba0.jpg,红色窗帘
2,7d8468afce98b757bd30f37da30be52e.jpg,欧式窗帘红色,欧式窗帘红色,./data/img/7d8468afce98b757bd30f37da30be52e.jpg,红色窗帘
3,04c3458b1fddcb0ff81a99b57a2223cf.jpg,红色喜庆蕾丝中式客厅窗帘效果图,红色喜庆蕾丝中式客厅窗帘效果图,./data/img/04c3458b1fddcb0ff81a99b57a2223cf.jpg,红色窗帘
4,94c10548670db012e5a2267db6f34df0.jpg,ucharge 遮光窗帘适用于卧室和起居室隔热孔环顶部窗帘 1 个窗格 红色,ucharge 遮光窗帘适用于卧室和起居室隔热孔环顶部窗帘 1 个窗格 红色,./data/img/94c10548670db012e5a2267db6f34df0.jpg,红色窗帘
...,...,...,...,...,...
274,0f4935de3d147b8fac126bdc37c16c36.jpg,3米沙发客厅加工,3米沙发客厅加工,./data/img/0f4935de3d147b8fac126bdc37c16c36.jpg,客厅条纹地毯
275,faeb1100e2ae5a4091bad9ddd741cd9c.jpg,倍丽雅地毯 商用条纹地毯 客厅茶几办公用地毯 满铺方形地毯,倍丽雅地毯 商用条纹地毯 客厅茶几办公用地毯 满铺方形地毯,./data/img/faeb1100e2ae5a4091bad9ddd741cd9c.jpg,客厅条纹地毯
276,e04c103f1130f283e9ce1883a826a4c0.jpg,卡通可爱简约条纹厨房洗手间客厅长条床边地毯进门垫浴室防滑地垫,卡通可爱简约条纹厨房洗手间客厅长条床边地毯进门垫浴室防滑地垫,./data/img/e04c103f1130f283e9ce1883a826a4c0.jpg,客厅条纹地毯
277,feac2451fbca7c4931431590fe8bf2d2.jpg,现代简约客厅沙发黑色条纹地毯贴图,现代简约客厅沙发黑色条纹地毯贴图,./data/img/feac2451fbca7c4931431590fe8bf2d2.jpg,客厅条纹地毯


In [18]:
kw = '枝形吊灯'

imgGallery(imglist=df[df['keyword']==kw]['imgfile'].tolist(), tags=df[df['keyword']==kw]['tag'].tolist(), ids= [os.path.basename(x) for x in df[df['keyword']==kw]['imgfile'].tolist()] )




#### 数据格式转化：



In [10]:
import pickle
import lmdb
import json
from sklearn.model_selection import train_test_split


def get_cn_clip_df(cap_df, save_dir):
    """ 拆分数据 """
    if not os.path.exists(save_dir): os.mkdir(save_dir)

    idxes = cap_df.index.values
    train_indx, test_idx = train_test_split(idxes.tolist(), test_size=0.1)
    test_idx, valid_idx = train_test_split(test_idx, test_size=0.4)
    print(len(train_indx), len(test_idx), len(valid_idx))
    train_df = cap_df.loc[train_indx]
    test_df = cap_df.loc[test_idx]
    valid_df = cap_df.loc[valid_idx]

    train_df.reset_index(drop=True ,inplace=True)
    test_df.reset_index(drop=True ,inplace=True)
    valid_df.reset_index(drop=True ,inplace=True)
    
    train_df.to_csv(os.path.join(save_dir, "train.csv"), index=False)
    test_df.to_csv(os.path.join(save_dir, "test.csv"), index=False)
    valid_df.to_csv(os.path.join(save_dir, "valid.csv"), index=False)



def img2bytes(img_path):
    img = Image.open(img_path) # 访问图片路径
    img_buffer = BytesIO()
    img.save(img_buffer, format=img.format)
    byte_data = img_buffer.getvalue()
    base64_str = base64.b64encode(byte_data) # bytes
    base64_str = base64_str.decode("utf-8") # str
    del img, byte_data
    return base64_str

def gen_cn_clip_data(df, dt_type='train', save_dir = os.getcwd()):
    """ 产生 图片内容及匹配文本数据 """
    
    images_df = []
    texts_dt = []
    imgs_dict = {}
    if not os.path.exists(save_dir):os.mkdir(save_dir)
    txt_json1 = os.path.join(save_dir, f"{dt_type}_texts.jsonl")
    imgs_tsv = os.path.join(save_dir, f"{dt_type}_imgs.tsv")
    imgdict_js = os.path.join(save_dir, f"{dt_type}_dict.json")
    for i,line in enumerate(df.itertuples()):
        ### 图
        image_id = i
        imgs_dict[image_id] = line[1]
        img_path = line[4]
        print(img_path)
        img_byte = img2bytes(img_path)
        images_df.append({"img_id": i, "img_content": img_byte})
        ### 文
        text = "".join(line[3].split("|||"))
        text_id = i
        jtxt = {"text_id": text_id, "text": text, "image_ids": [image_id]}
        txt_js = json.dumps(jtxt, ensure_ascii=False)
        print(txt_js, file=open(txt_json1, 'a'))
    
    images_df = pd.DataFrame(images_df)
    print(json.dumps(imgs_dict, indent=True), file=open(imgdict_js, 'w'))
    images_df.to_csv(imgs_tsv, sep="\t", header=None, index=False)


# #### 分差数据为 train，test, valid 三份：
out_dir = "./"
get_cn_clip_df(df, out_dir)
### 生成文图数据对：
train_df = pd.read_csv( os.path.join(out_dir, "train.csv") )
test_df = pd.read_csv( os.path.join(out_dir, "test.csv") )
valid_df = pd.read_csv(os.path.join(out_dir, "valid.csv"))



#### 生成训练数据：
dataset_dir = "./data/baidu"
train_df = pd.read_csv( os.path.join(out_dir, "train.csv") )
test_df = pd.read_csv( os.path.join(out_dir, "test.csv") )
valid_df = pd.read_csv(os.path.join(out_dir, "valid.csv"))
gen_cn_clip_data(train_df, dt_type='train', save_dir=dataset_dir)
gen_cn_clip_data(test_df, dt_type='test', save_dir=dataset_dir)
gen_cn_clip_data(valid_df, dt_type='valid', save_dir=dataset_dir)

251 16 12
./data/img/82c937ff3805aa6a618ec9a06756951a.jpg
./data/img/c6d730416ca1f0b4b3fe4e3591bccc78.jpg
./data/img/42aa6aec9fc8bd801367911f0b6be140.jpg
./data/img/3a6b1e7c7dd9e3188bf6c840055e1177.jpg
./data/img/30f2cae71022bf2f2aa590d82c53027b.jpg
./data/img/ab718cc69e6cf91daf1755b618db8725.jpg
./data/img/371bac059c99e3594253eb47aafbe768.jpg
./data/img/541160cafa9652f9a05fb026f1905c4a.jpg
./data/img/fa471fc6f25034082e9bcb000dd7f73a.jpg
./data/img/0d3dc2d613d95bdecce399f1cabbb826.jpg
./data/img/7ad0deaa36e6eff6f4e699d112815d3f.jpg
./data/img/6b288374696d957e6196c5a4f172539e.jpg
./data/img/8c71e1a2d23f26d9a649ae5f620c213d.jpg
./data/img/ab1858499364c3f28185dc08a8814400.jpg
./data/img/4efd9f0b8a2e8a1f4349d53b7dc49a44.jpg
./data/img/0d2dd7e939a03864fc584455caee117e.jpg
./data/img/4d0345f934a297fb027d14be188dc1bd.jpg
./data/img/2905ae07f9feef8655ed5a783d8327f9.jpg
./data/img/30b7b59bc982397e2d9b9be3e77e8b46.jpg
./data/img/6b5f0c5976555c6d0a931f9158e25ba0.jpg
./data/img/0fbdd144d88dec866b5

#### 训练

In [11]:
### 配置参数在 util/params.py 当中，可以按照实际情况进行配置


### 运行方法：
!python .clip_finetune.py 

python: can't open file '.clip_finetune.py': [Errno 2] No such file or directory
