In [1]:
import pickle
from pymongo import MongoClient
from tqdm import tqdm
from collections import defaultdict
from scipy.sparse import csr_matrix
import numpy as np
import random

In [2]:
### 1. Def getHistory_user_lists
def getHistory_user_lists():
    history_u_lists = defaultdict(list)
    history_ur_lists = defaultdict(list)
    bar = tqdm(total=user.count_documents({}), desc='Get History_user_lists')
    tempIds = user.find({}, no_cursor_timeout=True, batch_size=10)
    for item in tempIds:
        history_u_lists[int(item['newId'])] = []
        history_ur_lists[int(item['newId'])] = []
        for history in review.find({'user_id': item['user_id']}):
            history_u_lists[int(item['newId'])].append(int(history['newBusinessId']))
            history_ur_lists[int(item['newId'])].append(int(history['stars'])-1)
        bar.update(1)
    tempIds.close()
    bar.close()
    return history_u_lists, history_ur_lists

In [3]:
### 2. Def getHistory_poi_lists
def getHistory_poi_lists():
    history_v_lists = defaultdict(list)
    history_vr_lists = defaultdict(list)
    bar = tqdm(total=business.count_documents({}), desc='Get History_poi_lists')
    tempIds = business.find({}, no_cursor_timeout=True, batch_size=10)
    for item in tempIds:
        history_v_lists[int(item['newId'])] = []
        history_vr_lists[int(item['newId'])] = []
        for history in review.find({'business_id': item['business_id']}):
            history_v_lists[int(item['newId'])].append(int(history['newUserId']))
            history_vr_lists[int(item['newId'])].append(int(history['stars'])-1)
        bar.update(1)
    tempIds.close()
    bar.close()
    return history_v_lists, history_vr_lists

In [4]:
### 3. Def train_test_split
def train_test_split():
    trainPercent = 0.8
    matrix = dict()
    train_u, train_v, train_r, train_s, train_t = [], [], [], [], []
    test_u, test_v, test_r, test_s, test_t = [], [], [], [], []

    # 取得全部資料
    bar = tqdm(total=review.count_documents({}), desc='Train_Test_Split get all data')
    tempIds = review.find({}, no_cursor_timeout=True, batch_size=10)
    for item in tempIds:
        tid = int(item['newUserId'])
        if tid not in matrix:
            matrix[tid] = []

        matrix[tid].append({
            'newBusinessId': int(item['newBusinessId']),
            'stars': float(item['stars']),
            'sentiment_vector': item['sentiment_vector'],
            'timeProb': float(item['timeProb'])
        })
        bar.update(1)
    tempIds.close()
    bar.close()

    for user_id, user_data in matrix.items():
        user_data.sort(key=lambda x: x['stars'], reverse=True)
        train_size = int(len(user_data) * trainPercent)
        train_indices = random.sample(range(len(user_data)), train_size)
        train_indices_set = set(train_indices)
        
        for i, item in enumerate(user_data):
            if i in train_indices_set:
                train_u.append(user_id)
                train_v.append(item['newBusinessId'])
                train_r.append(item['stars'])
                train_s.append(item['sentiment_vector'])
                train_t.append(item['timeProb'])
            else:
                test_u.append(user_id)
                test_v.append(item['newBusinessId'])
                test_r.append(item['stars'])
                test_s.append(item['sentiment_vector'])
                test_t.append(item['timeProb'])

    return train_u, train_v, train_r, train_s, train_t, test_u, test_v, test_r, test_s, test_t

In [5]:
### 4. Def getSocial_adj_lists
def getSocial_adj_lists():
    social_adj_lists = defaultdict(set)
    bar = tqdm(total=user.count_documents({}), desc='Get Social_adj_lists')
    tempIds = user.find({}, no_cursor_timeout=True, batch_size=10)
    for item in tempIds:
        social_adj_lists[int(item['newId'])] = set()
        for friend in item['newFriends']:
            social_adj_lists[int(item['newId'])].add(int(friend))
        bar.update(1)
    tempIds.close()
    bar.close()
    return social_adj_lists

