In [6]:
import json
from ds_util import head

In [5]:
yelp_dataset_path = '/nvme0n1p2/yelp_dataset/'
yelp_photo_path = '/nvme0n1p2/yelp_photos/'

# 交互数据

## Step1：从 review 数据集中提取所有用户，并统计每个用户的评论数量

In [7]:
# 加载 user_id 为集合
user_ids = set()
with open(yelp_dataset_path + 'yelp_academic_dataset_user.json', 'r', encoding='utf-8') as f_users:
    for line in f_users:
        user_data = json.loads(line)
        user_ids.add(user_data['user_id'])

# 统计每个 user_id 的实际 reviews 数量
user_review_count = {}

with open(yelp_dataset_path + 'yelp_academic_dataset_review.json', 'r', encoding='utf-8') as review_file:
    for line in review_file:
        review_data = json.loads(line)
        user_id = review_data['user_id']
        if user_id in user_ids:
            if user_id in user_review_count:
                user_review_count[user_id] += 1
            else:
                user_review_count[user_id] = 1

# 保存为本地文件
with open('yelp_out/user_review_count.txt', 'w', encoding='utf-8') as output_file:
    for user_id, count in user_review_count.items():
        output_file.write(f"{user_id} {count}\n")
    
    print("generate complete")

generate complete


In [18]:
head('yelp_out/user_review_count.txt')

mh_-eMZ6K5RLWhZyISBhwA 28



## Step2：筛选出交互次数（评论数）大于等于 10 次的用户

In [9]:
items = {}
with open('yelp_out/user_review_count.txt', 'r', encoding='utf-8') as file:
    for line in file:
        # count_data -> [user_id, review_count]
        count_data = line.strip().split(' ')
        if int(count_data[1]) >= 10:
            items[count_data[0]] = count_data[1]
        if int(count_data[1]) < 1:
            print(f'user {count_data[0]} has 0 review!')

# 保存为本地文件
with open('yelp_out/core_users.txt', 'w', encoding='utf-8') as output_file:
    for user_id, count in items.items():
        output_file.write(f"{user_id} {count}\n")
    
    print("generate complete")

generate complete


In [19]:
head('yelp_out/core_users.txt')

mh_-eMZ6K5RLWhZyISBhwA 28



## Step3：从 review 数据集中提取所有项目商家，并统计每个商家的评论数量

In [23]:
# 加载 business_id 为集合
business_ids = set()
with open(yelp_dataset_path + 'yelp_academic_dataset_business.json', 'r', encoding='utf-8') as f_item:
    for line in f_item:
        item_data = json.loads(line)
        business_ids.add(item_data['business_id'])

# 统计每个 business_id 的实际 reviews 数量
item_review_count = {}

with open(yelp_dataset_path + 'yelp_academic_dataset_review.json', 'r', encoding='utf-8') as review_file:
    for line in review_file:
        review_data = json.loads(line)
        business_id = review_data['business_id']
        if business_id in business_ids:
            if business_id in item_review_count:
                item_review_count[business_id] += 1
            else:
                item_review_count[business_id] = 1

# 保存为本地文件
with open('yelp_out/item_review_count.txt', 'w', encoding='utf-8') as output_file:
    for business_id, count in item_review_count.items():
        output_file.write(f"{business_id} {count}\n")
    
    print("generate complete")

generate complete


In [27]:
head('yelp_out/item_review_count.txt')

XQfwVwDr-v0ZS3_CbbE5Xw 175



## Step4：筛选出交互次数大于等于 10 次且存在 photos 数据的商家

理论上说，10-core settings 只需要交互次数限制就可以了。但本文由于需要图片模态数据，而实际处理过程中发现如果不添加“有图片”的限制条件，生成的数据集中，将有约 67% 的 business 没有对应的 photos 数据。

In [15]:
items = {}
business_with_photos = set()

try:
    with open(yelp_photo_path + 'photos.json', 'r') as photo_list:
        for line in photo_list:
            data = json.loads(line)
            business_id = data['business_id']
            business_with_photos.add(business_id)

    with open('yelp_out/item_review_count.txt', 'r', encoding='utf-8') as file:
        for line in file:
            # count_data -> [business_id, review_count]
            count_data = line.strip().split(' ')
            if int(count_data[1]) >= 10 and count_data[0] in business_with_photos:
                items[count_data[0]] = count_data[1]
            if int(count_data[1]) < 1:
                print(f'business {count_data[0]} has 0 review!')

    # 保存为本地文件
    with open('yelp_out/re_core_items.txt', 'w', encoding='utf-8') as output_file:
        for business_id, count in items.items():
            output_file.write(f"{business_id} {count}\n")
        
        print("generate complete")

except Exception as e:
    print(f'errors: {e}')

generate complete


In [30]:
head('yelp_out/core_items.txt')

XQfwVwDr-v0ZS3_CbbE5Xw 175



## Step5：交叉过滤-从 review 数据集中提取出 10-core 交互记录

注意，yelp 数据集自带换行符

In [46]:
with open(yelp_dataset_path + 'yelp_academic_dataset_review.json', 'r', encoding='utf-8') as review_file:
    for line in review_file:
        if '\n' in line: 
            print('yes')
            break

yes


此外，由于存在一个用户对同一商家的多条评论记录，因此构造过滤 review 数据集时，需考虑以下几方面：
1. 用户和项目都是核心用户/项目
2. (user_id ,business_id) 不重复
3. 针对重复的情况，取评分最高项

In [16]:
core_users = set()
items = set()
# 保证一个用户对一个项目只有一条交互记录
user_item_pairs = {}

with open('yelp_out/core_users.txt', 'r', encoding='utf-8') as file:
    for line in file:
        user_id = line.strip().split(' ')[0]
        core_users.add(user_id)

with open('yelp_out/re_core_items.txt', 'r', encoding='utf-8') as file:
    for line in file:
        business_id = line.strip().split(' ')[0]
        items.add(business_id)

filter_reviews = []  # -> [user_id, business_id, stars, review_id]
index = 0
with open(yelp_dataset_path + 'yelp_academic_dataset_review.json', 'r', encoding='utf-8') as review_file:
    for line in review_file:
        review_data = json.loads(line)
        user_id = review_data['user_id']
        business_id = review_data['business_id']

        if user_id in core_users and business_id in items:
            if (user_id, business_id) not in user_item_pairs:
                # 用哈希表同时记录评分和索引，以便后续操作
                user_item_pairs[(user_id, business_id)] = [review_data['stars'], index]
                # filter_reviews.append([user_id, business_id, review_data['stars'], review_data['review_id']])
                filter_reviews.append([user_id, business_id, review_data['stars']])
                index += 1
            else:
                if review_data['stars'] > user_item_pairs[(user_id, business_id)][0]:
                    pair_index = user_item_pairs[(user_id, business_id)][1]
                    # 更新数据
                    # filter_reviews[pair_index] = [user_id, business_id, review_data['stars'], review_data['review_id']]
                    filter_reviews[pair_index] = [user_id, business_id, review_data['stars']]

# 持久化过滤后的评论数据集
with open('yelp_out/re_yelp_interactions.txt', 'w', encoding='utf-8') as out_file:
    for record in filter_reviews:
        out_file.write(' '.join(map(str, record)) + '\n')
    print('Filter reviews done')


Filter reviews done


In [18]:
head('yelp_out/re_yelp_interactions.txt', 2)

