In [1]:
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import pickle
import json

In [2]:
data_category = ["book", "movie", "music"]
# flag: book: 0, movie: 1, music: 2

1. read_data: 把所有数据读出来，然后存到一个list中
2. filter: 单纯过滤掉交互过小的交互，返回的还是list
3. id_map: 制作user和item的dict映射，并拆掉list，变成一个用户的交互序列

In [3]:
def read_data(data_category, domain_flag):
    data_path = r"./raw/douban/{}reviews_cleaned.txt".format(data_category)
    lines = open(data_path).readlines()
    
    data = []
    for line in tqdm(lines[1:]):

        if domain_flag == 0:
            user_id, item_id, rating, label, comment, time, inter_id = line.strip().split("\t")
        elif domain_flag == 1:
            try:    # some uncleaned data
                user_id, item_id, rating, comment, time, label, _, _, inter_id = line.strip().split("\t")
            except:
                continue
        elif domain_flag == 2:
            user_id, item_id, rating, label, comment, _, time, inter_id = line.strip().split("\t")
        
        user_id = int(user_id.replace('"', ''))
        item_id = int(item_id.replace('"', ''))
        label = label.replace('"', '')
                
        data.append([user_id, 
                    item_id, 
                    # float(rating.replace('"', '')), 
                    label,
                    # str(comment.replace('"', '')), 
                    int(time.replace('"', '').replace('-','').replace(':','').replace(' ','')), 
                    # str(inter_id.replace('"', '')),
                    domain_flag])
    return data
        

In [4]:
def count_inter(data):
    
    user_count = {}
    item_count = {}
    for inter in data:
        user_id, item_id, _, _, _ = inter
        
        if user_id not in user_count.keys():
            user_count[user_id] = 1
        else:
            user_count[user_id] += 1

        if item_id not in item_count.keys():
            item_count[item_id] = 1
        else:
            item_count[item_id] += 1
    
    return user_count, item_count

In [5]:
def filter(data, user_minmum, item_minimum, 
           t_min=20160101, t_max=20161232):   # 过滤掉交互少的数据
    
    user_count, item_count = count_inter(data)
    domain_set = {0: {"user": [], "item": []},
                  1: {"user": [], "item": []},
                  2: {"user": [], "item": []},}
    new_data = []

    for inter in tqdm(data):
        user_id, item_id, _, time, domain_id = inter
        
        if item_count[item_id] > item_minimum and user_count[user_id] > user_minmum \
            and time > t_min and time < t_max:
            
            new_data.append(inter)
            domain_set[domain_id]["user"].append(user_id)
            domain_set[domain_id]["item"].append(item_id)
    
    print("filter done!")

    return new_data, domain_set

In [6]:
def make_sequence(data):

    seq = {}
    domain_seq = {}

    for inter in tqdm(data):
        user_id, item_id, label, time, domain_id = inter
        if user_id not in seq.keys():
            seq[user_id] = [item_id]
            domain_seq[user_id] = [domain_id]
        else:
            seq[user_id].append(item_id)
            domain_seq[user_id].append(domain_id)

    return seq, domain_seq

In [7]:
def id_map(data, domain_set):
    
    final_data, final_domain = {}, {}
    temp_data = {}
    new_user_id = 1
    temp_item_count = {domain_id: len(set(domain_set[domain_id]["item"])) for domain_id in domain_set.keys()}
    item_count = {0: 1, 1: 1, 2: 1}
    item_dict = {
        0: {"str2id": {}, "id2str": {}, "id2label": {},},
        1: {"str2id": {}, "id2str": {}, "id2label": {},},
        2: {"str2id": {}, "id2str": {}, "id2label": {},},
    }
    user_dict = {"str2id": {}, "id2str": {},}

    for inter in tqdm(data):
        user_id, item_id, label, time, domain_id = inter
            
        if item_id not in item_dict[domain_id]["str2id"].keys():
            new_item_id = item_count[domain_id]
            item_dict[domain_id]["str2id"][item_id] = new_item_id
            item_dict[domain_id]["id2str"][new_item_id] = item_id
            item_dict[domain_id]["id2label"][new_item_id] = [label]
            item_count[domain_id] += 1
        else: # add label as the description of items
            if label not in item_dict[domain_id]["id2label"][item_dict[domain_id]["str2id"][item_id]]:
                item_dict[domain_id]["id2label"][item_dict[domain_id]["str2id"][item_id]].append(label)
        
        if user_id not in user_dict["str2id"].keys():
            user_dict["str2id"][user_id] = new_user_id
            user_dict["id2str"][new_user_id] = user_id
            temp_data[new_user_id] = [(item_dict[domain_id]["str2id"][item_id], domain_id, time)]
            new_user_id += 1
        else:
            temp_data[user_dict["str2id"][user_id]].append((item_dict[domain_id]["str2id"][item_id], domain_id, time))

    print("map done!")

    for user_id, inter in tqdm(temp_data.items()):

        inter.sort(key=lambda x: x[2])
        final_data[user_id] = [temp_tuple[0] for temp_tuple in inter]
        final_domain[user_id] = [temp_tuple[1] for temp_tuple in inter]

    print("sort done!")
    
    return final_data, final_domain, user_dict, item_dict, item_count


