In [1]:
import pickle
import json
import time, datetime
import numpy as np
import torch
from copy import deepcopy

def read_pickle(file):
    with open(file, 'rb') as f:
        ret = pickle.load(f)
    return ret

def write_pickle(file, data):
    with open(file, 'wb') as fw:
        pickle.dump(data, fw)
        

def read_json(file):
    with open(file, 'r') as f:
        ret = [json.loads(line) for line in f]
    return ret

def write_json(file, data):
    with open(file, 'w', encoding='utf-8') as fw:
        for item in data:
            line = json.dumps(item, ensure_ascii=False)
            fw.write(line + '\n')    # here we need line+'\n'

def read_by_line(file):
    with open(file, 'r') as f:
        ret = [line for line in f.readlines()]
    return ret

def write_by_line(file, data):
    with open(file, 'w') as fw:
        for line in data:
            fw.write(line)    # here does not need line+'\n'


gettime = lambda: time.time()

# Read JSON file
output: metadata, books

In [296]:
filter_path = 'filtered/'
# dataset_path = 'Electronics/'
dataset_path = 'Books/'

In [2]:
t0 = gettime()
# metadata_str = read_by_line('metadata.json')
metadata = read_json('metadata-modified.json')
t1 = gettime()
print("time cost:", t1 - t0)

time cost: 313.7264778614044


In [124]:
# Since in the original 'metadata.json' file, every entity name is enclosed in single quotes, which is illegal in json.
# if no metadata-modified.json file, we have to treat the json file as a txt file.
# we then regenerate the modified json file so that we can read it faster the next time.
t0 = gettime()
metadata_dict = [eval(m) for m in metadata_str]
print("time cost:", gettime() - t0)
t0 = gettime()
write_json('metadata-modified.json', metadata_dict)
print("time cost:", gettime() - t0)

time cost: 356.1625485420227


In [3]:
t0 = gettime()
book_reviews = read_json('Books/reviews_Books_5.json')
print("time cost:", gettime() - t0)

time cost: 169.2364764213562


In [None]:
item_ids = set(item['asin'] for item in metadata)

# Get Book Metadata and ReviewerIDs

All of the books have 'categories' tag.

Most of them have 'related' tag.

Only 20 books have 'brand' tag.

In [68]:
reviewed_book_ids = set(review['asin'] for review in book_reviews if review['asin'] in item_ids)
books = [item for item in metadata if item['asin'] in reviewed_book_ids]
print("number of Books in metadata being reviewed:", len(reviewed_book_ids))
print("number of books:", len(books))

users = list(set(review['reviewerID'] for review in book_reviews))
print(len(users))

number of Books in metadata being reviewed: 367982
number of books: 367982
603668


In [69]:
cnt_b = 0
cnt_r = 0
cnt_c = 0
for item in books:
    if 'brand' in item.keys():
        cnt_b += 1
    if 'related' in item.keys():
        cnt_r += 1
    if 'categories' in item.keys():
        cnt_c += 1
print("books with brand: %d, books with related: %d, books with categories: %d"%(cnt_b, cnt_r, cnt_c))

books with brand: 20, books with related: 343733, books with categories: 367982


In [78]:
write_json('Books/Books_metadata.json', books)

# Electronics Review Data Preprocessing

In [None]:
t0 = gettime()
elec_reviews = read_json('Electronics/reviews_Electronics_5.json')
print(len(elec_reviews))
print("time cost:", gettime() - t0)
i = 1
print(elec_reviews[i].keys())

In [156]:
# print(elec_reviews[1])

In [74]:
reviewed_elecs = set(review['asin'] for review in elec_reviews if review['asin'] in item_ids)
print("number of Electronics in metadata being reviewed:", len(reviewed_elecs))

number of Electronics in metadata being reviewed: 63001


In [75]:
reviewed_elec_ids = set(review['asin'] for review in elec_reviews if review['asin'] in item_ids)
elecs = [item for item in metadata if item['asin'] in reviewed_elec_ids]
print("number of Electronics in metadata being reviewed:", len(reviewed_elec_ids))
print("number of elecs:", len(elecs))

users = list(set(review['reviewerID'] for review in elec_reviews))
print(len(users))

number of Electronics in metadata being reviewed: 63001
number of elecs: 63001
192403


In [77]:
cnt_b = 0
cnt_r = 0
cnt_c = 0
for item in elecs:
    if 'brand' in item.keys():
        cnt_b += 1
    if 'related' in item.keys():
        cnt_r += 1
    if 'categories' in item.keys():
        cnt_c += 1
print("elecs with brand: %d, elecs with related: %d, elecs with categories: %d" % (cnt_b, cnt_r, cnt_c))

elecs with brand: 29751, elecs with related: 60682, elecs with categories: 63001


In [79]:
write_json('Electronics/Electronics_metadata.json', elecs)

## Filter rare nodes

*asin*: unique number for every item in Amazon

*reviewerID*: unique ID for each user

At first we build the list of all asins and users, then we filter the reviews.

What we may need: **user_id**, **item_id**, **category**, **brand**, however, '**related**' in metadata is also interesting... but we ignore it at first.

'related' includes 'bought_together', 'also_viewed', 'also bought', 'buy_after_viewing', ''

output: filtered_reviews, filtered_users, filtered_items, filtered_metadata

In [298]:
# items_origin = elecs
# reviews_origin = elec_reviews

items_origin = books
reviews_origin = book_reviews

print(len(items_origin))
print(len(reviews_origin))

367982
8898041