In [6]:
### 5. Def getSocial_adj_lists
def getPOI_adj_lists():
    poi_adj_lists = defaultdict(set)
    bar = tqdm(total=business.count_documents({}), desc='Get POI_adj_lists')
    tempIds = business.find({}, no_cursor_timeout=True, batch_size=10)
    for item in tempIds:
        poi_adj_lists[int(item['newId'])] = set()
        for neighbor in item['newNeighbors']:
            poi_adj_lists[int(item['newId'])].add(int(neighbor))
        bar.update(1)
    tempIds.close()
    bar.close()
    return poi_adj_lists

In [7]:
### 6. Def getRatings_list
def getRatings_list():
    ratings_list = dict()
    ratings_list[1.0] = 0
    ratings_list[2.0] = 1
    ratings_list[3.0] = 2
    ratings_list[4.0] = 3
    ratings_list[5.0] = 4
    ratings_list[1.5] = 5
    ratings_list[2.5] = 6
    ratings_list[3.5] = 7
    return ratings_list

In [8]:
### 設定使用的資料庫
client = MongoClient('127.0.0.1', 27017)
db = client.Yelp_Final
business = db.business
review = db.review
user = db.user
sentiment = db.sentiment

# 儲存pkl的路徑與檔名
dir_data = './data/final.pickle'

In [18]:
history_u_lists, history_ur_lists = getHistory_user_lists()
history_v_lists, history_vr_lists = getHistory_poi_lists()
train_u, train_v, train_r, train_s, train_t, test_u, test_v, test_r, test_s, test_t = train_test_split()
social_adj_lists = getSocial_adj_lists()
poi_adj_lists = getPOI_adj_lists()
ratings_list = getRatings_list()

data_file = open(dir_data, 'wb')
pickle.dump((history_u_lists, history_ur_lists, history_v_lists, history_vr_lists, train_u, train_v, train_r, train_s, train_t, test_u, test_v, test_r, test_s, test_t, social_adj_lists, poi_adj_lists, ratings_list), data_file)
data_file.close()
print('Save data to', dir_data, 'successfully!')

Get History_user_lists: 100%|██████████| 1779/1779 [01:16<00:00, 23.14it/s]
Get History_poi_lists: 100%|██████████| 11456/11456 [06:01<00:00, 31.71it/s]
Train_Test_Split get all data: 100%|██████████| 61189/61189 [00:12<00:00, 5035.39it/s]
Get Social_adj_lists: 100%|██████████| 1779/1779 [00:00<00:00, 35340.57it/s]
Get POI_adj_lists: 100%|██████████| 11456/11456 [00:00<00:00, 20214.13it/s]


Save data to ./data/final.pickle successfully!


In [3]:
dir_data = './data/final.pickle'
data_file = open(dir_data, 'rb')
history_u_lists, history_ur_lists, history_v_lists, history_vr_lists, train_u, train_v, train_r, train_s, train_t, test_u, test_v, test_r, test_s, test_t, social_adj_lists, POI_adj_lists, ratings_list = pickle.load(
        data_file)
data_file.close()

In [5]:
print(history_u_lists[0])

[2061, 8895, 7433, 10878, 10878, 10974, 7535, 864, 3682, 9353, 4041, 2672, 1914, 8245, 9576, 3498, 8190, 9982, 9162, 6290, 5987, 8352, 8624, 2469, 8399, 3571, 5787, 221, 8322, 8389]


In [9]:
k = 15
needCalcUser = 0
for item in history_u_lists:
    if len(history_u_lists[item]) > k:
        needCalcUser += 1
print(needCalcUser)

1668


In [23]:
poi_sentiment = getSentiment_list()
data_file = open('./data/AvgSentiment.pickle', 'wb')
pickle.dump((poi_sentiment), data_file)
data_file.close()
print('Save AvgSentiment successfully!')