In [1]:
import os
import pickle
import jsonlines
import pandas as pd
import numpy as np
import json
import copy
from tqdm import tqdm

In [2]:
data = json.load(open("./handled/item2attributes_A.json", "r"))

In [None]:
len(data)

In [None]:
example_dict = {}
for item_dict in tqdm(data.values()):
    example_dict.update(item_dict)

In [None]:
id_map = json.load(open("./handled/id_map.json", "r"))["item_dict"]["0"]["str2id"]
title_data = {}
for key, value in tqdm(data.items()):
    title_data[id_map[key]] = value["title"][:100]

In [None]:
# the number of items that do not have name
print("the number of items that do not have name: {}".format(len(id_map.values()) - len(data)))

In [None]:
title_list = []
for id in range(1, len(id_map)+1):
    if id not in title_data.keys():
        title_list.append("no name")
    else:
        title_list.append(title_data[id])

assert len(title_list) == len(id_map)

with open("./handled/title_A.pkl", "wb") as f:
    pickle.dump(title_list, f)

In [None]:
example_dict.keys()

In [None]:
example_dict["description"][0]

In [12]:
def get_attri(item_str, attri, item_info):

    if attri not in item_info.keys() or len(item_info[attri]) > 100:
        new_str = item_str.replace(f"<{attri.upper()}>", "unknown")
    else:
        new_str = item_str.replace(f"<{attri.upper()}>", str(item_info[attri]))

    return new_str

In [13]:
def get_feat(item_str, feat, item_info):

    if feat not in item_info.keys():
        return ""
    
    assert isinstance(item_info[feat], list)
    feat_str = ""
    for meta_feat in item_info[feat]:
        feat_str = feat_str + meta_feat + "; "
    new_str = item_str.replace(f"<{feat.upper()}>", feat_str)

    if len(new_str) > 128: # avoid exceed the input length limitation
        return new_str[:128]

    return new_str

In [None]:
prompt_template = "The electronic item has following attributes: \n name is <TITLE>; brand is <BRAND>; price is <PRICE>, rating is <DATE>, price is <PRICE>. \n"
feat_template = "The item has following features: <CATEGORY>. \n"
desc_template = "The item has following descriptions: <DESCRIPTION>. \n"

In [None]:
item_data = {}
for key, value in tqdm(data.items()):
    item_str = copy.deepcopy(prompt_template)
    item_str = get_attri(item_str, "title", value)
    item_str = get_attri(item_str, "brand", value)
    item_str = get_attri(item_str, "date", value)
    # item_str = get_attri(item_str, "rank", value)
    item_str = get_attri(item_str, "price", value)

    feat_str = copy.deepcopy(feat_template)
    feat_str = get_feat(feat_str, "category", value)
    desc_str = copy.deepcopy(desc_template)
    desc_str = get_feat(desc_str, "description", value)
    
    item_data[key] = item_str + feat_str + desc_str

In [16]:
len_list = []
for item_str in item_data.values():
    len_list.append(len(item_str))

In [None]:
np.mean(len_list)

In [18]:
json.dump(item_data, open("./handled/item_str_A_truncate.json", "w"))

In [19]:
# convert to jsonline
def save_data(data_path, data):
    '''write all_data list to a new jsonl'''
    with jsonlines.open("./handled/"+ data_path, "w") as w:
        for meta_data in data:
            w.write(meta_data)

id_map = json.load(open("./handled/id_map.json", "r"))["item_dict"]["0"]["str2id"]
json_data = []
for key, value in item_data.items():
    json_data.append({"input": value, "target": "", "item": key, "item_id": id_map[key]})

json_data = sorted(json_data, key=lambda x: x["item_id"])
save_data("item_str_A_truncate.jsonline", json_data)

In [20]:
import requests
import json

In [21]:
def get_response(prompt):
    url = ""

    payload = json.dumps({
    "model": "text-embedding-ada-002",
    "input": prompt
    })
    headers = {
    'Authorization': '',
    'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',
    'Content-Type': 'application/json'
    }

    response = requests.request("POST", url, headers=headers, data=payload)
    re_json = json.loads(response.text)

    return re_json["data"][0]["embedding"]

In [22]:
item_emb = {}

In [None]:
value_list = []

for key, value in tqdm(item_data.items()):
    if len(value) > 4096:
        value_list.append(key)

In [24]:
if os.path.exists("./handled/item_emb_A.pkl"):    # check whether some item emb exist in cache
    item_emb = pickle.load(open("./handled/item_emb_A.pkl", "rb"))
else:
    item_emb = {}

In [None]:
count = 1
while 1:    # avoid broken due to internet connection
    if len(item_emb) == len(item_data):
        break
    try:
        for key, value in tqdm(item_data.items()):
            if key not in item_emb.keys():
                if len(value) > 4096:
                    value = value[:4095]
                item_emb[key] = get_response(value)
                count += 1
    except:
        pickle.dump(item_emb, open("./handled/item_emb_A.pkl", "wb"))

In [None]:
id_map = json.load(open("./handled/id_map.json", "r"))["item_dict"]["0"]["id2str"]
emb_list = []
for id in range(1, len(id_map)+1):
    if id_map[str(id)] in item_emb.keys():
        meta_emb = item_emb[id_map[str(id)]]
    else:
        meta_emb = [0] * len(list(item_emb.values())[0])
    emb_list.append(meta_emb)

emb_list = np.array(emb_list)
pickle.dump(emb_list, open("./handled/itm_emb_np_A.pkl", "wb"))

In [None]:
# 确保LLM embedding和物品的数量是相同的
assert len(emb_list) == len(id_map)