In [299]:
def filter_rare_node(items, reviews, user_threshold, item_threshold, related_threshold):
    continue_filter = True
    filtered_user_ids = set()
    filtered_item_asins = set()
    item_relate = {}
    
    while(continue_filter):
        t0 = gettime()
        
        continue_filter = False
        user_interact_num = {}
        item_interact_num = {}
        user_item_interact = set()
        
        #------------------------------------
        # filter step 1
        # rough filter
        # filter the active users and items
        # according to the description of the dataset, each user and item has at least 5 reviews
        #------------------------------------
        t1 = gettime()
        for review in reviews:
            if not review['unixReviewTime']:
                continue
            user_id = review['reviewerID']
            item_id = review['asin']
            user_item = str(user_id)+str(item_id)
            if user_item not in user_item_interact:
                user_interact_num[user_id] = user_interact_num.get(user_id, 0) + 1
                item_interact_num[item_id] = item_interact_num.get(item_id, 0) + 1
                user_item_interact.add(user_item)
        filtered_review_users = set(u for u in user_interact_num.keys() if user_interact_num[u]>=user_threshold)
        filtered_review_items = set(b for b in item_interact_num.keys() if item_interact_num[b]>=item_threshold)
        print("step 1 time cost:", gettime() - t1)
        
        print("len filtered_review_users: %d, len filtered_review_items: %d" %(len(filtered_review_users), len(filtered_review_items)))
        
        if (filtered_user_ids != filtered_review_users) or (filtered_item_asins != filtered_review_items):
            continue_filter = True
            
        #------------------------------------
        # filter step 2
        # filter items
        # keep related items all included in one set
        #------------------------------------
        t1 = gettime()
        item_related_dict = {}    # {item_id:related}
        for item in items:
            item_id = item['asin']
            if item_id not in filtered_review_items:
                continue
            if 'related' not in item.keys():    # in case some items have no tag 'related'
                item['related'] = {}
            filtered_related = [asin for subtag in item['related'].values() for asin in subtag if asin in filtered_review_items]
            if len(filtered_related) >= related_threshold:
                item_related_dict[item_id] = filtered_related
                
        continue_inside = True
        while (continue_inside):
            relates = {}    # {item_id:related items}
            continue_inside = False
            for item, item_related in item_related_dict.items():
                filtered_related = [relate for relate in item_related if relate in item_related_dict.keys()]
#                 print(len(filtered_related))
                if len(filtered_related) >= related_threshold:
                    relates[item] = filtered_related
                else:
                    continue_inside = True
                    
            print("len(relates):", len(relates))
            item_related_dict = deepcopy(relates)

        item_relate = deepcopy(item_related_dict)
        filtered_item_asins = set(item_related_dict.keys())
        print("step 2 time cost:", gettime() - t1)
        
        #------------------------------------
        # filter step 3
        # filter users
        #------------------------------------
        t1 = gettime()
        filtered_user_ids = set(review['reviewerID'] for review in reviews \
                                if (review['asin'] in filtered_item_asins) \
                                and (review['reviewerID'] in filtered_review_users))
        print("step 3 time cost:", gettime() - t1)
        
        #------------------------------------
        # filter step 4
        # filter reviews
        # make sure that 'reviewerID' in filtered_user_ids and 'asin' in filtered_item_asins
        #------------------------------------
        t1 = gettime()
        filtered_review = []
        user_item_interact = set()
        for review in reviews:
            if 'unixReviewTime' not in review.keys():
                continue
            if (review['reviewerID'] in filtered_user_ids) and (review['asin'] in filtered_item_asins):
                user_id = review['reviewerID']
                item_id = review['asin']
                user_item = str(user_id) + '/' + str(item_id)
                if user_item not in user_item_interact:    # remove duplication
                    filtered_review.append(review)
                    user_item_interact.add(user_item)
        reviews = deepcopy(filtered_review)
        print("step 4 time cost:", gettime() - t1)
        
        print(len(list(filtered_user_ids)))
        print(len(list(filtered_item_asins)))
        print(len(reviews))
        print('time cost:', gettime() - t0)
        print('filter loop')
        
    print('filter complete')
    
#     filtered_items = [item['asin'] for item in items if item['asin'] in filtered_item_asins]
    filtered_items = []
    for item in items:
        filtered_item = {}
        if item['asin'] in filtered_item_asins:
            filtered_item['asin'] = item['asin']
            if 'categories' in item.keys():
                filtered_item['categories'] = item['categories']
            if 'brand' in item.keys():
                filtered_item['brand'] = item['brand']
            if 'related' in item.keys():
                filtered_item['related'] = item_relate[item['asin']]
            filtered_items.append(filtered_item)
    
    print(len(filtered_user_ids))
    print(len(filtered_item_asins))
    print(len(reviews))
    
    return list(filtered_user_ids), filtered_items, filtered_review

In [None]:
# users_large, items_large, reviews_large = filter_rare_node(items_origin, reviews_origin, 8, 8, 2)    # Electronics
users_large, items_large, reviews_large = filter_rare_node(items_origin, reviews_origin, 10, 10, 2)      # Books

step 1 time cost: 21.25064706802368
len filtered_review_users: 219385, len filtered_review_items: 196511
len(relates): 173859
len(relates): 173820
len(relates): 173815
len(relates): 173815
step 2 time cost: 100.29750967025757
step 3 time cost: 5.9981536865234375
step 4 time cost: 121.28477644920349
219309
173815
5228831
time cost: 248.87074494361877
filter loop
step 1 time cost: 9.323015689849854
len filtered_review_users: 171152, len filtered_review_items: 131547
len(relates): 129904
len(relates): 129892
len(relates): 129891
len(relates): 129890
len(relates): 129890
step 2 time cost: 39.171366930007935
step 3 time cost: 4.517160177230835
step 4 time cost: 194.57264232635498
171152
129890
4572166
time cost: 248.56316661834717
filter loop
step 1 time cost: 9.089890003204346
len filtered_review_users: 154576, len filtered_review_items: 118802
len(relates): 118218
len(relates): 118212
len(relates): 118211
len(relates): 118211
step 2 time cost: 29.52898359298706
step 3 time cost: 3.5465202

In [None]:
print(items_large[0]['related'])

In [None]:
t0 = gettime()
write_pickle(dataset_path+filter_path+'users-large.pickle', users_large)
write_pickle(dataset_path+filter_path+'items-large.pickle', items_large)
write_pickle(dataset_path+filter_path+'reviews-large.pickle', reviews_large)
print("time cost:", gettime() - t0)

In [None]:
users_small, items_small, reviews_small = filter_rare_node(items_origin, reviews_origin, 13, 12, 4)

In [None]:
t0 = gettime()
write_pickle(dataset_path+filter_path+'users-middle.pickle', users_small)
write_pickle(dataset_path+filter_path+'items-middle.pickle', items_small)
write_pickle(dataset_path+filter_path+'reviews-middle.pickle', reviews_small)
print("time cost:", gettime() - t0)

In [268]:
users = users_small
items = items_small
reviews = reviews_small

# users = users_large
# items = items_large
# reviews = reviews_large

print(len(users))
print(len(items))
print(len(reviews))

21272
13596
359189


# Filter Result Insight

In [269]:
brands = set(item['brand'] for item in items if 'brand' in item.keys())
item_with_brands = [item for item in items if 'brand' in item.keys()]
print(len(brands))
print(len(item_with_brands))

1281
8156


