In [1]:
import torch
import torch.nn as nn
from model import LocalPredictor

In [2]:
import os

In [3]:
import pickle as pk

In [4]:
import numpy as np

In [5]:
local_predictor = torch.load('./results_tokyo/local_predictor_broader_5.pytorch').cuda(0)

In [6]:
date_list = [(2012, m, d) for m in range(12, 13) for d in range(1, 32)] + [(2013, m, d) for m in range(1, 3) for d in range(1, 32)]

In [7]:
T = 96
dT = 4
batch_size = 256

In [8]:
data = dict({})
for m in range(11, 12):
    for d in range(1, 29):
        filename = '/data/fan/UsersInTokyoProcessed/2012{:02d}{:02d}_interp.pk'.format(m, d)
        print(filename)
        with open(filename, 'rb') as f:
            data[(m, d)] = pk.load(f)

/data/fan/UsersInTokyoProcessed/20121101_interp.pk
/data/fan/UsersInTokyoProcessed/20121102_interp.pk
/data/fan/UsersInTokyoProcessed/20121103_interp.pk
/data/fan/UsersInTokyoProcessed/20121104_interp.pk
/data/fan/UsersInTokyoProcessed/20121105_interp.pk
/data/fan/UsersInTokyoProcessed/20121106_interp.pk
/data/fan/UsersInTokyoProcessed/20121107_interp.pk
/data/fan/UsersInTokyoProcessed/20121108_interp.pk
/data/fan/UsersInTokyoProcessed/20121109_interp.pk
/data/fan/UsersInTokyoProcessed/20121110_interp.pk
/data/fan/UsersInTokyoProcessed/20121111_interp.pk
/data/fan/UsersInTokyoProcessed/20121112_interp.pk
/data/fan/UsersInTokyoProcessed/20121113_interp.pk
/data/fan/UsersInTokyoProcessed/20121114_interp.pk
/data/fan/UsersInTokyoProcessed/20121115_interp.pk
/data/fan/UsersInTokyoProcessed/20121116_interp.pk
/data/fan/UsersInTokyoProcessed/20121117_interp.pk
/data/fan/UsersInTokyoProcessed/20121118_interp.pk
/data/fan/UsersInTokyoProcessed/20121119_interp.pk
/data/fan/UsersInTokyoProcessed

In [9]:
data_doc = dict({})

for date in data:
    for uid in data[date]:
        if uid not in data_doc:
            data_doc[uid] = [data[date][uid]]
        else:
            data_doc[uid].append(data[date][uid])

In [10]:
for uid in data_doc:
    data_doc[uid] = torch.LongTensor(data_doc[uid]).cuda(0)

In [11]:
mrr_avg_rank_dict = dict({})

for y, m, d in date_list:
    filename = '/data/fan/UsersInTokyoProcessed/{:04d}{:02d}{:02d}_interp.pk'.format(y, m, d)
    if not os.path.isfile(filename):
        continue
    print(filename)
    
    with open(filename, 'rb') as f:
        cnt_day = 0
        data = pk.load(f)
        for uid in data:
            data[uid] = torch.LongTensor([data[uid]]).cuda(0)
            
        mrr = 0.0
        avg_rank = 0.0
        acc5 = 0.0
        
        with torch.no_grad():
            for t in range(T - 2 * dT):
                
                for uid in data:
                    if np.random.ranf() < 0.05:
                        cnt_day += 1
                        data_x = data[uid][:, t: t + dT]
                        data_t = torch.zeros_like(data_x) + t
                        data_y = data[uid][:, t + 2 * dT - 1].cpu().numpy()
                        if uid in data_doc:
                            x_loc_doc = []
                            x_t_doc = []
                            for j in range(-2, 3):
                                if t + j >= 0 and t + 2 * dT + j <= T:
                                    tmp = data_doc[uid][:, t + j: t + 2 * dT + j]
                                    x_loc_doc.append(tmp)
                                    x_t_doc.append(torch.zeros_like(tmp) + t + j)
                            pred = local_predictor.predict(data_x, data_t, torch.cat(x_loc_doc, dim=0), torch.cat(x_t_doc, dim=0)).cpu().numpy()
                        else:
                            pred = local_predictor.predict(data_x, data_t, None, None).cpu().numpy()
                            
                        rank = np.where((np.argsort(-pred, axis=1, ).T - data_y).T == 0)[1] + 1
                        mrr += np.sum(1 / rank)
                        avg_rank += np.sum(rank)
                        acc5 += np.sum(rank <= 5)
                        
            print('Evaluation: mrr = {:.4f}, avg_rank = {:.4f}, acc@5 = {:.4f}'.format(mrr / cnt_day, avg_rank / cnt_day, acc5 / cnt_day))
            mrr_avg_rank_dict[(y, m, d)] = (mrr / cnt_day, avg_rank / cnt_day, acc5 / cnt_day)

