In [1]:
import torch
import torch.nn as nn
import pickle as pk
import numpy as np
from model import LocalPredictor
import random

In [2]:
data = dict({})
for d in range(1, 29):
    filename = '/data/fan/UsersInOsakaProcessed/201210{:02d}_interp.pk'.format(d)
    print(filename)
    with open(filename, 'rb') as f:
        data[d] = pk.load(f)

/data/fan/UsersInOsakaProcessed/20121001_interp.pk
/data/fan/UsersInOsakaProcessed/20121002_interp.pk
/data/fan/UsersInOsakaProcessed/20121003_interp.pk
/data/fan/UsersInOsakaProcessed/20121004_interp.pk
/data/fan/UsersInOsakaProcessed/20121005_interp.pk
/data/fan/UsersInOsakaProcessed/20121006_interp.pk
/data/fan/UsersInOsakaProcessed/20121007_interp.pk
/data/fan/UsersInOsakaProcessed/20121008_interp.pk
/data/fan/UsersInOsakaProcessed/20121009_interp.pk
/data/fan/UsersInOsakaProcessed/20121010_interp.pk
/data/fan/UsersInOsakaProcessed/20121011_interp.pk
/data/fan/UsersInOsakaProcessed/20121012_interp.pk
/data/fan/UsersInOsakaProcessed/20121013_interp.pk
/data/fan/UsersInOsakaProcessed/20121014_interp.pk
/data/fan/UsersInOsakaProcessed/20121015_interp.pk
/data/fan/UsersInOsakaProcessed/20121016_interp.pk
/data/fan/UsersInOsakaProcessed/20121017_interp.pk
/data/fan/UsersInOsakaProcessed/20121018_interp.pk
/data/fan/UsersInOsakaProcessed/20121019_interp.pk
/data/fan/UsersInOsakaProcessed

In [3]:
uid_sets = dict({})
for d in data:
    uid_sets[d] = set(list(data[d].keys()))

In [4]:
uid_doc_set = set([])
for d in range(1, 15):
    uid_doc_set = uid_doc_set | uid_sets[d]

In [5]:
data_doc = dict({})
data_qry_train = dict({})
data_qry_test = dict({})

for d in range(1, 15):
    for uid in data[d]:
        if uid not in data_doc:
            data_doc[uid] = [data[d][uid]]
        else:
            data_doc[uid].append(data[d][uid])

for d in range(15, 22):
    for uid in data[d]:
        if uid not in data_qry_train:
            data_qry_train[uid] = [data[d][uid]]
        else:
            data_qry_train[uid].append(data[d][uid])
            
for d in range(22, 29):
    for uid in data[d]:
        if uid not in data_qry_test:
            data_qry_test[uid] = [data[d][uid]]
        else:
            data_qry_test[uid].append(data[d][uid])

In [6]:
for uid in data_doc:
    data_doc[uid] = torch.LongTensor(data_doc[uid]).cuda(1)

for uid in data_qry_train:
    data_qry_train[uid] = torch.LongTensor(data_qry_train[uid]).cuda(1)
    
for uid in data_qry_test:
    data_qry_test[uid] = torch.LongTensor(data_qry_test[uid]).cuda(1)

In [7]:
num_locs = 1600
loc_embedding_dim = 128
T = 96
num_time = T
time_embedding_dim = 32
hidden_dim = 256
latent_dim = 256
n_layers = 2

In [8]:
local_predictor = LocalPredictor(num_locs, loc_embedding_dim, num_time, time_embedding_dim, hidden_dim, latent_dim, n_layers).cuda(1)

In [9]:
optimizer = torch.optim.RMSprop(local_predictor.parameters(), lr=1e-3)
optimizer_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.8)

In [10]:
dT = 4

In [11]:
training_loss = dict({})
validation_loss = dict({})

In [14]:
batch_size = 128
user_list_train = list(data_qry_train.keys())

for epoch in range(1, 11):
    optimizer.zero_grad()
    optimizer_scheduler.step()
    
    random.shuffle(user_list_train)
    
    avg_loss = 0.0
    cnt = 0
    
    for uid in user_list_train:
        t = np.random.randint(T - 2 * dT + 1)
        x_loc_qry = data_qry_train[uid][:, t: t + dT]
        x_t_qry = torch.zeros_like(x_loc_qry) + t
        y = data_qry_train[uid][:, t + 2 * dT - 1]
        
        if uid not in data_doc:
            loss = local_predictor(x_loc_qry, x_t_qry, None, None, y)
        else:
            x_loc_doc = []
            x_t_doc = []
            for j in range(-2, 2):
                if t + j >= 0 and t + 2 * dT + j <= T:
                    tmp = data_doc[uid][:, t + j: t + 2 * dT + j]
                    x_loc_doc.append(tmp)
                    x_t_doc.append(torch.zeros_like(tmp) + t + j)
            loss = local_predictor(x_loc_qry, x_t_qry, torch.cat(x_loc_doc, dim=0), torch.cat(x_t_doc, dim=0), y)
        
        loss.backward()
        cnt += data_qry_train[uid].shape[0]
        avg_loss += loss.item()

        if cnt % batch_size == 0:
            optimizer.step()
            optimizer.zero_grad()

            print('Epoch {:02d}, avg_loss = {:.4f}'.format(epoch, avg_loss / cnt), end='\r')
    training_loss[epoch] = avg_loss / cnt
    print('')
    
    if epoch % 5 == 0:
        # testing
        cnt = 0
        avg_loss = 0.0

        with torch.no_grad():
            for uid in data_qry_test:
                t = np.random.randint(T - 2 * dT + 1)
                x_loc_qry = data_qry_test[uid][:, t: t + dT]
                x_t_qry = torch.zeros_like(x_loc_qry) + t
                y = data_qry_test[uid][:, t + 2 * dT - 1]

                if uid not in data_doc:
                    loss = local_predictor(x_loc_qry, x_t_qry, None, None, y)
                else:
                    x_loc_doc = []
                    x_t_doc = []
                    for j in range(-2, 2):
                        if t + j >= 0 and t + 2 * dT + j <= T:
                            tmp = data_doc[uid][:, t + j: t + 2 * dT + j]
                            x_loc_doc.append(tmp)
                            x_t_doc.append(torch.zeros_like(tmp) + t + j)
                    loss = local_predictor(x_loc_qry, x_t_qry, torch.cat(x_loc_doc, dim=0), torch.cat(x_t_doc, dim=0), y)

                cnt += data_qry_test[uid].shape[0]
                avg_loss += loss.item()

        print('Validation Loss = {:.4f}'.format(avg_loss / cnt))
        validation_loss[epoch] = avg_loss / cnt

Epoch 01, avg_loss = 1.2566
Epoch 02, avg_loss = 1.2526
Epoch 03, avg_loss = 1.2577
Epoch 04, avg_loss = 1.2528
Epoch 05, avg_loss = 1.2510
Validation Loss = 1.3123
Epoch 06, avg_loss = 1.2425
Epoch 07, avg_loss = 1.2422
Epoch 08, avg_loss = 1.2456
Epoch 09, avg_loss = 1.2544
Epoch 10, avg_loss = 1.2529
Validation Loss = 1.3051


In [15]:
torch.save(local_predictor, './results_osaka/local_predictor_broader_5.pytorch')

In [14]:
with open('./results_osaka/centralized_local_predictor_training_loss.pk', 'wb') as f:
    pk.dump(training_loss, f)
with open('./results_osaka/centralized_local_predictor_validation_loss.pk', 'wb') as f:
    pk.dump(validation_loss, f)