In [270]:
item_with_categories = [item for item in items if 'categories' in item.keys()]
print(len(item_with_categories))

13596


In [295]:
categories = set(cat for cat in item['categories'][0] for item in items)
print(len(categories))

5


In [271]:
item_with_related = [item for item in items if 'related' in item.keys()]
print(len(item_with_related))

13596


# Compute id2ind and ind2id

In [272]:
uinds = [i for i in range(len(users))]
uid2ind = {user:ind for user, ind in zip(users, uinds)}
ind2uid = {ind:user for user, ind in zip(users, uinds)}
print(ind2uid[0])

AKPKPMWM6IIU5


In [273]:
b_inds = [i for i in range(len(items))]
bid2ind = {item['asin']:ind for item, ind in zip(items, b_inds)}
ind2bid = {ind:item['asin'] for item, ind in zip(items, b_inds)}
print(ind2bid[0])

0972683275


In [274]:
# brands = set(item['brand'] for item in items if 'brand' in item.keys())
# print(len(metadata))
brands = [item['brand'] for item in items if 'brand' in item.keys()]
print(len(brands))
brands = list(set(brands))

br_inds = [i for i in range(len(brands))]
br_id2ind = {brand:ind for brand, ind in zip(brands, br_inds)}
ind2br_id = {ind:brand for brand, ind in zip(brands, br_inds)}
print(len(brands))
print(ind2br_id[3])

8156
1281
Blue Crane Digital


In [275]:
categories = set(category for item in items for category in item['categories'][0] if 'categories' in item.keys())
ca_inds = [i for i in range(len(categories))]
ca_id2ind = {category:ind for category, ind in zip(categories, ca_inds)}
ind2ca_id = {ind:category for category, ind in zip(categories, ca_inds)}
print(len(ind2ca_id))
print(ind2ca_id[0])

580
Screen Protectors


In [276]:
print(items[2357]['categories'][0])

['Electronics', 'Accessories & Supplies', 'Audio & Video Accessories', 'Cables & Interconnects', 'Video Cables']


In [277]:
adj_path = 'adjs/'
datas = [uid2ind, ind2uid, bid2ind, ind2bid, br_id2ind, ind2br_id, ca_id2ind, ind2ca_id]
filenames = ['user_id2index', 'index2user_id', 'item_id2index', 'index2item_id', 'brand_id2index', 'index2brand_id', 'categories_id2index', 'index2categories_id']
for i in range(8):
    write_pickle(dataset_path+adj_path+filenames[i]+'.pickle', datas[i])

# Split data and compute adj matrix

In [278]:
def dataset_split(reviews, uid2ind, bid2ind, train_ratio, valid_ratio, test_ratio, n_neg_sample):
    selected_reviews = []
    
    for review in reviews:
        filtered_review = {}
        filtered_review['reviewerID'] = uid2ind[review['reviewerID']]
        filtered_review['asin'] = bid2ind[review['asin']]
        filtered_review['rate'] = 1.0
#         filtered_review['unixReviewTime'] = time.mktime(datetime.datetime.strptime(review['date'], '%Y-%m-%d %H:%M:%S').timetuple())
        filtered_review['unixReviewTime'] = review['unixReviewTime']
        selected_reviews.append(filtered_review)
        
    selected_reviews_sorted = sorted(selected_reviews, key=lambda k: k['unixReviewTime']) # use the earlier data to train and the later data to test
    n_reviews = len(selected_reviews_sorted)
    train_size = int(n_reviews * train_ratio)
    valid_size = int(n_reviews * valid_ratio)
    train_data = [selected_reviews_sorted[i] for i in range(train_size)]
    valid_data = [selected_reviews_sorted[i] for i in range(train_size, train_size + valid_size)]
    test_data = [selected_reviews_sorted[i] for i in range(train_size + valid_size, n_reviews)]
    
    selected_users = set()
    selected_businesses = set()
    for review in train_data:
        selected_users.add(review['reviewerID'])
        selected_businesses.add(review['asin'])
        
    eval_datas = [valid_data, test_data]
#     selected_eval_datas = [[] for _ in range(len(eval_datas))]
    selected_eval_datas = [[], []]
    for eval_index in range(len(eval_datas)):
        eval_data = eval_datas[eval_index]
        for review in eval_data:
            if review['reviewerID'] in selected_users and review['asin'] in selected_businesses:
                selected_eval_datas[eval_index].append(review)
    selected_valid_data, selected_test_data = selected_eval_datas
    
    data_list = [train_data, selected_valid_data, selected_test_data]
#     data_for_user_list = [{} for _ in range(len(data_list))]
    data_for_user_list = [{}, {}, {}]
    train_data_for_item = set()
    for index in range(len(data_list)):
        data = data_list[index]
        data_for_user = data_for_user_list[index]
        for review in data:
            user = review['reviewerID']
            item = review['asin']
            if index == 0:
                train_data_for_item.add(item)
            if user not in data_for_user:
                data_for_user[user] = [item]
            else:
                data_for_user[user].append(item)
    train_data_for_user, valid_data_for_user, test_data_for_user = data_for_user_list # dictionary of user_id:[item_id]
    
    with_neg_list = [valid_data_for_user, test_data_for_user]
#     data_with_neg_list = [[] for _ in range(len(with_neg_list))]
    data_with_neg_list = [[], []]
    for index in range(len(with_neg_list)):
        current_data = with_neg_list[index]
        for user in current_data.keys():
            if user not in selected_users:
                continue
            user_eval = {} # a dict
            business_set = selected_businesses - set(train_data_for_user[user]) - set(current_data[user]) # items not existed in this user's records
            sample_businesses = np.random.choice(list(business_set), size=n_neg_sample, replace=False)    # sample is random.choice
            user_eval['reviewerID'] = user
            user_eval['pos_item_id'] = current_data[user]
            user_eval['neg_item_id'] = list(sample_businesses)
            data_with_neg_list[index].append(user_eval)
    valid_with_neg, test_with_neg = data_with_neg_list
    
    return train_data, selected_valid_data, selected_test_data, valid_with_neg, test_with_neg