/data/fan/UsersInTokyoProcessed/20121201_interp.pk
Evaluation: mrr = 0.8145, avg_rank = 17.4593, acc@5 = 0.8731
/data/fan/UsersInTokyoProcessed/20121202_interp.pk
Evaluation: mrr = 0.8261, avg_rank = 16.2655, acc@5 = 0.8814
/data/fan/UsersInTokyoProcessed/20121203_interp.pk
Evaluation: mrr = 0.8385, avg_rank = 13.0843, acc@5 = 0.8953
/data/fan/UsersInTokyoProcessed/20121204_interp.pk
Evaluation: mrr = 0.8347, avg_rank = 13.5921, acc@5 = 0.8928
/data/fan/UsersInTokyoProcessed/20121205_interp.pk
Evaluation: mrr = 0.8321, avg_rank = 13.7138, acc@5 = 0.8901
/data/fan/UsersInTokyoProcessed/20121206_interp.pk
Evaluation: mrr = 0.8304, avg_rank = 14.1563, acc@5 = 0.8883
/data/fan/UsersInTokyoProcessed/20121207_interp.pk
Evaluation: mrr = 0.8250, avg_rank = 15.0040, acc@5 = 0.8835
/data/fan/UsersInTokyoProcessed/20121208_interp.pk
Evaluation: mrr = 0.8134, avg_rank = 17.8889, acc@5 = 0.8711
/data/fan/UsersInTokyoProcessed/20121209_interp.pk
Evaluation: mrr = 0.8253, avg_rank = 16.6384, acc@5 =

Evaluation: mrr = 0.8313, avg_rank = 17.8045, acc@5 = 0.8833
/data/fan/UsersInTokyoProcessed/20130213_interp.pk
Evaluation: mrr = 0.8284, avg_rank = 17.7971, acc@5 = 0.8817
/data/fan/UsersInTokyoProcessed/20130214_interp.pk
Evaluation: mrr = 0.8280, avg_rank = 18.1824, acc@5 = 0.8809
/data/fan/UsersInTokyoProcessed/20130215_interp.pk
Evaluation: mrr = 0.8249, avg_rank = 18.8182, acc@5 = 0.8783
/data/fan/UsersInTokyoProcessed/20130216_interp.pk
Evaluation: mrr = 0.8212, avg_rank = 20.7361, acc@5 = 0.8736
/data/fan/UsersInTokyoProcessed/20130217_interp.pk
Evaluation: mrr = 0.8349, avg_rank = 18.5812, acc@5 = 0.8863
/data/fan/UsersInTokyoProcessed/20130218_interp.pk
Evaluation: mrr = 0.8324, avg_rank = 18.3028, acc@5 = 0.8836
/data/fan/UsersInTokyoProcessed/20130219_interp.pk
Evaluation: mrr = 0.8301, avg_rank = 18.5972, acc@5 = 0.8818
/data/fan/UsersInTokyoProcessed/20130220_interp.pk
Evaluation: mrr = 0.8250, avg_rank = 18.6905, acc@5 = 0.8785
/data/fan/UsersInTokyoProcessed/20130221_in

In [12]:
with open('./results_tokyo/evaluation_local_broader_5_rank.pk', 'wb') as f:
    pk.dump(mrr_avg_rank_dict, f)