In [None]:
all_data = []
# all_item = []
for flag, category in enumerate(data_category):
    data = read_data(category, flag)
    all_data.append(data)
    # all_item.append(item_dict)
    print("{} is done".format(category))
all_data = all_data[0] + all_data[1] + all_data[2]

In [9]:
# 统计数据集的交互在时间上的分布
# time_list = []
# for inter in all_data[0]:
#     _, _, _, inter_time, _ = inter
#     time_list.append(inter_time)

In [10]:
# plt.hist(time_list, bins=10)

In [None]:
new_data, domain_set = filter(all_data, user_minmum=1, item_minimum=2)
final_data, final_domain, user_dict, item_dict, item_count = id_map(new_data, domain_set)
item_count = {domain_id: len(set(domain_set[domain_id]["item"])) for domain_id in domain_set.keys()}
item_dict["item_count"] = item_count

In [12]:
# seq, domain_seq = make_sequence(all_data)

In [None]:
# book和movie两个domain交集的用户数量
len(set(domain_set[0]["user"]) & set(domain_set[1]["user"])), len(set(domain_set[0]["user"])), len(set(domain_set[1]["user"]))

In [None]:
# 验证map是否能对上
print(item_count)
max(item_dict[0]["str2id"].values()), max(item_dict[1]["str2id"].values()), max(item_dict[2]["str2id"].values())

把所有数据先存下来
可以使用final_domain去进行数据筛选

In [19]:
with open("./handled/id_map.json", "w") as f:
    json.dump({"user_dict": user_dict, "item_dict": item_dict}, f)
with open("./handled/douban_all.pkl", "wb") as f:
    pickle.dump((final_data, final_domain), f)

In [20]:
with open("./handled/id_map.json", "r") as f:
    map_dict = json.load(f)
user_dict = map_dict["user_dict"]
item_dict = map_dict["item_dict"]

with open("./handled/douban_all.pkl", "rb") as f:
    final_data, final_domain = pickle.load(f)

筛选book-movie两个domain

这里选的是book和movie两个domain

In [None]:
## 先筛选final_data和final_domain
bm_data, bm_domain = {}, {}
for user_id, inter in tqdm(final_domain.items()):
    inter = np.array(inter)
    inter_data = np.array(final_data[user_id])
    bm_data[user_id] = inter_data[np.where(np.logical_or(inter==0, inter == 1))]
    bm_domain[user_id] = inter[np.where(np.logical_or(inter==0, inter == 1))]

In [None]:
domain_stats = []
for inter in bm_domain.values():
    domain_stats.append(np.mean(inter))

In [None]:
# 统计两个domain中overlap的用户
domain_stats = np.array(domain_stats)
domain_stats[domain_stats==0].shape[0], domain_stats[domain_stats==1].shape[0], domain_stats.shape[0]

In [None]:
# 统计整体序列的长度
inter_len = []
for inter in bm_data.values():
    inter_len.append(len(inter))
print(np.mean(inter_len))
plt.hist(inter_len, bins=30)

In [None]:
inter_len = np.array(inter_len)
len(inter_len[inter_len>200]) / len(inter_len)

In [27]:
with open("./handled/book_movie.pkl", "wb") as f:
    pickle.dump((bm_data, bm_domain), f)