In [286]:
# get adjs
def get_adj_matrix(uid2ind, bid2ind, brand_id2ind, cat_id2ind, users, businesses, reviews):
    """
    Notice
    ------
    Not all items have the 'brand' tag.
    
    metapaths: UBU, UBBU, UB, UBB, UBUB, UBCaB, UBBrB, UBCa, 
               UBBCa, UBBr, UBBBr, BB, BCaB, BBrB, BCa, BBCa, BBr, BBBr
    """
    tot_users = len(uid2ind)  # tot for total
    tot_business = len(bid2ind)
    tot_brand = len(brand_id2ind)
    tot_category = len(cat_id2ind)
    print(tot_users, tot_business, tot_brand, tot_category)
    
    adj_UB = np.zeros([tot_users, tot_business])
    adj_BB = np.zeros([tot_business, tot_business])
    adj_BCa = np.zeros([tot_business, tot_category])
    adj_BBr = np.zeros([tot_business, tot_brand])
    
    #-----------------------------------------------
    # step 1:
    # initiate adj_UB, adj_BB, adj_BCa, adj_BBr
    #-----------------------------------------------
    # relation UB
    for review in reviews:
        user_id = review['reviewerID']    # it's already a number
        business_id = review['asin']      # it's already a number
        adj_UB[user_id][business_id] = 1
        
    #relation BB, BCa, BBr
    for business in businesses:
        if business['asin'] not in bid2ind.keys():
            continue
        business_id = bid2ind[business['asin']]
        #------------
        # BBr
        #------------
        if 'brand' in business.keys():
            brand_id = brand_id2ind[business['brand']]
            print("business_id: %d, brand_id: %d" % (business_id, brand_id))
            adj_BBr[business_id][brand_id] = 1
        #------------
        # BCa
        # more than one category for a business
        #------------
        for category in business['categories'][0]:
            category = category.strip()
            category_id = cat_id2ind[category]
            adj_BCa[business_id][category_id] = 1
        #------------
        # BB
        # more than one related item for the given item
        #------------
        relates = list(set(item for item in business['related']))
        for item in relates:
            relate_id = bid2ind[item]
            adj_BB[business_id][relate_id] = 1
            adj_BB[relate_id][business_id] = 1

    #-----------------------------------------------
    # step 2:
    # Compute other metapaths needed
    #-----------------------------------------------
    adj_UBU = adj_UB.dot(adj_UB.T)
    adj_UBB = adj_UB.dot(adj_BB)
    adj_UBBU = adj_UBB.dot(adj_UB.T)
    adj_UBUB = adj_UBU.dot(adj_UB)
    adj_UBCa = adj_UB.dot(adj_BCa)
    adj_UBBCa = adj_UBB.dot(adj_BCa)
    adj_UBCaB = adj_UBCa.dot(adj_BCa.T)
    adj_UBBr = adj_UB.dot(adj_BBr)
    adj_UBBBr = adj_UBB.dot(adj_BBr)
    adj_UBBrB = adj_UBBr.dot(adj_BBr.T)
    adj_BCaB = adj_BCa.dot(adj_BCa.T)
    adj_BBrB = adj_BBr.dot(adj_BBr.T)
    adj_BBBr = adj_BB.dot(adj_BBr)
    
    return adj_UB, adj_UBB, adj_UBBU, adj_UBU, adj_UBUB, adj_UBCa, \
           adj_UBBCa, adj_UBCaB, adj_UBBr, adj_UBBBr, adj_UBBrB, adj_BB, \
           adj_BCa, adj_BBr, adj_BCaB, adj_BBrB, adj_BBBr


In [280]:
train_data, valid_data, test_data, valid_with_neg_sample, test_with_neg_sample \
    = dataset_split(reviews, uid2ind, bid2ind, 0.8, 0.1, 0.1, 50)

In [281]:
print(type(train_data))
print(train_data[0])
print(len(reviews))
print(len(train_data), len(valid_data), len(test_data))
print(valid_with_neg_sample[1253])
print(test_with_neg_sample[5463])

<class 'list'>
{'unixReviewTime': 939600000, 'reviewerID': 13860, 'asin': 20, 'rate': 1.0}
359189
287351 30913 26765
{'reviewerID': 1653, 'pos_item_id': [6772], 'neg_item_id': [12730, 7813, 5829, 12721, 10198, 3297, 9061, 2268, 516, 13174, 6920, 4771, 8503, 19, 3963, 4874, 10750, 11146, 11127, 236, 3943, 10279, 2618, 4934, 10662, 11416, 12522, 6568, 9669, 9244, 12932, 8818, 5379, 5488, 7805, 10047, 10894, 11982, 5076, 9658, 12363, 1570, 993, 13003, 4612, 8556, 6790, 1085, 4465, 5663]}
{'reviewerID': 4220, 'pos_item_id': [11491], 'neg_item_id': [1629, 417, 11694, 11588, 10900, 6634, 4806, 5842, 374, 12151, 2801, 13256, 3042, 12568, 9557, 3168, 748, 4399, 6786, 5602, 5942, 2638, 5981, 13092, 4341, 10770, 1969, 5116, 8324, 4235, 13161, 10578, 7314, 2175, 6653, 10787, 8131, 5693, 11447, 2229, 6912, 7, 13021, 3335, 10257, 3244, 5181, 2772, 2111, 11664]}


In [282]:
rating_path = 'ratings/'
path = dataset_path+rating_path
filenames = ['train_data', 'valid_data', 'test_data', 'valid_with_neg_sample', 'test_with_neg_sample']
objs = [train_data, valid_data, test_data, valid_with_neg_sample, test_with_neg_sample]
for file, obj in zip(filenames, objs):
    write_pickle(path+file+'.pickle', obj)

In [287]:
# get adj matrices
adj_UB, adj_UBB, adj_UBBU, adj_UBU, adj_UBUB, adj_UBCa, \
    adj_UBBCa, adj_UBCaB, adj_UBBr, adj_UBBBr, adj_UBBrB, adj_BB, \
    adj_BCa, adj_BBr, adj_BCaB, adj_BBrB, adj_BBBr\
    = get_adj_matrix(uid2ind, bid2ind, br_id2ind, ca_id2ind, users, items, train_data)

21272 13596 1281 580
business_id: 0, brand_id: 167
business_id: 1, brand_id: 1112
business_id: 3, brand_id: 811
business_id: 5, brand_id: 357
business_id: 6, brand_id: 1247
business_id: 7, brand_id: 747
business_id: 8, brand_id: 747
business_id: 9, brand_id: 747
business_id: 11, brand_id: 209
business_id: 12, brand_id: 747
business_id: 13, brand_id: 1213
business_id: 14, brand_id: 1213
business_id: 15, brand_id: 1213
business_id: 16, brand_id: 1213
business_id: 17, brand_id: 1213
business_id: 18, brand_id: 1213
business_id: 19, brand_id: 1237
business_id: 21, brand_id: 1237
business_id: 22, brand_id: 980
business_id: 23, brand_id: 1230
business_id: 24, brand_id: 1213
business_id: 25, brand_id: 1213
business_id: 26, brand_id: 218
business_id: 27, brand_id: 39
business_id: 28, brand_id: 39
business_id: 29, brand_id: 39
business_id: 30, brand_id: 39
business_id: 31, brand_id: 1072
business_id: 32, brand_id: 1237
business_id: 33, brand_id: 747
business_id: 34, brand_id: 75
business_id: 35,