bcjbaE6dDog4jkNY91ncLQ e4Vwtrqf-wpJfwesgvdgxQ 4.0

smOvOajNG0lS4Pq7d8g4JQ RZtGWDLCAtuipwaZ-UfjmQ 4.0



## Step6：划分数据集

参考 [SELFRec issue 54](https://github.com/Coder-Yu/SELFRec/issues/54)，领域内在得到最终结果时，似乎会把训练集和验证集合并，所以这里先按 8:2 划分训练集跟测试集，然后在将训练集划分为 7:1

In [19]:
import pandas as pd
from sklearn.model_selection import train_test_split

data = pd.read_csv('yelp_out/re_yelp_interactions.txt', delimiter=' ', header=None)

# 检查重复行
print(f"原始数据重复行数: {data.duplicated().sum()}")

# 划分训练集和测试集，比例为8:2
temp_data, test_data = train_test_split(data, test_size=0.2, random_state=114514)

temp_data.to_csv('yelp_ds_re/train_data.txt', index=False, header=False, sep=' ')
test_data.to_csv('yelp_ds_re/test_data.txt', index=False, header=False, sep=' ')

# 查看划分后的数据集大小
print(f"训练集大小: {len(temp_data)}")
print(f"测试集大小: {len(test_data)}")

# 检查是否有重复项
print(f"训练数据重复行数: {temp_data.duplicated().sum()}")
print(f"测试数据重复行数: {test_data.duplicated().sum()}")
duplicates_in_train = temp_data.merge(test_data, how='inner')
print(f'训练集和测试集中的重复行:\n{duplicates_in_train}')


原始数据重复行数: 0
训练集大小: 1429167
测试集大小: 612501
训练数据重复行数: 0
测试数据重复行数: 0
训练集和测试集中的重复行:
Empty DataFrame
Columns: [0, 1, 2]
Index: []


# 图像数据

1. 根据生成的 interactions 数据，对第二列建立集合，即数据集中所有 business_id
2. 将 business_id 代入 `photos/` 检索出所有 photo_id 并保存文件

## 图像记录数据展示

In [6]:
head(yelp_photo_path + 'photos.json', jsonf=True)

{
  "photo_id": "zsvj7vloL4L5jhYyPIuVwg",
  "business_id": "Nk-SJhPlDBkAZvfsADtccA",
  "caption": "Nice rock artwork everywhere and craploads of taps.",
  "label": "inside"
}


## 商家-图片映射记录

从最终的**交互数据集**中，提取出商家-图片(一对多)的映射关系。

In [20]:

photo_ids = set()
items = set()
item2photos = {}

try:
    with open('yelp_out/re_yelp_interactions.txt', 'r') as interact_file:
        for line in interact_file:
            items.add(line.split(' ')[1])
            item2photos[line.split(' ')[1]] = []
    
    # item: photo_id1, photo_id2, ...
    with open('yelp_out/item2photos.txt', 'w', encoding='utf-8') as out_file:
        with open(yelp_photo_path + 'photos.json', 'r', encoding='utf-8') as photos:
            for line in photos:
                json_data = json.loads(line)
                photo_id = json_data['photo_id']
                business_id = json_data['business_id']

                # 存在 photo_id 重复的情况
                if photo_id not in photo_ids and business_id in items:
                    item2photos[business_id].append(photo_id)
                else:
                    continue
        
        for item in item2photos:
            out_file.write(item + ' ' + ' '.join(item2photos[item]) + '\n')
        print('generate complete')
except Exception as e:
    print(f'error: {e}')

generate complete


In [21]:
with open('yelp_out/item2photos.txt', 'r') as f:
    no_photos_item = 0
    total_item = 0

    for line in f:
        total_item += 1
        parts = line.split(' ')
        business_id = parts[0]
        photo_ids = parts[1:]

        if photo_ids[0] == '\n':
            no_photos_item += 1

    print(f'{no_photos_item} business without any photo')
    print(f'occupy {no_photos_item/total_item*100}%')

0 business without any photo
occupy 0.0%


## 使用 [CLIP-ViT](https://huggingface.co/openai/clip-vit-base-patch32) 对图像进行编码

- [tutorial 1](https://medium.com/@highsunday0630/image-embedding-1-clip%E6%A8%A1%E5%9E%8B%E6%8F%90%E5%8F%96-image-embedding-%E4%B8%A6%E4%BB%A5-tensorboard-%E8%A6%96%E8%A6%BA%E5%8C%96%E6%95%88%E6%9E%9C-dc281370d7d8)
- [tutorial 2](https://blog.csdn.net/qq_37756660/article/details/135979873)

In [1]:
from transformers import CLIPProcessor, CLIPModel
from safetensors.torch import save_file
from PIL import Image
from safetensors import safe_open
import os
import tqdm
import torch

In [2]:
model_path = "/home/yzh/code/SELFRec/model/clip-vit-base-patch32"
# 图片存储路径
directory = '/nvme0n1p2/yelp_photos/photos'

In [3]:
photos = set()  # 所有需要处理的图片
with open('yelp_out/item2photos.txt', 'r', encoding='utf-8') as file:
    for line in file:
        for photo_id in line.strip().split(' ')[1:]:
            photos.add(photo_id)
print('total record photos:', len(photos))

total record photos: 190902


yelp 数据集中的 photo 数据存在脏数据(无法识别、损坏)

> 我是没想到公开数据集连这个都不处理好😅

先写个日志看看有多少脏数据吧（手可别抖，下面的日志模块代码切记执行一次，这不比别的，多执行几次会多出几个logger，到时候你的日志里都是重复的信息）

In [4]:
import logging
photo_logger = logging.getLogger('error_photo')
photo_logger.setLevel(logging.INFO)
sh = logging.FileHandler('../log/error_photo.log')
sh.setLevel(logging.WARNING)
formatter = logging.Formatter('%(name)s - %(message)s')
sh.setFormatter(formatter)
photo_logger.addHandler(sh)

In [5]:
# 加载模型
# model = CLIPModel.from_pretrained(model_path).to(f'cuda:{device_ids[0]}')
model = CLIPModel.from_pretrained(model_path).to('cuda')
processor = CLIPProcessor.from_pretrained(model_path)

# 定义批量大小
batch_size = 64

# 初始化计数器
photo_num = 0
processed_num = 0
error_num = 0
photo_embs = {}

# 读取目录中的图片
with os.scandir(directory) as entries:
    images = []  # 待处理图片列表
    photo_ids = []  # 图片ID列表

    # 遍历目录中的所有文件
    for entry in tqdm.tqdm(entries):
        photo_id = entry.name.replace('.jpg', '')
        if photo_id not in photos:
            continue
        photo_num += 1
        try:
            images.append(Image.open(entry.path))
            photo_ids.append(photo_id)
        except Exception as e:
            photo_logger.warning(f'{entry.name} error: {e}')
            error_num += 1
            continue

        # 当达到批量大小时，处理这批图片
        if len(images) == batch_size:
            inputs = processor(images=images, return_tensors='pt')
            inputs = {k: v.to('cuda') for k, v in inputs.items()}
            with torch.no_grad():
                # 使用get_image_features获取图像特征
                outputs = model.get_image_features(pixel_values=inputs['pixel_values'])
            for idx, pid in enumerate(photo_ids):
                photo_embs[pid] = outputs[idx]
            processed_num += batch_size
            
            # 清空列表
            images = []
            photo_ids = []

    # 处理剩余的图片
    if images:
        inputs = processor(images=images, return_tensors='pt')
        inputs = {k: v.to('cuda') for k, v in inputs.items()}
        with torch.no_grad():
            # 使用get_image_features获取图像特征
            outputs = model.get_image_features(pixel_values=inputs['pixel_values'])
        for index, pid in enumerate(photo_ids):
            photo_embs[pid] = outputs[index]
        processed_num += len(images)

print(f'total {photo_num} photos, processed: {processed_num}, errors: {error_num}, len of photo_embs: {len(photo_embs)}')
# 保存为文件
save_file(photo_embs, 'photo_embs.safetensors')

  return self.fget.__get__(instance, owner)()
200098it [27:28, 121.40it/s]


total 190902 photos, processed: 190797, errors: 105, len of photo_embs: 190797


In [7]:
with safe_open('photo_embs.safetensors', framework='pt') as f:
    total_num = len(f.keys())
    photo_num = 0
    for k in f.keys():
        if k in photos: photo_num += 1
    for k in f.keys():
        print(f.get_tensor(k).shape)
        break
    print(f'{photo_num/total_num*100}% photos are in the dataset')

torch.Size([512])
100.0% photos are in the dataset


按照原来的方案，有105张损坏图像，现在从映射数据 `item2photots.txt` 中删除无效图片。

In [12]:
error_ids = set()  # 错误图像id集合
with open('../log/error_photo.log', 'r') as f:
    for line in f:
        photo_id = line.strip().split(' - ')[1].split(' ')[0].replace('.jpg', '')
        if photo_id not in error_ids:
            error_ids.add(photo_id)
        else:
            print(f'exist duplicate error photo: {photo_id}')

In [17]:
error_items = set()
with open('yelp_out/re_item2photos.txt', 'w') as out_file:
    with open('yelp_out/item2photos.txt', 'r') as f:
        for line in f:
            new_line = []
            item_id = line.strip().split(' ')[0]
            photo_ids = line.strip().split(' ')[1:]
            filter_ids = [id for id in photo_ids if id not in error_ids]
            if len(filter_ids) == 0:
                print(f'{item_id} remain no photo😭')
                error_items.add(item_id)
                continue
            new_line.append(item_id)
            new_line.extend(filter_ids)
            out_file.write(' '.join(new_line) + '\n')        

WWs1xspH1d-NCIWmXM40RQ remain no photo😭
qJKMyChtpyqPvf7HNfvq4A remain no photo😭
vpkCctZV4_q7iUkmtdZkzQ remain no photo😭
dBa7aJXV50TZEtInwdbvfg remain no photo😭
LIh33t2G-y0C1H3o41xJSQ remain no photo😭
djn6PlsuFw_Z_gRA55QDcg remain no photo😭
a1Bd6IhR_Bsthhff9VGLoA remain no photo😭


In [18]:
with open('yelp_out/re_item2photos.txt', 'r') as file:
    for line in file:
        item = line.strip().split(' ')[0]
        photos = line.strip().split(' ')[1:]
        if len(photos) == 0:
            print(f'{item} has no photo')
        for photo in photos:
            if photo in error_ids:
                print(f'{item} has error photo')

然后从交互数据中删除 过滤损毁图像后 无剩余图像的 item 对应的交互记录

In [20]:
with open('yelp_out/filter_yelp_interactions.txt', 'w') as out_file:
    with open('yelp_out/re_yelp_interactions.txt', 'r') as f:
        for line in f:
            item_id = line.strip().split(' ')[1]
            if item_id in error_items:
                print(f'delete error item {item_id}')
                continue
            out_file.write(line)

delete error item WWs1xspH1d-NCIWmXM40RQ
delete error item WWs1xspH1d-NCIWmXM40RQ
delete error item WWs1xspH1d-NCIWmXM40RQ
delete error item WWs1xspH1d-NCIWmXM40RQ
delete error item qJKMyChtpyqPvf7HNfvq4A
delete error item qJKMyChtpyqPvf7HNfvq4A
delete error item qJKMyChtpyqPvf7HNfvq4A
delete error item qJKMyChtpyqPvf7HNfvq4A
delete error item qJKMyChtpyqPvf7HNfvq4A
delete error item qJKMyChtpyqPvf7HNfvq4A
delete error item qJKMyChtpyqPvf7HNfvq4A
delete error item qJKMyChtpyqPvf7HNfvq4A
delete error item qJKMyChtpyqPvf7HNfvq4A
delete error item qJKMyChtpyqPvf7HNfvq4A
delete error item vpkCctZV4_q7iUkmtdZkzQ
delete error item dBa7aJXV50TZEtInwdbvfg
delete error item dBa7aJXV50TZEtInwdbvfg
delete error item vpkCctZV4_q7iUkmtdZkzQ
delete error item dBa7aJXV50TZEtInwdbvfg
delete error item dBa7aJXV50TZEtInwdbvfg
delete error item vpkCctZV4_q7iUkmtdZkzQ
delete error item dBa7aJXV50TZEtInwdbvfg
delete error item dBa7aJXV50TZEtInwdbvfg
delete error item dBa7aJXV50TZEtInwdbvfg
delete error ite

## 返回 Step6 重新划分数据集

In [2]:
import pandas as pd
from sklearn.model_selection import train_test_split
import os

data = pd.read_csv('yelp_out/filter_yelp_interactions.txt', delimiter=' ', header=None)

# 检查重复行
print(f"原始数据重复行数: {data.duplicated().sum()}")

# 划分临时集和测试集，比例为8:2
temp_data, test_data = train_test_split(data, test_size=0.2, random_state=114514)
# 将临时集划分为训练集和验证集，比例为7:1
train_data, val_data = train_test_split(temp_data, test_size=0.125, random_state=114514)

os.makedirs('yelp_ds', exist_ok=True)
temp_data.to_csv('yelp_ds/merge_train_data.txt', index=False, header=False, sep=' ')
train_data.to_csv('yelp_ds/train_data.txt', index=False, header=False, sep=' ')
val_data.to_csv('yelp_ds/val_data.txt', index=False, header=False, sep=' ')
test_data.to_csv('yelp_ds/test_data.txt', index=False, header=False, sep=' ')

# 查看划分后的数据集大小
print(f"训练集大小: {len(train_data)}")
print(f"验证集大小: {len(val_data)}")
print(f"测试集大小: {len(test_data)}")

# 检查是否有重复项
print(f"训练集重复行数: {train_data.duplicated().sum()}")
print(f"验证集重复行数: {val_data.duplicated().sum()}")
print(f"测试集重复行数: {test_data.duplicated().sum()}")

# 检查交叉重复项
duplicates_in_all = train_data.merge(val_data, how='inner').merge(test_data, how='inner')
print(f"训练集、验证集和测试集之间重复行数: {duplicates_in_all.shape[0]}")

原始数据重复行数: 0
训练集大小: 1429113
验证集大小: 204159
测试集大小: 408318
训练集重复行数: 0
验证集重复行数: 0
测试集重复行数: 0
训练集、验证集和测试集之间重复行数: 0


# RAG

## 生成 RAG 数据

从 business 数据中提取，以最终交互数据为过滤条件，预计就是下面这样
```
business_id categories
...
```

In [57]:
import json

yelp_dataset_path = '/nvme0n1p2/yelp_dataset/'

items = set()
item_categories = {}
with open('yelp_out/filter_yelp_interactions.txt', 'r') as ui_file:
    for line in ui_file: items.add(line.split(' ')[1].strip())
    print(f'total items: {len(items)}')
    with open(yelp_dataset_path + 'yelp_academic_dataset_business.json', 'r') as f:
        for line in f:
            json_data = json.loads(line)
            business_id = json_data['business_id']
            categories = json_data['categories']
            if business_id in items: item_categories[business_id] = categories

# 保存item_categories到文件
with open('yelp_out/yelp_text.json', 'w', encoding='utf-8') as text_file:
    json.dump(item_categories, text_file, ensure_ascii=True)

total items: 33183


In [58]:
with open('yelp_out/yelp_text.json', 'r', encoding='utf-8') as text_file:
    data = json.load(text_file)
print(len(data))

33183


## 使用 LLM 增强交互数据

### RAG增强(deprecated)

In [23]:
from langchain_community.document_loaders import JSONLoader
from pprint import pprint

loader = JSONLoader(
    file_path="RAG/yelp_rag.json",
    jq_schema="{item_id: .business_id, categories: .categories}",
    json_lines=True,
    text_content=False
)
data = loader.load()

In [24]:
for idx, d in enumerate(data):
    pprint(d)
    if idx > 5: break

Document(metadata={'source': '/home/yzh/code/SELFRec/mk_dataset/RAG/yelp_rag.json', 'seq_num': 1}, page_content='{"item_id": "MTSW4McQd7CbVtyjqoe9mw", "categories": "Restaurants, Food, Bubble Tea, Coffee & Tea, Bakeries"}')
Document(metadata={'source': '/home/yzh/code/SELFRec/mk_dataset/RAG/yelp_rag.json', 'seq_num': 2}, page_content='{"item_id": "bBDDEgkFA1Otx9Lfe7BZUQ", "categories": "Ice Cream & Frozen Yogurt, Fast Food, Burgers, Restaurants, Food"}')
Document(metadata={'source': '/home/yzh/code/SELFRec/mk_dataset/RAG/yelp_rag.json', 'seq_num': 3}, page_content='{"item_id": "eEOYSgkmpB90uNA7lDOMRA", "categories": "Vietnamese, Food, Restaurants, Food Trucks"}')
Document(metadata={'source': '/home/yzh/code/SELFRec/mk_dataset/RAG/yelp_rag.json', 'seq_num': 4}, page_content='{"item_id": "il_Ro8jwPlHresjw9EGmBg", "categories": "American (Traditional), Restaurants, Diners, Breakfast & Brunch"}')
Document(metadata={'source': '/home/yzh/code/SELFRec/mk_dataset/RAG/yelp_rag.json', 'seq_num':

这里的 Document 足够小了，所以无需 split to chunks

In [None]:
from langchain_community.embeddings import OllamaEmbeddings
from langchain_chroma import Chroma

# 实例化嵌入模型
ollama_emb = OllamaEmbeddings(base_url="http://localhost:11434", model="qwen2", num_gpu=2, show_progress=True)

# 使用文档块创建向量数据库并持久化
persist_directory = "RAG"
# If a persist_directory is specified, the collection will be persisted there. Otherwise, the data will be ephemeral in-memory.
vector_db = Chroma.from_documents(documents=data, 
                                     embedding=ollama_emb, 
                                     persist_directory=persist_directory)

In [None]:

# 实例化嵌入模型
ollama_emb = OllamaEmbeddings(base_url="http://localhost:11434", model="qwen2", num_gpu=2, show_progress=True)

# 使用文档块创建向量数据库并持久化
persist_directory = "RAG"
# If a persist_directory is specified, the collection will be persisted there. Otherwise, the data will be ephemeral in-memory.
vector_db = Chroma.from_documents(documents=data, 
                                     embedding=ollama_emb, 
                                     persist_directory=persist_directory)

In [30]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.llms.ollama import Ollama
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.output_parsers import  StrOutputParser

template = """User had interacted with the following businesses:
{item_list}

The categories of these businesses are in {context}. Please output the business_id only from the following candidate, but not user history.
{candidates}

Output format:
business_ids separated by ','. Nothing else. Please do not output other thing else, do not give reasoning.
"""
prompt = ChatPromptTemplate.from_template(template)
model = Ollama(base_url="http://localhost:11434", model="qwen2")

retriever = vector_db.as_retriever()
setup_and_retrieval = RunnableParallel(
    {
        "item_list": RunnablePassthrough(),
        "context": retriever,
        "candidates": RunnablePassthrough()
    }
)


In [37]:
retriever.get_relevant_documents('e4Vwtrqf-wpJfwesgvdgxQ')

  retriever.get_relevant_documents('e4Vwtrqf-wpJfwesgvdgxQ')
OllamaEmbeddings: 100%|██████████| 1/1 [00:02<00:00,  2.73s/it]


[Document(metadata={'seq_num': 32384, 'source': '/home/yzh/code/SELFRec/mk_dataset/RAG/yelp_rag.json'}, page_content='{"item_id": "QKtbbF5-qny5h90Qs3erXw", "categories": "Indian, Restaurants"}'),
 Document(metadata={'seq_num': 14523, 'source': '/home/yzh/code/SELFRec/mk_dataset/RAG/yelp_rag.json'}, page_content='{"item_id": "Qz6-OxFp9PhGwMZG5geqpw", "categories": "Italian, American (New), Restaurants"}'),
 Document(metadata={'seq_num': 1677, 'source': '/home/yzh/code/SELFRec/mk_dataset/RAG/yelp_rag.json'}, page_content='{"item_id": "QlI4_BHwxb5UplGwd4vE0w", "categories": "Mexican, Restaurants"}'),
 Document(metadata={'seq_num': 22024, 'source': '/home/yzh/code/SELFRec/mk_dataset/RAG/yelp_rag.json'}, page_content='{"item_id": "Q46KberieM6ziVYME6CHEQ", "categories": "Mexican, Restaurants"}')]

In [31]:
item_list = []
candidates = []
with open('yelp_out/filter_yelp_interactions.txt', 'r') as f:
    for line in f:
        if line.split(' ')[0].strip() == 'bcjbaE6dDog4jkNY91ncLQ':
            item_list.append(line.split(' ')[1].strip())
        else:  # 非交互项
            if len(candidates) <= 5:
                candidates.append(line.split(' ')[1].strip())

In [36]:
test_chain = setup_and_retrieval | prompt
output = test_chain.invoke({
    "item_list": item_list,
    "candidates": candidates
})
print(output)

OllamaEmbeddings: 100%|██████████| 1/1 [00:03<00:00,  3.69s/it]

messages=[HumanMessage(content='User had interacted with the following businesses:\n{\'item_list\': [\'e4Vwtrqf-wpJfwesgvdgxQ\', \'gUyfJlJRxu1fHuZ4dpBheQ\', \'5AOkxsg6UJQ_CoJTMBDUmQ\', \'xkYOPbA8AL4jcQIN3xveoQ\', \'yJ2ZRXx01eF40eRQFqIBeQ\', \'ZyOqGKdr5JetY4jgD_UoGw\', \'D73evJ9PZKxO3E5TaThe3w\', \'ew5TyXOlyCpCRptye1LdxA\', \'q4aAaxdN4wmUZoC6sKEwsw\', \'4WdDY97x4GdMYtyk1KQMnw\', \'pym7c6ZFEtmoH16xN2ApBg\', \'mhrW9O0O5hXGXGnEYBVoag\', \'fIBMKVl-dyb3KyM11UBJPQ\', \'7gtWQMLOEwCxh1I5j6uB4g\', \'qt_E6txwQ1h62wyv8701UQ\', \'UjFLIhKTOiFcQiziOA9rgA\', \'URxxeb2R60AH81IcuxJAvQ\', \'HIomEsnJRxw0861yD87Qgw\', \'f9H3wpzWG_apxoumWB-Dvg\', \'al3Ri6TEqa2rBzjHsn0T_g\', \'mRpk0A4u0hnF0lNe1h4hGg\', \'dul6XjaCh1GgA-YHpuChUg\', \'5OpNE-GEP1unD89k61XbVQ\', \'c-Drp2IuAXSqjvyzvOPBzQ\', \'N1we1YLrBxPOoenxJwzdOA\', \'Z4PF4EtM12L7nwOHZHFJNA\', \'oFbwMxqaCJfIzAEmwaXD3Q\', \'uJITgt5t7j-KpDChsXPV5w\', \'SjgRHmQ_ClUlECE2JkY8ng\', \'Hj_-qd7KyQPRqTWzWgsFag\', \'hLr6cRTANzEll1hGlmbgHA\', \'nHsoeVL1dXs9ZNjmdrlPuA\', \'e




In [32]:
chain = setup_and_retrieval | prompt | model | StrOutputParser()

result = chain.invoke({
    "item_list": item_list,
    "candidates": candidates
})

OllamaEmbeddings: 100%|██████████| 1/1 [00:05<00:00,  5.68s/it]


In [33]:
print(result)

Tq4dHsAaxAXSqjvyzvOPBzQ, N1we1YLrBxPOoenxJwzdOA, Z4PF4EtM12L7nwOHZHFJNA, oFbwMxqaCJfIzAEmwaXD3Q, uJITgt5t7j-KpDChsXPV5w, SjgRHmQ_ClUlECE2JkY8ng, Hj_-qd7KyQPRqTWzWgsFag, hLr6cRTANzEll1hGlmbgHA, nHsoeVL1dXs9ZNjmdrlPuA, ecI3FBTM0f99Fnml3kNKfg, 8kh6Z3c8UHQKmsy0_TbOnA, RQpOPNHJReRnrsCD-2qEoA, jKTWcdyXPw_cGUp9fKqapQ, Se9CEgJEVxcWax1fStWuQA, R46XVcmUzy8qeerHyZQtEg, dN7AMKUhwTa4Kk0bdMRPgA, BFihjoRdU-jmdbvIEqEsxQ, 3S-u4euLhybQzOuaTAZOpg, SkjUwG0FerzrxnIV8N56CA, 5HZPNcMR5dHQ1OyOb-RDgw, Q9GU2OvZObDVyA00ZJkFaA, cTSczU-9-cYUEM2DlNJcQw, qjP2XXjtLdlZ20SISqtAAA, vvOzblHBA2HHsCb7CMSDbQ, ncI768qIjMFnMwYMppB4tw, 2DTkzhmMpv5fIPKheePClA, 4vak1jxwM6dQ-pNQQ5U8Vw, RKfpN_TqD3wa58kgvnR1lQ, 8UPv1p9GW-BiZtQqUt8nOA, ZfzTw5exIOalHsGlK99y-w, dju1isgEvDd74tLTDkk5DA, dGeXdSMah56gEHwZNaRQKA, 7H1b6TZ-LNxyGx1cv9suJQ, StqG4cdKhTHmGyS7PSimdA, if57kE6_VfR1nI1X93oHEA, hJTwBhYBTkiHaDMml_v_sw, TUTQeLjq1UbkR5r8mOvMqw, PKZwdGTapRvFsBYh0zQXpw, 2x4atI8B9Z0g61bgEOO2Uw, WXgV2lOUgas7DzTLeDau-w, F3b3-mmClvVPUT0WvK_guA, JjUNJCyGQlCxMwO

这么看 RAG 效果并不好啊。。。而且文本内容太短了，效率远不如直接拼接 prompt，还浪费我几个小时做 embedding。。。焯

不过 LangChain 的辅助功能还是能用，比如模板之类的，比手搓要强

### prompt

In [38]:
yelp_rag = {}
with open('RAG/yelp_rag.json', 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line)
        yelp_rag[data['business_id']] = data['categories']        

In [42]:
for idx, (k, v) in enumerate(yelp_rag.items()):
    print(idx, (k, v))
    if idx > 3: break

0 ('MTSW4McQd7CbVtyjqoe9mw', 'Restaurants, Food, Bubble Tea, Coffee & Tea, Bakeries')
1 ('bBDDEgkFA1Otx9Lfe7BZUQ', 'Ice Cream & Frozen Yogurt, Fast Food, Burgers, Restaurants, Food')
2 ('eEOYSgkmpB90uNA7lDOMRA', 'Vietnamese, Food, Restaurants, Food Trucks')
3 ('il_Ro8jwPlHresjw9EGmBg', 'American (Traditional), Restaurants, Diners, Breakfast & Brunch')
4 ('0bPLkL0QhhPO5kt1_EXmNQ', 'Food, Delis, Italian, Bakeries, Restaurants')


In [24]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.llms.ollama import Ollama
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import  StrOutputParser

In [53]:
ollama_model = Ollama(
    #base_url="http://172.16.110.34:45665",  # node09
    base_url="http://localhost:11434",  # whr
    model="qwen2.5:0.5b",
    num_gpu=2,
    # It is recommended to set this value to the number of physical CPU cores
    num_thread=48,
    # system prompt (overrides what is defined in the Modelfile)
    system='You are a recommendation system and required to recommend user with businesses based on user history that each business with id, categories.'
)

In [50]:
template = """User had interacted with the [history] businesses. Output user's favorite and least favorite business_ids from the [candidates], not from the user [history]. Each line is "business_id: categories". Output format are two business_ids separated by ','. The first is favorite and the second is least favorite. No any other things. No reasoning.
[history]
{item_list}

[candidates]
{candidates}
"""
prompt = ChatPromptTemplate.from_template(template)
chain = RunnablePassthrough() | prompt | ollama_model | StrOutputParser()

In [39]:
test_template = """Answer the following question directly without reasoning:
{question}
"""
test_prompt = ChatPromptTemplate.from_template(test_template)
test_chain = RunnablePassthrough() | test_prompt | ollama_model | StrOutputParser()

生成模板数据

In [41]:
import json

item_list = []
candidates = []

# 生成id数据
with open('yelp_out/filter_yelp_interactions.txt', 'r') as f:
    for line in f:
        if line.split(' ')[0].strip() == 'bcjbaE6dDog4jkNY91ncLQ':
            # 以这个user为例，找出所有交互的items
            item_list.append(line.split(' ')[1].strip())
        else:  # 非交互项，取前20个试试
            if len(candidates) < 3 :
                candidates.append(line.split(' ')[1].strip())

# 添加categories数据
categories = {}
with open('yelp_out/yelp_rag.json', 'r') as f:
    for line in f:
        data = json.loads(line)
        categories[data['business_id']] = data['categories']

item_list = [f'{item}: {categories[item]}' for item in item_list]
filter_item = item_list[:3]
candidates = [f'{item}: {categories[item]}' for item in candidates]

print(item_list[0])
print(candidates[0])

e4Vwtrqf-wpJfwesgvdgxQ: Sandwiches, Beer, Wine & Spirits, Bars, Food, Restaurants, American (Traditional), Nightlife
RZtGWDLCAtuipwaZ-UfjmQ: Pizza, Restaurants, Italian, Salad


item_list 应该是个多行字符串，分别为 id 和 categories。candidates 同理。

In [23]:
t_output = test_chain.invoke({
    "question": 'Which number is bigger between 9.9 and 9.11?',
})
print(t_output)

9.11 is bigger than 9.9.


In [51]:
formatted_item_list = "\n".join(filter_item)
formatted_candidates = "\n".join(candidates)
prompt_chain = RunnablePassthrough() | prompt
out_prompt = prompt_chain.invoke({
    "item_list": formatted_item_list,
    "candidates": formatted_candidates
})
print(out_prompt.to_string())

Human: User had interacted with the [history] businesses. Output user's favorite and least favorite business_ids from the [candidates], not from the user [history]. Each line is "business_id: categories". Output format are two business_ids separated by ','. The first is favorite and the second is least favorite. No any other things. No reasoning.
[history]
e4Vwtrqf-wpJfwesgvdgxQ: Sandwiches, Beer, Wine & Spirits, Bars, Food, Restaurants, American (Traditional), Nightlife
gUyfJlJRxu1fHuZ4dpBheQ: Mexican, Restaurants, Latin American
5AOkxsg6UJQ_CoJTMBDUmQ: American (New), Restaurants, Cajun/Creole

[candidates]
RZtGWDLCAtuipwaZ-UfjmQ: Pizza, Restaurants, Italian, Salad
otQS34_MymijPTdNBoBdCw: Restaurants, Tacos, Mexican, Hot Dogs, Breakfast & Brunch, Steakhouses
rBdG_23USc7DletfZ11xGA: Wine Bars, Bars, Nightlife, American (New), Mediterranean, Restaurants



In [54]:
output = chain.invoke({
    "item_list": formatted_item_list,
    "candidates": formatted_candidates
})

print(output)

RZtGWDLCAtuipwaZ-UfjmQ, otQS34_MymijPTdNBoBdCw


### 数据增强

In [None]:
import json

item_list = []
candidates = []

# 生成id数据
with open('mk_dataset/yelp_ds_final/train_data.txt', 'r') as f:
    for line in f:
        if line.split(' ')[0].strip() == 'bcjbaE6dDog4jkNY91ncLQ':
            # 以这个user为例，找出所有交互的items
            item_list.append(line.split(' ')[1].strip())
        else:  # 非交互项，取前20个试试
            if len(candidates) < 20 :
                candidates.append(line.split(' ')[1].strip())

# 添加categories数据
categories = {}
with open('yelp_out/yelp_rag.json', 'r') as f:
    for line in f:
        data = json.loads(line)
        categories[data['business_id']] = data['categories']
item_list = [f'{item}: {categories[item]}' for item in item_list]
candidates = [f'{item}: {categories[item]}' for item in candidates]

print(item_list[0])
print(candidates[0])

In [5]:
import numpy as np
import scipy.stats as stats

# 模型 A 和模型 B 的 Recall@20 值
recall_A = np.array([0.1480,0.1481,0.1482])
recall_B = np.array([0.14649,0.14650,0.14651])

# 计算每次实验的差异
differences = recall_A - recall_B

# 进行配对 t 检验
t_stat, p_value = stats.ttest_rel(recall_A, recall_B)

print(f"t 统计量: {t_stat}, p-value: {p_value}")

t 统计量: 30.792014356777727, p-value: 0.0010530218790085494


### 数据增强2

In [1]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.llms.ollama import Ollama
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import  StrOutputParser
import json

In [2]:
ollama_model = Ollama(
    #base_url="http://172.16.110.34:45665",  # node09
    base_url="http://localhost:11434",  # whr
    model="qwen2.5:0.5b",
    num_gpu=2,
    # It is recommended to set this value to the number of physical CPU cores
    num_thread=48,
    # system prompt (overrides what is defined in the Modelfile)
    system='You are a recommendation system and required to recommend user with businesses based on user history that each business with id, categories.'
)

试一下能不能从文件读取模板

In [19]:
template = """User had interacted with the following businesses. Each line is "business_id: categories". Please summarize this user preference with historical information. Give answers directly. Do not output any other things. Do not give reasoning.

[history]
{history}
"""

template_file = '/home/yzh/code/SELFRec/conf/aug_prompt.txt'
with open(template_file, 'r') as f:
    template = f.read()
print(template)

User had interacted with the following businesses. Each line is "business_id: categories". Please summarize this user preference categories with historical information. Give answers directly. Do not output any other things. Do not give reasoning.

[history]
{history}

[output format]
categorie1, categorie2, ...


In [14]:
with open('yelp_out/yelp_text.json', 'r') as text_file:
    yelp_text = json.load(text_file)

In [20]:
prompt = ChatPromptTemplate.from_template(template)
chain = RunnablePassthrough() | prompt | ollama_model | StrOutputParser()

In [6]:
item_list = []

# 生成id数据
with open('yelp_out/filter_yelp_interactions.txt', 'r') as f:
    for line in f:
        if line.split(' ')[0].strip() == 'bcjbaE6dDog4jkNY91ncLQ':
            # 以这个user为例，找出所有交互的items
            item_list.append(line.split(' ')[1].strip())

item_list = [f'{item}: {yelp_text[item]}' for item in item_list]
filter_item = item_list[:3]

print(filter_item)

['e4Vwtrqf-wpJfwesgvdgxQ: Sandwiches, Beer, Wine & Spirits, Bars, Food, Restaurants, American (Traditional), Nightlife', 'gUyfJlJRxu1fHuZ4dpBheQ: Mexican, Restaurants, Latin American', '5AOkxsg6UJQ_CoJTMBDUmQ: American (New), Restaurants, Cajun/Creole']


In [21]:
form_item_list = '\n'.join(filter_item)
prompt_chain = RunnablePassthrough() | prompt
out_prompt = prompt_chain.invoke({
    "history": form_item_list,
})
print(out_prompt.to_string())

Human: User had interacted with the following businesses. Each line is "business_id: categories". Please summarize this user preference categories with historical information. Give answers directly. Do not output any other things. Do not give reasoning.

[history]
e4Vwtrqf-wpJfwesgvdgxQ: Sandwiches, Beer, Wine & Spirits, Bars, Food, Restaurants, American (Traditional), Nightlife
gUyfJlJRxu1fHuZ4dpBheQ: Mexican, Restaurants, Latin American
5AOkxsg6UJQ_CoJTMBDUmQ: American (New), Restaurants, Cajun/Creole

[output format]
categorie1, categorie2, ...


In [22]:
output = chain.invoke({
    "history": form_item_list,
})

print(output)

American (Traditional), Mexican, American (New)


### 合并文件

由于存在请求中断的情况，所以分批次产生了若干个生成文件，为不遗漏，通过id进行控制

In [1]:
import json
with open('yelp_out/yelp_user_history.json', 'r') as file:
    data = json.load(file)
print(len(data))

116507


In [2]:
user2id = {}
id2user = {}
for i, user in enumerate(data):
    user2id[user] = i
    id2user[i] = user

In [37]:
user2id['_fCu_7tmTX-DevfSyoyqsg']


116449

In [38]:
user2id['V9fW3-fJ-sEMz_ewPpzXXg']

116506

In [15]:
id2user[3339]

'oe3JA8llbDetMWxSPjJHVA'

生成完毕，统计数据量

In [51]:
8649+30000+77800+58

116507

In [29]:
with open('/nvme0n1p2/yelp_out/yelp_user_preference-20240929_0832.json', 'r') as file:
    for line in file:
        data: dict = json.loads(line)
        if list(data.keys())[0] == '-YOWyjJ0bOdSN0LfDSLC4Q':
            print(data)

{'-YOWyjJ0bOdSN0LfDSLC4Q': ''}


合并生成文件(文件不大，就不写流了)

In [49]:
file_paths = [
    '/nvme0n1p2/yelp_out/yelp_user_preference-20240930_1006.json',  # 0-8648
    '/nvme0n1p2/yelp_out/yelp_user_preference-20240928_2011.json',  # 8649-38648
    '/nvme0n1p2/yelp_out/yelp_user_preference-20240929_0832.json',  # 38649-116448
    '/nvme0n1p2/yelp_out/yelp_user_preference-20240930_0837.json'   # 116449-116506
]

In [50]:
%%bash
wc -l /nvme0n1p2/yelp_out/yelp_user_preference-20240930_1006.json
wc -l /nvme0n1p2/yelp_out/yelp_user_preference-20240928_2011.json
wc -l /nvme0n1p2/yelp_out/yelp_user_preference-20240929_0832.json
wc -l /nvme0n1p2/yelp_out/yelp_user_preference-20240930_0837.json

8649 /nvme0n1p2/yelp_out/yelp_user_preference-20240930_1006.json
30000 /nvme0n1p2/yelp_out/yelp_user_preference-20240928_2011.json
77800 /nvme0n1p2/yelp_out/yelp_user_preference-20240929_0832.json
58 /nvme0n1p2/yelp_out/yelp_user_preference-20240930_0837.json


In [51]:
8649+30000+77800+58

116507

In [52]:
%%bash
head /nvme0n1p2/yelp_out/yelp_user_preference-20240930_0837.json -n 1
tail /nvme0n1p2/yelp_out/yelp_user_preference-20240930_0837.json -n 1

{"_fCu_7tmTX-DevfSyoyqsg": "restaurants, bars"}
{"V9fW3-fJ-sEMz_ewPpzXXg": "Music Venues, Bars"}


json line -> json

In [56]:
import json

user_preferences: dict[str, str] = {}  # user->categories
with open('yelp_user_preferences.v2.json', 'w', encoding='utf-8') as fout:
    for path in file_paths:
        with open(path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data: dict[str, str] = json.loads(line)  # user->categories
                    for k,v in data.items():  # 其实只有一组
                        if k in user_preferences: # 应该不会有
                            print('exist duplicate user: ', k)
                        user_preferences[k] = v
                except Exception as e:
                    with open('error_preferences.txt', 'a', encoding='utf-8') as errors:
                        errors.write(f"{path.split('-')[1]}: {line}")
    
    json.dump(user_preferences, fout, ensure_ascii=False)


In [57]:
error = 0
with open('yelp_user_preferences.v1.json', 'r') as file:
    data = json.load(file)
    for k, v in data.items():
        if v.strip() == "":
            error += 1
    print(f"total {len(data)} users, {error} empty fields")

total 116507 users, 73 empty fields


执行 `python ollama_aug.py --type specific` 后，更新 `yelp_user_preferences.json`。

In [60]:
import json

with open('yelp_user_preferences.v1.json', 'r') as origin:
    origin_data: dict[str, str] = json.load(origin)
    print(f"before: {len(origin_data)}")

    fix_path = '/nvme0n1p2/yelp_out/yelp_user_preference-20241007_1708.json'
    fix_data: dict[str, str] = {}
    with open(fix_path, 'r') as f:
        for line in f:
            data: dict[str, str] = json.loads(line)
            for k,v in data.items():
                fix_data[k] = v
    print(f"fix: {len(fix_data)}")

    # 更新数据
    with open('yelp_user_preferences.v2.json', 'w') as output:
        for k in fix_data.keys():
            if k in origin_data:
                origin_data[k] = fix_data[k]
            else:
                print(f"{k} is in fix data but not in origin data!")
        json.dump(origin_data, output, ensure_ascii=False)
    print(f"after: {len(origin_data)}")

before: 116507
fix: 73
after: 116507


In [61]:
error = 0
with open('yelp_user_preferences.v2.json', 'r') as file:
    data = json.load(file)
    for k, v in data.items():
        if v.strip() == "":
            error += 1
    print(f"total {len(data)} users, {error} empty fields")

total 116507 users, 0 empty fields


再次添加，直接构建全局的吧，反正差不了几个

In [15]:
import json

with open('yelp_out/yelp_user_preferences.v2.json', 'r') as origin:
    origin_data: dict[str, str] = json.load(origin)
    print(f"before: {len(origin_data)}")

    fix_path = 'yelp_out/preferences/yelp_user_preference-20241022_2023.json'
    fix_data: dict[str, str] = {}
    with open(fix_path, 'r') as f:
        data: dict[str, str] = json.load(f)
        for k,v in data.items():
            fix_data[k] = v
    print(f"fix: {len(fix_data)}")

    # 更新数据
    with open('yelp_user_preferences.v3.json', 'w') as output:
        for k, v in fix_data.items():
            if k in origin_data:
                print('exist')
            else:
                origin_data[k] = v
        json.dump(origin_data, output, ensure_ascii=False)
    print(f"after: {len(origin_data)}")

before: 116507
fix: 508
after: 117015


## 文本嵌入生成

[stella_en_1.5B_v5](https://huggingface.co/bennegeek/stella_en_1.5B_v5)

In [18]:
from sentence_transformers import SentenceTransformer

query_prompt_name = "s2s_query"
model = SentenceTransformer('/home/yzh/code/SELFRec/model/stella_en_1.5B_v5', device="cuda:1", trust_remote_code=True)

  from tqdm.autonotebook import tqdm, trange


生成 user_pref_embs

In [16]:
import json
import torch
import numpy as np
from safetensors.torch import save_file, load_file

# 读取数据
with open('yelp_out/yelp_user_preferences.v3.json', 'r', encoding='utf-8') as file:
    user_preferences: dict[str, str] = json.load(file)

In [3]:
user_pre_embs: dict[str, torch.Tensor] = {}
user_pre_list: list[str] = []
for prefs in user_preferences.values():
    user_pre_list.append(prefs)
print(f'total text: {len(user_pre_list)}')

# 模型编码
pre_embs = model.encode(user_pre_list, prompt_name=query_prompt_name, device='cuda:1', batch_size=16, convert_to_tensor=True)
print(f"output embedding shape: {pre_embs.shape}")

total text: 116507
output embedding shape: torch.Size([116507, 1024])


In [20]:
from tqdm import tqdm

pref_embs = load_file("yelp_ds/user_pre_embs.safetensors", device="cuda:1")
for user in tqdm(user_preferences, desc='fix user pref'):
    if user in pref_embs: continue
    emb = model.encode(user_preferences[user], prompt_name=query_prompt_name, device='cuda:1', convert_to_tensor=True)
    pref_embs[user] = emb

metadata = {
    "type": "pt",
    "user num": "117015",
    "dim": "1024",
    "build by": "https://github.com/sun2ot",
    "time": "2024-10-22 20:51"
}
save_file(pref_embs, "user_pre_embs.safetensors")

fix user pref: 100%|██████████| 117015/117015 [00:15<00:00, 7533.11it/s]  


In [22]:
print(len(pref_embs))

117015


In [14]:
from safetensors.torch import load_file

loaded = load_file('user_pre_embs.safetensors', device='cuda:1')
print(len(loaded))
print(loaded['GziM44xJcoR4jJByq10NQA'])

116507
tensor([-2.0780,  1.0266, -0.3841,  ..., -1.3006, -2.8999,  3.2373],
       device='cuda:1')


In [13]:
from safetensors import safe_open

with safe_open('user_pre_embs.safetensors', framework='pt', device='cuda:1') as f: # type: ignore
    print(len(f.keys()))
    print(f.get_tensor('GziM44xJcoR4jJByq10NQA'))

116507
tensor([-2.0780,  1.0266, -0.3841,  ..., -1.3006, -2.8999,  3.2373],
       device='cuda:1')


如法炮制, 生成 text_embs

In [5]:
import json
import torch
import numpy as np
from safetensors.torch import save_file

with open('yelp_out/yelp_text.json', 'r', encoding='utf-8') as text_file:
    item_text: dict[str, str] = json.load(text_file)

item_text_embs: dict[str, torch.Tensor] = {}
item_text_list: list[str] = []
for text in item_text.values():
    item_text_list.append(text)
print(f'total text: {len(item_text_list)}')

text_embs = model.encode(item_text_list, device='cuda:1', prompt_name=query_prompt_name, batch_size=32, convert_to_tensor=True)
print(f"output embedding shape: {text_embs.shape}")

for idx, item in enumerate(item_text.keys()):
    item_text_embs[item] = text_embs[idx]
save_file(item_text_embs, 'item_text_embs.safetensors')

total text: 33183
output embedding shape: torch.Size([33183, 1024])


In [7]:
from safetensors import safe_open

with safe_open('item_text_embs.safetensors', framework='pt', device='cuda:1') as f: # type: ignore
    print(len(f.keys()))
    print(f.get_tensor('MTSW4McQd7CbVtyjqoe9mw'))

33183
tensor([-0.8660,  2.7190, -0.9495,  ...,  1.4988, -0.4805,  2.8815],
       device='cuda:1')


# 缩小数据集

先构建交互字典，便于后续处理

In [1]:
import json
from collections import defaultdict

ui_dict = defaultdict(dict)
with open('yelp_ds/filter_yelp_interactions.txt', 'r') as file:
    for line in file:
        user, item, ratings = line.strip().split(' ')
        ui_dict[user][item] = ratings

# 持久化
with open('yelp_ds/ui_dict.json', 'w', encoding='utf-8') as file:
    json.dump(ui_dict, file, ensure_ascii=False)

In [2]:
with open('yelp_ds/ui_dict.json', 'r', encoding='utf-8') as file:
    ui_data = json.load(file)
print(len(ui_data))
del ui_data

117015


In [6]:
import json
import random

with open('yelp_ds/ui_dict.json', 'r', encoding='utf-8') as file:
    # 有评分, 浮点型, 以字符串保存
    ui_dict: dict[str, dict[str, str]] = json.load(file)
origin = len(ui_dict)
print(f"origin: {origin}")

not_core_user=0
for user, item_ratings in list(ui_dict.items()):
    if len(item_ratings) < 15:
        not_core_user+=1
        # 先删除交互不足15次的user
        del ui_dict[user]
print(f"del not_core_user: {not_core_user}")

for user, item_ratings in ui_dict.items():
    if len(item_ratings) < 10: raise Exception('exit user it < 10')
    if len(item_ratings) > 20:
        # 保留至多20次交互
        select_items = random.sample(list(item_ratings.keys()), 20)
        ui_dict[user] = {item: item_ratings[item] for item in select_items}

item_set = set()
it_num = 0
for user, item_ratings in ui_dict.items():
    it_num += len(item_ratings)
    item_set.update(item_ratings.keys())

print(f"users: {len(ui_dict)}, items: {len(item_set)}, it_num: {it_num}")


# 持久化
with open('yelp_tiny/ui_dict.json', 'w', encoding='utf-8') as file:
    json.dump(ui_dict, file, ensure_ascii=False)

origin: 117015
del not_core_user: 79618
users: 37397, items: 32491, it_num: 707178


In [7]:
with open('yelp_tiny/ui_dict.json', 'r', encoding='utf-8') as file:
    d = json.load(file)
print(len(d))

37397


#### 划分数据集

In [8]:
with open('yelp_tiny/ui_dict.json', 'r', encoding='utf-8') as file:
    ui_dict = json.load(file)

train_set = []
val_set = []
test_set = []
not_core = 0

for user, item_ragings in ui_dict.items():
    items = list(item_ragings.keys())
    random.shuffle(items)  # 打乱交互数据
    it_num = len(items)
    if it_num < 10: not_core+=1  # 不足10个会导致对应测试集为空
    train_num = int(it_num * 0.7)
    val_num = int(it_num * 0.1) + train_num

    # 将每个用户的交互记录划分到三个子集
    for item in items[:train_num]: train_set.append((user, item, item_ragings[item]))
    for item in items[train_num:val_num]: val_set.append((user, item, item_ragings[item]))
    for item in items[val_num:]: test_set.append((user, item, item_ragings[item]))
print(not_core)

0


In [9]:
print(len(train_set))
print(len(val_set))
print(len(test_set))

488814
62360
156004


In [10]:
from tqdm import tqdm
def save_dataset(dataset: list[tuple[str, str, str]], filename: str):
    with open(filename, "w", encoding='utf-8') as f:
        for user, item, ratings in tqdm(dataset):
            f.write(f"{user} {item} {ratings}\n")

In [11]:
save_dataset(train_set, 'yelp_tiny/train.txt')
save_dataset(val_set, 'yelp_tiny/val.txt')
save_dataset(test_set, 'yelp_tiny/test.txt')

100%|██████████| 488814/488814 [00:00<00:00, 2066120.92it/s]
100%|██████████| 62360/62360 [00:00<00:00, 1542025.35it/s]
100%|██████████| 156004/156004 [00:00<00:00, 1704365.84it/s]


In [12]:
merge_train = []
for i in train_set: merge_train.append(i)
for j in val_set: merge_train.append(j)
print(len(merge_train))
save_dataset(merge_train, 'yelp_tiny/merge_train.txt')

551174


100%|██████████| 551174/551174 [00:00<00:00, 1696203.46it/s]


### 图像模态二次处理

yelp 图像模态数据需要再处理下，多图直接合一，避免模型过程中处理，影响代码复用性

In [15]:
from safetensors import safe_open
from safetensors.torch import save_file
import torch

item2image: dict[str, list[str]] = {}
with safe_open('yelp_ds/photo_embs.safetensors', 'pt', device="cuda:0") as image_safetensors: # type: ignore
    item_set = set()
    with open('yelp_out/re_item2photos.txt', 'r') as map_file:
        for line in map_file:
            item = line.strip().split(' ')[0]
            item_set.add(item)
            images = line.strip().split(' ')[1:]
            item2image[item] = images

    merge_image_tensor: dict[str, torch.Tensor] = {}
    for item in tqdm(item_set, desc='item images'):
        # 这里是全局图像, 不是 tiny
        try:
            merge_image_tensor[item] = torch.mean(
                torch.stack([image_safetensors.get_tensor(image) for image in item2image[item]]), dim=0
            )
        except Exception as e:
            print(e)
            exit(-1)
    
    save_file(merge_image_tensor, 'yelp_ds/item_image_emb.safetensors')

item images: 100%|██████████| 33183/33183 [00:09<00:00, 3548.65it/s]
