# 处理 tiny 数据集

## ifashion

In [2]:
import os
os.makedirs('ifashion_tiny/mmssl', exist_ok=True)

### 编码映射

In [3]:
import os
import json

user2id: dict[str, int] = {}
item2id: dict[str, int] = {}

with open("ifashion_tiny/ui_dict.json", 'r') as file:
    ui_dict: dict[str, dict[str, int]] = json.load(file)

item_set = set()
for user, item_ratings in ui_dict.items():
    user2id[user] = len(user2id)
    item_set.update(item_ratings.keys())

for idx, item in enumerate(item_set):
    item2id[item] = idx

os.makedirs("ifashion_tiny/remap", exist_ok=True)
with open("ifashion_tiny/remap/user2id.json", 'w', encoding='utf-8') as f1:
    json.dump(user2id, f1, ensure_ascii=False)
with open("ifashion_tiny/remap/item2id.json", 'w', encoding='utf-8') as f2:
    json.dump(item2id, f2, ensure_ascii=False)

In [5]:
with open("ifashion_tiny/remap/user2id.json", 'r', encoding='utf-8') as f1:
    uid = json.load(f1)
print(len(uid))

with open("ifashion_tiny/remap/item2id.json", 'r', encoding='utf-8') as f2:
    iid = json.load(f2)
print(len(iid))

38403
20000


### 交互数据

In [6]:
import json
from collections import defaultdict

with open("ifashion_tiny/remap/user2id.json", 'r', encoding='utf-8') as f1:
    user2id: dict[str, int] = json.load(f1)

with open("ifashion_tiny/remap/item2id.json", 'r', encoding='utf-8') as f2:
    item2id: dict[str, int] = json.load(f2)

分别将各个数据集进行转换

In [7]:
def pre_save(interactions: dict[str, set]) -> dict[str, list[int]]:
    """将dict.values转为list"""
    it_dict = {}
    for k,v in interactions.items():
        it_dict[k] = list(v)
    return it_dict

def ds_convert_json(txt_path: str, json_path: str, user2id: dict[str, int], item2id: dict[str, int]):
    out_json: dict[str, set] = defaultdict(set[int])
    with open(txt_path, 'r', encoding='utf-8') as f:
        for line in f:
            user, item, _ratings = line.split(' ')
            out_json[str(user2id[user])].add(item2id[item])
    
    with open(json_path, 'w', encoding='utf-8') as f:
        json.dump(pre_save(out_json), f, ensure_ascii=False)
    print(f"{txt_path} convert to {json_path}")

In [8]:
ds_convert_json(
    'ifashion_tiny/train.txt',
    'ifashion_tiny/mmssl/train.json',
    user2id, item2id
)

ifashion_tiny/train.txt convert to ifashion_tiny/mmssl/train.json


In [9]:
ds_convert_json(
    'ifashion_tiny/val.txt',
    'ifashion_tiny/mmssl/val.json',
    user2id, item2id
)

ifashion_tiny/val.txt convert to ifashion_tiny/mmssl/val.json


In [10]:
ds_convert_json(
    'ifashion_tiny/test.txt',
    'ifashion_tiny/mmssl/test.json',
    user2id, item2id
)

ifashion_tiny/test.txt convert to ifashion_tiny/mmssl/test.json


按照编码映射转换全局交互数据集

In [15]:
with open("ifashion_tiny/ui_dict.json", 'r') as file:
    ui_dict: dict[str, dict[str, int]] = json.load(file)

remap_ui_dict: dict[str, list[int]] = {}
for user, item_ratings in ui_dict.items():
    remap_ui_dict[str(user2id[user])] = [item2id[item] for item in item_ratings.keys()]

with open("ifashion_tiny/remap/remap_ui_list.json", 'w') as file:
    json.dump(remap_ui_dict, file)

In [21]:
with open("ifashion_tiny/remap/remap_ui_list.json", 'r') as file:
    d = json.load(file)
print(len(d))
n=0
for u, il in d.items():
    if len(il) == 0: n+=1
print(n)

38403
31


### 构造 ifashion 稀疏矩阵

In [22]:
import json
from scipy.sparse import csr_matrix

def read_user_items(file_path):
    with open(file_path, 'r') as file:
        user_items = json.load(file)
    return user_items

with open("ifashion_tiny/remap/user2id.json", 'r', encoding='utf-8') as f1:
    user2id: dict[str, int] = json.load(f1)