business_id: 1725, brand_id: 694
business_id: 1728, brand_id: 551
business_id: 1729, brand_id: 478
business_id: 1732, brand_id: 588
business_id: 1733, brand_id: 271
business_id: 1734, brand_id: 849
business_id: 1736, brand_id: 678
business_id: 1737, brand_id: 763
business_id: 1738, brand_id: 717
business_id: 1739, brand_id: 597
business_id: 1740, brand_id: 377
business_id: 1741, brand_id: 694
business_id: 1746, brand_id: 1003
business_id: 1748, brand_id: 649
business_id: 1750, brand_id: 478
business_id: 1756, brand_id: 553
business_id: 1757, brand_id: 18
business_id: 1758, brand_id: 928
business_id: 1759, brand_id: 551
business_id: 1763, brand_id: 424
business_id: 1764, brand_id: 334
business_id: 1765, brand_id: 143
business_id: 1766, brand_id: 334
business_id: 1767, brand_id: 964
business_id: 1768, brand_id: 928
business_id: 1769, brand_id: 928
business_id: 1770, brand_id: 426
business_id: 1771, brand_id: 117
business_id: 1772, brand_id: 1237
business_id: 1773, brand_id: 1031
business

business_id: 3478, brand_id: 928
business_id: 3487, brand_id: 953
business_id: 3488, brand_id: 1017
business_id: 3490, brand_id: 123
business_id: 3491, brand_id: 845
business_id: 3492, brand_id: 123
business_id: 3494, brand_id: 1278
business_id: 3495, brand_id: 731
business_id: 3496, brand_id: 867
business_id: 3501, brand_id: 724
business_id: 3502, brand_id: 1197
business_id: 3503, brand_id: 1018
business_id: 3505, brand_id: 361
business_id: 3507, brand_id: 853
business_id: 3508, brand_id: 811
business_id: 3512, brand_id: 458
business_id: 3515, brand_id: 1182
business_id: 3517, brand_id: 1192
business_id: 3518, brand_id: 960
business_id: 3519, brand_id: 960
business_id: 3520, brand_id: 1235
business_id: 3521, brand_id: 746
business_id: 3522, brand_id: 228
business_id: 3523, brand_id: 551
business_id: 3524, brand_id: 551
business_id: 3526, brand_id: 551
business_id: 3529, brand_id: 123
business_id: 3530, brand_id: 928
business_id: 3532, brand_id: 873
business_id: 3535, brand_id: 694
bus

business_id: 5172, brand_id: 167
business_id: 5173, brand_id: 167
business_id: 5175, brand_id: 724
business_id: 5179, brand_id: 997
business_id: 5180, brand_id: 696
business_id: 5181, brand_id: 811
business_id: 5182, brand_id: 410
business_id: 5183, brand_id: 1235
business_id: 5185, brand_id: 965
business_id: 5186, brand_id: 317
business_id: 5188, brand_id: 527
business_id: 5189, brand_id: 570
business_id: 5190, brand_id: 441
business_id: 5191, brand_id: 485
business_id: 5193, brand_id: 767
business_id: 5194, brand_id: 478
business_id: 5195, brand_id: 1224
business_id: 5198, brand_id: 738
business_id: 5200, brand_id: 162
business_id: 5201, brand_id: 431
business_id: 5202, brand_id: 431
business_id: 5204, brand_id: 779
business_id: 5209, brand_id: 239
business_id: 5212, brand_id: 594
business_id: 5213, brand_id: 754
business_id: 5214, brand_id: 731
business_id: 5216, brand_id: 687
business_id: 5217, brand_id: 687
business_id: 5218, brand_id: 195
business_id: 5219, brand_id: 506
business

business_id: 6835, brand_id: 1212
business_id: 6836, brand_id: 1155
business_id: 6837, brand_id: 1212
business_id: 6838, brand_id: 453
business_id: 6839, brand_id: 458
business_id: 6840, brand_id: 458
business_id: 6841, brand_id: 720
business_id: 6842, brand_id: 1230
business_id: 6843, brand_id: 1230
business_id: 6844, brand_id: 63
business_id: 6845, brand_id: 1228
business_id: 6846, brand_id: 842
business_id: 6847, brand_id: 535
business_id: 6849, brand_id: 811
business_id: 6850, brand_id: 1203
business_id: 6852, brand_id: 693
business_id: 6854, brand_id: 218
business_id: 6856, brand_id: 886
business_id: 6857, brand_id: 695
business_id: 6860, brand_id: 516
business_id: 6861, brand_id: 253
business_id: 6862, brand_id: 609
business_id: 6863, brand_id: 687
business_id: 6864, brand_id: 694
business_id: 6865, brand_id: 1213
business_id: 6866, brand_id: 123
business_id: 6867, brand_id: 123
business_id: 6869, brand_id: 669
business_id: 6871, brand_id: 669
business_id: 6872, brand_id: 39
busi

business_id: 8552, brand_id: 478
business_id: 8553, brand_id: 186
business_id: 8554, brand_id: 186
business_id: 8555, brand_id: 186
business_id: 8556, brand_id: 186
business_id: 8557, brand_id: 819
business_id: 8558, brand_id: 831
business_id: 8559, brand_id: 319
business_id: 8560, brand_id: 977
business_id: 8561, brand_id: 631
business_id: 8562, brand_id: 318
business_id: 8563, brand_id: 318
business_id: 8564, brand_id: 965
business_id: 8565, brand_id: 711
business_id: 8566, brand_id: 93
business_id: 8567, brand_id: 271
business_id: 8569, brand_id: 353
business_id: 8570, brand_id: 145
business_id: 8571, brand_id: 643
business_id: 8572, brand_id: 296
business_id: 8575, brand_id: 658
business_id: 8576, brand_id: 551
business_id: 8577, brand_id: 654
business_id: 8578, brand_id: 271
business_id: 8579, brand_id: 271
business_id: 8581, brand_id: 551
business_id: 8582, brand_id: 453
business_id: 8583, brand_id: 939
business_id: 8585, brand_id: 431
business_id: 8586, brand_id: 295
business_id