In [None]:
# 统计重复交互的问题
# _, i_counts = np.unique(bm_data[0], return_counts=True)
# np.sum(i_counts), len(i_counts)

需要构造cold-start场景

In [None]:
cold_all_data = []
for flag, category in enumerate(["book", "movie"]):
    data = read_data(category, flag)
    cold_all_data.append(data)
    # all_item.append(item_dict)
    print("{} is done".format(category))
cold_all_data = cold_all_data[0] + cold_all_data[1]

In [12]:
def cold_filter(data, item_dict, t_min=20170101, t_max=20171232):   # 过滤掉交互少的数据
    
    domain_set = {0: {"user": [], "item": []},
                  1: {"user": [], "item": []},
                  2: {"user": [], "item": []},}
    new_data = []

    for inter in tqdm(data):
        user_id, item_id, _, time, domain_id = inter
        
        if time > t_min and time < t_max \
            and item_id in item_dict[domain_id]["str2id"].keys():
            
            new_data.append(inter)
            domain_set[domain_id]["user"].append(user_id)
            domain_set[domain_id]["item"].append(item_id)
    
    print("filter done!")

    return new_data, domain_set

In [None]:
new_cold_data, cold_domain_set = cold_filter(cold_all_data, item_dict)
final_cold_data, final_cold_domain, _, _, _ = id_map(new_cold_data, cold_domain_set)

In [None]:
random_num = {0: 0, 1: 0}
filter_cold_data, filter_cold_domain = {}, {}

for i in tqdm(range(1, len(final_cold_data))):
    random_inter, random_domain = final_cold_data[i], final_cold_domain[i]
    temp_seq, temp_domain = [], []
    flag = False    # mark whether has cross domain
    
    if len(random_inter) < 3:
        continue

    for j in range(len(random_inter)-1):
        temp_seq.append(random_inter[j])
        temp_domain.append(random_domain[j])
        if random_domain[j] != random_domain[j+1]:
            flag = True
            break
    
    if len(temp_seq) > 3 and flag and random_num[random_domain[j+1]]<60:
        temp_seq.append(random_inter[j+1])
        temp_domain.append(random_domain[j+1])
        filter_cold_data[i] = np.array(temp_seq)
        filter_cold_domain[i] = np.array(temp_domain)
        random_num[random_domain[j+1]] += 1

    if (random_num[0]+random_num[1]) > 100:
        break

In [37]:
domain_flag = []
for meta_domain_seq in filter_cold_domain.values():
    domain_flag.append(meta_domain_seq[-1])

In [None]:
np.sum(domain_flag) # the number of movie domain

In [39]:
with open("./handled/book_movie_cold.pkl", "wb") as f:
    pickle.dump((filter_cold_data, filter_cold_domain), f)

statistics

In [None]:
test_domain = {0: 0, 1: 0, 2: 0}
inter_len = []
new_inter_len = []
for domain_inter in final_domain.values():
    target_domain = domain_inter[-1]
    test_domain[target_domain] += 1
    inter_len.append(len(domain_inter))
    new_inter = np.array(domain_inter)
    new_inter = new_inter[np.where(np.logical_or(new_inter==0, new_inter == 2))]
    new_inter_len.append(len(new_inter))
test_domain

In [None]:
inter_len = np.array(inter_len)
inter_len[inter_len>100]=100

In [62]:
new_inter_len = np.array(new_inter_len)
new_inter_len[new_inter_len>100]=100

In [None]:
plt.hist(new_inter_len, bins=30)

In [None]:
{domain_id: len(set(domain_set[domain_id]["user"])) for domain_id in domain_set.keys()}

In [None]:
user_inter = {}
start_id = 1
user_dict = {"str2id": {}, "id2str": {}}
for inter in tqdm(all_data):
    user_id, item_id, time, domain = inter

    if user_id not in user_dict["str2id"].keys():
        user_dict["str2id"][user_id] = start_id
        user_dict["id2str"][start_id] = user_id
        user_inter[start_id] = [(item_id, time, domain)]
        start_id += 1
    else:
        user_inter[user_dict["str2id"][user_id]].append((item_id, time, domain))

In [None]:
for user_id in tqdm(user_inter.keys()):
    user_inter[user_id].sort(key=lambda x: x[1])