with open("ifashion_tiny/remap/item2id.json", 'r', encoding='utf-8') as f2:
    item2id: dict[str, int] = json.load(f2)

def create_sparse_matrix(user_items: dict[str, list[int]], user2id, item2id):
    # 获取用户和项目的最大编号
    max_user = max(list(user2id.values()))
    max_item = max(list(item2id.values()))

    # 初始化稀疏矩阵
    rows = []
    cols = []
    data = []

    for user, items in user_items.items():
        user_id = int(user)
        for item_id in items:
            rows.append(user_id)
            cols.append(item_id)
            data.append(1.0)  # ratings设为1.0

    # 构造稀疏矩阵
    train_mat = csr_matrix((data, (rows, cols)), shape=(max_user + 1, max_item + 1))
    return train_mat

In [24]:
import pickle

user_items = read_user_items('ifashion_tiny/remap/remap_ui_list.json')

# 构造稀疏矩阵
train_mat = create_sparse_matrix(user_items, user2id, item2id)

# 保存稀疏矩阵到文件
with open('ifashion_tiny/mmssl/train_mat', 'wb') as file:
    pickle.dump(train_mat, file)

In [26]:
with open('ifashion_tiny/mmssl/train_mat', 'rb') as file:
    matrix = pickle.load(file)
print(matrix.shape)

(38403, 20000)


### 多模态数据

In [28]:
import json
from safetensors.numpy import load_file
from tqdm import tqdm
import numpy as np

with open("ifashion_tiny/remap/item2id.json", 'r', encoding='utf-8') as f2:
    item2id = json.load(f2)

image_id2embs: dict[str, np.ndarray] = {}
image_ndarrays = load_file("ifashion_ds/ifashion_image.safetensors")

for item in tqdm(item2id, desc='image embs'):
    image_id2embs[item2id[item]] = np.squeeze(image_ndarrays[item])

# 提取键并排序
sorted_keys = sorted(image_id2embs.keys(), key=int)
sorted_arrays = [image_id2embs[key] for key in sorted_keys]
result_array = np.stack(sorted_arrays, axis=0)  # (all_item_num, dim)

print(result_array.shape)
np.save('ifashion_tiny/mmssl/image_feat.npy', result_array)

image embs: 100%|██████████| 20000/20000 [00:00<00:00, 685198.24it/s]

(20000, 512)





文本特征

In [29]:
import json
from safetensors.numpy import load_file
from safetensors import safe_open
from tqdm import tqdm
import numpy as np

with open("ifashion_tiny/remap/item2id.json", 'r', encoding='utf-8') as f2:
    item2id = json.load(f2)

text_id2embs: dict[str, np.ndarray] = {}
text_ndarrays = load_file("ifashion_ds/ifashion_text.safetensors")

for item in tqdm(item2id, desc='text embs'):
    text_id2embs[item2id[item]] = np.squeeze(text_ndarrays[item])

# 提取键并排序
sorted_keys = sorted(text_id2embs.keys(), key=int)
sorted_arrays = [text_id2embs[key] for key in sorted_keys]
result_array = np.stack(sorted_arrays, axis=0)  # (all_item_num, dim)

print(result_array.shape)
np.save('ifashion_tiny/mmssl/text_feat.npy', result_array)

text embs: 100%|██████████| 20000/20000 [00:00<00:00, 886651.31it/s]

(20000, 1024)





In [30]:
%%bash
rsync -avcP ifashion_tiny/mmssl/* /home/yzh/code/MMSSL/MMSSL/data/

sending incremental file list
image_feat.npy
     40,960,128 100%  394.26MB/s    0:00:00 (xfr#1, to-chk=5/6)
test.json
      1,116,489 100%   10.44MB/s    0:00:00 (xfr#2, to-chk=4/6)
text_feat.npy
     81,920,128 100%  260.42MB/s    0:00:00 (xfr#3, to-chk=3/6)
train.json
      2,032,104 100%    6.35MB/s    0:00:00 (xfr#4, to-chk=2/6)
train_mat
      4,747,227 100%   14.28MB/s    0:00:00 (xfr#5, to-chk=1/6)
val.json
        320,381 100%  983.87kB/s    0:00:00 (xfr#6, to-chk=0/6)

sent 131,128,984 bytes  received 130 bytes  87,419,409.33 bytes/sec
total size is 131,096,457  speedup is 1.00