business_id: 10282, brand_id: 1174
business_id: 10284, brand_id: 1168
business_id: 10285, brand_id: 724
business_id: 10286, brand_id: 724
business_id: 10287, brand_id: 724
business_id: 10288, brand_id: 724
business_id: 10289, brand_id: 389
business_id: 10290, brand_id: 389
business_id: 10291, brand_id: 1230
business_id: 10293, brand_id: 690
business_id: 10295, brand_id: 104
business_id: 10296, brand_id: 129
business_id: 10297, brand_id: 129
business_id: 10298, brand_id: 129
business_id: 10301, brand_id: 1018
business_id: 10303, brand_id: 388
business_id: 10305, brand_id: 458
business_id: 10307, brand_id: 1221
business_id: 10308, brand_id: 865
business_id: 10309, brand_id: 1264
business_id: 10310, brand_id: 990
business_id: 10311, brand_id: 962
business_id: 10313, brand_id: 431
business_id: 10315, brand_id: 431
business_id: 10317, brand_id: 842
business_id: 10323, brand_id: 1263
business_id: 10325, brand_id: 1168
business_id: 10328, brand_id: 690
business_id: 10329, brand_id: 145
busine

business_id: 11966, brand_id: 271
business_id: 11967, brand_id: 475
business_id: 11968, brand_id: 458
business_id: 11970, brand_id: 458
business_id: 11972, brand_id: 129
business_id: 11973, brand_id: 458
business_id: 11978, brand_id: 297
business_id: 11979, brand_id: 112
business_id: 11985, brand_id: 458
business_id: 11986, brand_id: 458
business_id: 11987, brand_id: 458
business_id: 11988, brand_id: 458
business_id: 11991, brand_id: 313
business_id: 11992, brand_id: 297
business_id: 11993, brand_id: 313
business_id: 11995, brand_id: 605
business_id: 11997, brand_id: 453
business_id: 11998, brand_id: 584
business_id: 12000, brand_id: 1076
business_id: 12001, brand_id: 1076
business_id: 12002, brand_id: 1076
business_id: 12003, brand_id: 1076
business_id: 12004, brand_id: 1076
business_id: 12005, brand_id: 82
business_id: 12006, brand_id: 902
business_id: 12007, brand_id: 492
business_id: 12008, brand_id: 1213
business_id: 12009, brand_id: 1152
business_id: 12010, brand_id: 1218
busines

business_id: 13573, brand_id: 690
business_id: 13574, brand_id: 218
business_id: 13575, brand_id: 1213
business_id: 13578, brand_id: 551
business_id: 13580, brand_id: 458
business_id: 13582, brand_id: 431
business_id: 13583, brand_id: 164
business_id: 13584, brand_id: 388
business_id: 13585, brand_id: 388
business_id: 13587, brand_id: 551
business_id: 13589, brand_id: 373
business_id: 13592, brand_id: 218
business_id: 13593, brand_id: 218
business_id: 13594, brand_id: 218


In [290]:
print(adj_UB.shape)
print(adj_UBCaB[0:100][0:100])
print(adj_UBB[0:100][0:100])
cnt = 0
for row in adj_UBBr:
    for col in row:
        if col:
            cnt += 1
print(cnt)

(21272, 13596)
[[13. 11. 20. ... 12. 22. 11.]
 [14. 10. 15. ... 12. 18. 10.]
 [32. 12. 13. ... 22. 13. 14.]
 ...
 [20. 20. 23. ... 20. 25. 20.]
 [45. 31. 46. ... 38. 47. 36.]
 [ 6.  6.  8. ...  6. 10.  6.]]
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 1. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 1.]
 [0. 0. 0. ... 0. 0. 0.]]
151762


In [292]:
adjs = [adj_UB, adj_UBB, adj_UBBU, adj_UBU, adj_UBUB, adj_UBCa, \
    adj_UBBCa, adj_UBCaB, adj_UBBr, adj_UBBBr, adj_UBBrB, adj_BB, \
    adj_BCa, adj_BBr, adj_BCaB, adj_BBrB, adj_BBBr]
filenames = ['adj_UB', 'adj_UBB', 'adj_UBBU', 'adj_UBU', 'adj_UBUB', 'adj_UBCa', \
    'adj_UBBCa', 'adj_UBCaB', 'adj_UBBr', 'adj_UBBBr', 'adj_UBBrB', 'adj_BB', \
    'adj_BCa', 'adj_BBr', 'adj_BCaB', 'adj_BBrB', 'adj_BBBr']
for adj, file in zip(adjs, filenames):
    write_pickle(dataset_path+adj_path+file+'.pickle', adj)

# Experiments

Further experiments, not related to data preprocessing.

In [10]:
# load dictionaries
path = '../yelp_dataset/adjs/'
uid2ind = read_pickle(path+'uid2ind.pickle')
bid2ind = read_pickle(path+'bid2ind.pickle')
ct_id2ind = read_pickle(path+'ct_id2ind.pickle')
ca_id2ind = read_pickle(path+'ca_id2ind.pickle')
ind2uid = read_pickle(path+'ind2uid.pickle')
ind2bid = read_pickle(path+'ind2bid.pickle')
ind2ct_id = read_pickle(path+'ind2ct_id.pickle')
ind2ca_id = read_pickle(path+'ind2ca_id.pickle')

In [30]:
del reviews

NameError: name 'reviews' is not defined

In [112]:
train_data = read_pickle('../yelp_dataset/rates/train_data.pickle')
print(type(train_data))

<class 'list'>


In [113]:
print(train_data[0])

{'business_id': 181, 'rate': 1.0, 'user_id': 21, 'timestamp': 1116877827.0}


In [99]:
def make_embedding(user_features, item_features):
    user_concat = torch.cat(user_features, 1)
    item_concat = torch.cat(item_features, 1)
    X = []
    for user in user_concat:
        tmp = [torch.cat([user,item], 0).unsqueeze(0) for item in item_concat]
        print("tmp[0].shape", tmp[0].shape)
        tmp = torch.cat(tmp, 0)
        X.append(tmp)
    X = torch.cat(X, 0)
    return X

In [3]:
def load_feature(feature_path, metapaths):
    user_features = [read_pickle(feature_path+metapath+'_user.pickle') for metapath in metapaths]
    item_features = [read_pickle(feature_path+metapath+'_item.pickle') for metapath in metapaths]
        
    return user_features, item_features


