In [1]:
import torch
import torch.nn as nn
from model import LocalPredictor

In [2]:
import os

In [3]:
import pickle as pk

In [4]:
import numpy as np

In [5]:
local_predictor = torch.load('./results_osaka/local_predictor.pytorch').cuda(1)

In [6]:
date_list = [(2012, m, d) for m in range(12, 13) for d in range(1, 32)] + [(2013, m, d) for m in range(1, 3) for d in range(1, 32)]

In [7]:
T = 96
dT = 4
batch_size = 256

In [8]:
data = dict({})
for m in range(11, 12):
    for d in range(1, 29):
        filename = '/data/fan/UsersInOsakaProcessed/2012{:02d}{:02d}_interp.pk'.format(m, d)
        print(filename)
        with open(filename, 'rb') as f:
            data[(m, d)] = pk.load(f)

/data/fan/UsersInOsakaProcessed/20121101_interp.pk
/data/fan/UsersInOsakaProcessed/20121102_interp.pk
/data/fan/UsersInOsakaProcessed/20121103_interp.pk
/data/fan/UsersInOsakaProcessed/20121104_interp.pk
/data/fan/UsersInOsakaProcessed/20121105_interp.pk
/data/fan/UsersInOsakaProcessed/20121106_interp.pk
/data/fan/UsersInOsakaProcessed/20121107_interp.pk
/data/fan/UsersInOsakaProcessed/20121108_interp.pk
/data/fan/UsersInOsakaProcessed/20121109_interp.pk
/data/fan/UsersInOsakaProcessed/20121110_interp.pk
/data/fan/UsersInOsakaProcessed/20121111_interp.pk
/data/fan/UsersInOsakaProcessed/20121112_interp.pk
/data/fan/UsersInOsakaProcessed/20121113_interp.pk
/data/fan/UsersInOsakaProcessed/20121114_interp.pk
/data/fan/UsersInOsakaProcessed/20121115_interp.pk
/data/fan/UsersInOsakaProcessed/20121116_interp.pk
/data/fan/UsersInOsakaProcessed/20121117_interp.pk
/data/fan/UsersInOsakaProcessed/20121118_interp.pk
/data/fan/UsersInOsakaProcessed/20121119_interp.pk
/data/fan/UsersInOsakaProcessed

In [9]:
data_doc = dict({})

for date in data:
    for uid in data[date]:
        if uid not in data_doc:
            data_doc[uid] = [data[date][uid]]
        else:
            data_doc[uid].append(data[date][uid])

In [10]:
for uid in data_doc:
    data_doc[uid] = torch.LongTensor(data_doc[uid]).cuda(1)

In [11]:
avg_loss_dict = dict({})

for y, m, d in date_list:
    filename = '/data/fan/UsersInOsakaProcessed/{:04d}{:02d}{:02d}_interp.pk'.format(y, m, d)
    if not os.path.isfile(filename):
        continue
    print(filename)
    
    with open(filename, 'rb') as f:
        avg_loss = 0.0
        cnt_day = 0
        data = pk.load(f)
        for uid in data:
            data[uid] = torch.LongTensor([data[uid]]).cuda(1)
            
        with torch.no_grad():
            for t in range(T - 2 * dT):
                cnt_time = 0
                avg_loss_t = 0.0
                
                for uid in data:
                    if np.random.ranf() < 0.05:
                        cnt_day += 1
                        cnt_time += 1
                        data_x = data[uid][:, t: t + dT]
                        data_t = torch.zeros_like(data_x) + t
                        data_y = data[uid][:, t + 2 * dT - 1]
                        if uid in data_doc:
                            data_doc_x = data_doc[uid][:, t: t + 2 * dT]
                            data_doc_t = torch.zeros_like(data_doc_x) + t
                            loss = local_predictor(data_x, data_t, data_doc_x, data_doc_t, data_y).item()
                        else:
                            loss = local_predictor(data_x, data_t, None, None, data_y).item()

                        avg_loss += loss
                        avg_loss_t += loss

                avg_loss_t /= cnt_time
                avg_loss_dict[(y, m, d, t)] = avg_loss_t
                
            avg_loss /= cnt_day
            avg_loss_dict[(y, m, d)] = avg_loss
            print('Evaluation: {:.4f}'.format(avg_loss))

/data/fan/UsersInOsakaProcessed/20121201_interp.pk
Evaluation: 1.3858
/data/fan/UsersInOsakaProcessed/20121202_interp.pk
Evaluation: 1.3316
/data/fan/UsersInOsakaProcessed/20121203_interp.pk
Evaluation: 1.2214
/data/fan/UsersInOsakaProcessed/20121204_interp.pk
Evaluation: 1.2449
/data/fan/UsersInOsakaProcessed/20121205_interp.pk
Evaluation: 1.2620
/data/fan/UsersInOsakaProcessed/20121206_interp.pk
Evaluation: 1.2574
/data/fan/UsersInOsakaProcessed/20121207_interp.pk
Evaluation: 1.3114
/data/fan/UsersInOsakaProcessed/20121208_interp.pk
Evaluation: 1.3938
/data/fan/UsersInOsakaProcessed/20121209_interp.pk
Evaluation: 1.3426
/data/fan/UsersInOsakaProcessed/20121210_interp.pk
Evaluation: 1.2504
/data/fan/UsersInOsakaProcessed/20121211_interp.pk
Evaluation: 1.2848
/data/fan/UsersInOsakaProcessed/20121212_interp.pk
Evaluation: 1.2987
/data/fan/UsersInOsakaProcessed/20121213_interp.pk
Evaluation: 1.2958
/data/fan/UsersInOsakaProcessed/20121214_interp.pk
Evaluation: 1.3465
/data/fan/UsersInOsa

In [12]:
with open('./results_osaka/evaluation_local_nll.pk', 'wb') as f:
    pk.dump(avg_loss_dict, f)