In [106]:
def make_labels(Y, n_user, n_item):
    r"""
    Parameter
    ---------
    Y: list of dict
        saves the interaction information in COO form
    
    Return
    ------
    ret: torch.tensor
        still in COO form
    """
    indices = np.array(([y['user_id'] for y in Y], [y['business_id'] for y in Y]))
    values = np.array([1. for y in Y])
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    ret = torch.sparse_coo_tensor(indices, values, size=(n_user,n_item),
                                  dtype=torch.float32, device=device, requires_grad=False)
    return ret

In [115]:
Y = make_labels(train_data, 648, 637)
print(Y.shape)
print(Y)
print(type(Y))
dense = Y.to_dense()
x = [0, 1, 2, 3]
y = [0, 1, 2, 3]
print(dense)
dense[x, y]
# print(dense[99])

torch.Size([648, 637])
tensor(indices=tensor([[ 21,  99,  99,  ..., 530, 568, 118],
                       [181, 569, 290,  ...,  34,   9, 201]]),
       values=tensor([1., 1., 1.,  ..., 1., 1., 1.]),
       device='cuda:0', size=(648, 637), nnz=36692, layout=torch.sparse_coo)
<class 'torch.Tensor'>
tensor([[0., 0., 0.,  ..., 1., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')


tensor([0., 0., 0., 0.], device='cuda:0')

In [4]:
featurepath = '../yelp_dataset/mf_features/'
metapaths = ['UB', 'UUB', 'UBUB', 'UBCaB', 'UBCiB']

In [8]:
user_features, item_features = load_feature(featurepath, metapaths)

In [104]:
# this is for test
user_features = tuple(torch.Tensor(np.zeros((3, 3))) for i in range(2)) # two 3*3 matrices
item_features = tuple(torch.Tensor(np.ones((5, 3))) for i in range(2))  # two 5*3 matrices

print("user_features:", user_features)

user_concat = torch.cat(user_features, 1)
print(user_concat[:, 0:2])
print("user_concat:", user_concat)
item_concat = torch.cat(item_features, 1)
print("item_concat:", item_concat)
user = user_concat[0]
item = item_concat[0]
print(user, item)
user = user_concat[0].view(1, 2, 3)
item = item_concat[0].view(1, 2, 3)
print(user, item)
ui_concat = torch.cat([user, item], 1)
print("ui_concat:", ui_concat)

user_features: (tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]), tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]))
tensor([[0., 0.],
        [0., 0.],
        [0., 0.]])
user_concat: tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])
item_concat: tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])
tensor([0., 0., 0., 0., 0., 0.]) tensor([1., 1., 1., 1., 1., 1.])
tensor([[[0., 0., 0.],
         [0., 0., 0.]]]) tensor([[[1., 1., 1.],
         [1., 1., 1.]]])
ui_concat: tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.],
         [1., 1., 1.]]])


In [100]:
X = make_embedding(user_features, item_features)
print(X)

tmp[0].shape torch.Size([1, 12])
tmp[0].shape torch.Size([1, 12])
tmp[0].shape torch.Size([1, 12])
tensor([[0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.]])


In [61]:
print(X.shape)

torch.Size([648, 637, 100])


5647
5153
349913
filter loop
3613
3426
236230
filter loop
2707
2445
174828
filter loop
2095
1953
139223
filter loop
1727
1574
115435
filter loop
1451
1350
100099
filter loop
1277
1222
89362
filter loop
1171
1082
80145
filter loop
1032
986
71864
filter loop
946
881
65722
filter loop
864
816
61480
filter loop
818
784
58698
filter loop
796
764
57019
filter loop
770
743
55129
filter loop
751
725
53627
filter loop
735
709
52327
filter loop
713
703
51216
filter loop
711
690
50583
filter loop
691
689
49752
filter loop
691
672
49025
filter loop
679
672
48564
filter loop
679
659
48005
filter loop
662
659
47346
filter loop
662
644
46705
filter loop
652
644
46319
filter loop
652
638
46064
filter loop
649
638
45947
filter loop
649
637
45904
filter loop
648
637
45865
filter loop
648
637
45865
filter loop
filter complete
648
637
45865


In [20]:
users_small = []
for user in users_comp:
    for u in users:
        if user['user_id'] == u:
            users_small.append(user)
print(len(users_small))
print(users_small[0].keys())

648
dict_keys(['elite', 'yelping_since', 'friends', 'average_stars', 'review_count', 'compliment_cute', 'compliment_note', 'user_id', 'compliment_hot', 'compliment_more', 'compliment_cool', 'cool', 'compliment_plain', 'compliment_funny', 'compliment_writer', 'fans', 'compliment_photos', 'compliment_list', 'name', 'compliment_profile', 'useful', 'funny'])


In [21]:
busi_small = []
for busi in busi_comp:
    for b in businesses:
        if busi['business_id'] == b:
            busi_small.append(busi)
print(len(busi_small))
print(busi_small[0].keys())

637
dict_keys(['categories', 'is_open', 'hours', 'attributes', 'address', 'longitude', 'name', 'state', 'postal_code', 'latitude', 'business_id', 'review_count', 'stars', 'city'])


In [22]:
write_pickle('../yelp_dataset/filtered/users-small.pickle', users_small)

In [23]:
write_pickle('../yelp_dataset/filtered/businesses-small.pickle', busi_small)

In [27]:
write_pickle('../yelp_dataset/filtered/reviews-small.pickle', reviews_small)

In [26]:
# test sparsity
adj = read_pickle('../yelp_dataset/adjs/adj_UBCiB.pickle')
print(type(adj))
# for i in adj_UB

<class 'numpy.ndarray'>


In [27]:
import scipy.sparse as sp
sparse = sp.csr_matrix(adj)

In [28]:
nnz = sparse.nnz
size = adj.shape[0] * adj.shape[1]
print(nnz / size)

0.9474751439037153


注意到除了UB的稠密度只有0.08，其余的都有至少0.85以上，说明metapath的邻接矩阵实际上很稠密

In [62]:
x = torch.ones(10)*10
x[3] = 7
x = x.unsqueeze(1)
print(x)
V = torch.ones(10, 5)
V[:, 3] *= 3
print(V)
# print(torch.matmul(x, V))
out = torch.mul(x, V).sum(1, keepdim=True)
print(out)
out_t = out.sum(0).squeeze()
print(out_t.size())
print(out_t)

tensor([[10.],
        [10.],
        [10.],
        [ 7.],
        [10.],
        [10.],
        [10.],
        [10.],
        [10.],
        [10.]])
tensor([[1., 1., 1., 3., 1.],
        [1., 1., 1., 3., 1.],
        [1., 1., 1., 3., 1.],
        [1., 1., 1., 3., 1.],
        [1., 1., 1., 3., 1.],
        [1., 1., 1., 3., 1.],
        [1., 1., 1., 3., 1.],
        [1., 1., 1., 3., 1.],
        [1., 1., 1., 3., 1.],
        [1., 1., 1., 3., 1.]])
tensor([[70.],
        [70.],
        [70.],
        [49.],
        [70.],
        [70.],
        [70.],
        [70.],
        [70.],
        [70.]])
torch.Size([])
tensor(679.)


In [64]:
out_t = out_t.repeat(1, 2)
print(out_t)
out_t[0][1] = -out_t[0][0]
print(out_t[0][1])
print(out_t)

tensor([[ 679., -679.,  679., -679.]])
tensor(-679.)


AttributeError: 'Tensor' object has no attribute 'astype'

# Process valid dataset

In [7]:
valid_data = read_pickle('../yelp_dataset/rates/valid_with_neg_sample.pickle')
train_data = read_pickle('../yelp_dataset/rates/train_data.pickle')

In [9]:
len(train_data)

594375

In [49]:
pos = [[y['user_id'], pos_id, 1] for y in valid_data for pos_id in y['pos_business_id']]
neg = [[y['user_id'], neg_id, 0] for y in valid_data for neg_id in y['neg_business_id']]
ret = pos + neg
ret[0:100]
ret_array = np.array(ret, dtype=np.int64)
ret_array.dtype

dtype('int64')

In [39]:
print(len(valid_data[1]['neg_business_id']))

50


In [33]:
t_set = set(i for i in range(10))
# t_set = [i for i in range(10)]
t_array = np.asarray(list(t_set))
t_tensor = torch.as_tensor(t_array).repeat(2, 1)
print(t_tensor)
print(torch.topk(t_tensor, 2, dim=0))

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
torch.return_types.topk(
values=tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),
indices=tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))


In [173]:
class FMG_YelpDataset(Dataset):
    def __init__(self, interaction_data, n_user, n_item, neg_sample_n, mode, cuda=False):
        r"""
        Parameters
        ----------
        data: list of dict
        
        neg_sample_n: int,
            number of negative samples to be sampled
            
        mode: 'train', 'valid' or 'test',
            among which 'valid' and 'test' have the same functions
            
        Returns
        -------
        If mode == 'train':
        users: list of user ids
        
        items: list of business ids
        
        labels: corresponding labels
        
        If mode == 'valid' or 'test':
        user: list, 
            only one user id, but list size is len(items)
        
        items: list,
            positive samples and negative samples
            
        labels: corresponding labels
        """
        super().__init__()
        self.neg_sample_n = neg_sample_n
        self.mode = mode
        self.n_user = n_user
        self.n_item = n_item
        self.item_ids = set(i for i in range(self.n_item))
        self.device = torch.device('cuda:0' if cuda else 'cpu')

        if self.mode == 'train':
            pos_sampleset_list = [set() for i in range(self.n_user)]
            for y in interaction_data:
                pos_sampleset_list[y['user_id']].add(y['business_id'])            
            self.pos_sampleset_list = pos_sampleset_list    # used for train set
            self.data = torch.tensor(np.asarray([[y['user_id'], y['business_id'], 1] for y in interaction_data]), 
                                     device=self.device)

        elif self.mode == 'valid' or self.mode == 'test':
            self.data = []
            for input in interaction_data:
                pos = [i for i in input['pos_business_id']]
                neg = [i for i in input['neg_business_id']]
                items = pos + neg[0:self.neg_sample_n]
                user = [input['user_id']] * len(items)
                labels = [1] * len(pos) + [0] * self.neg_sample_n
                self.data.append([torch.tensor(np.asarray(user), device=self.device), 
                                  torch.tensor(np.asarray(items),device=self.device), 
                                  torch.tensor(np.asarray(labels), device=self.device)])
        
    def __getitem__(self, index):
        r"""
        ps: index is a number.
        
        return the uid and all pos bids and neg bids, along with labels.
        
        all in python list
        """
        if self.mode == 'train':
            pos_ind = self.data[index][0:2]
            user = pos_ind[0].item()    # user id
            pos_ind = pos_ind.unsqueeze(0)
            neg_sample_array = np.asarray(list(self.item_ids - self.pos_sampleset_list[user]))
            neg_samples = np.random.choice(neg_sample_array, self.neg_sample_n, replace=False)
            print("test: ", [user, neg_samples[0]])
            neg_inds = torch.tensor(np.asarray([[user, neg_sample] for neg_sample in neg_samples]), device=self.device)

            indices = torch.cat((pos_ind, neg_inds), 0)
            labels = torch.tensor(np.asarray([1] + [0]*self.neg_sample_n), device=self.device)
            return indices, labels

        elif self.mode == 'valid' or self.mode == 'test':
            r"""
            here the data contains 'pos_business_id' and 'neg_business_id' keys.
            """
            return self.data[index]
        
    def __len__(self):
        return len(self.data)    # available for both list and np.ndarray

In [174]:
dataset = FMG_YelpDataset(valid_data, len(users), len(businesses), neg_sample_n=3, mode='valid')
uid, bids, labels = dataset.__getitem__(2)
print(uid, bids, labels)
train_set = FMG_YelpDataset(train_data, len(users), len(businesses), neg_sample_n=4, mode='train')
print(train_set.__getitem__(2))

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]) tensor([ 10, 615,  73, 497, 540,  37, 398, 356, 255, 513, 251, 233,  12, 325,
        625, 484, 520, 353,  28, 137, 614]) tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0])
test:  [99, 617]
(tensor([[ 99, 290],
        [ 99, 617],
        [ 99, 278],
        [ 99, 180],
        [ 99, 235]]), tensor([1, 0, 0, 0, 0]))


In [147]:
valid = np.array(valid_data)
# valid = [1, 2, 3, 4, 5, 6]
# index = [2, 3, 4]
# input = valid[index]
# input['neg_business_id']

In [180]:
a = torch.randn(3, 4, 2)
print(a.shape)
print(a.squeeze(2).shape)
print(a.reshape(-1).shape)

torch.Size([3, 4, 2])
torch.Size([3, 4, 2])
torch.Size([24])
