In [9]:
import torch
import torch.nn as nn
import pickle as pk
import numpy as np
from model import LocalPredictorFullSearchOffset
import random

In [2]:
data = dict({})
for d in range(1, 29):
    filename = '../data/UsersInTokyo2012/201210{:02d}_interp.pk'.format(d)
    print(filename)
    with open(filename, 'rb') as f:
        data[d] = pk.load(f)

../data/UsersInTokyo2012/20121001_interp.pk
../data/UsersInTokyo2012/20121002_interp.pk
../data/UsersInTokyo2012/20121003_interp.pk
../data/UsersInTokyo2012/20121004_interp.pk
../data/UsersInTokyo2012/20121005_interp.pk
../data/UsersInTokyo2012/20121006_interp.pk
../data/UsersInTokyo2012/20121007_interp.pk
../data/UsersInTokyo2012/20121008_interp.pk
../data/UsersInTokyo2012/20121009_interp.pk
../data/UsersInTokyo2012/20121010_interp.pk
../data/UsersInTokyo2012/20121011_interp.pk
../data/UsersInTokyo2012/20121012_interp.pk
../data/UsersInTokyo2012/20121013_interp.pk
../data/UsersInTokyo2012/20121014_interp.pk
../data/UsersInTokyo2012/20121015_interp.pk
../data/UsersInTokyo2012/20121016_interp.pk
../data/UsersInTokyo2012/20121017_interp.pk
../data/UsersInTokyo2012/20121018_interp.pk
../data/UsersInTokyo2012/20121019_interp.pk
../data/UsersInTokyo2012/20121020_interp.pk
../data/UsersInTokyo2012/20121021_interp.pk
../data/UsersInTokyo2012/20121022_interp.pk
../data/UsersInTokyo2012/2012102

In [3]:
uid_sets = dict({})
for d in data:
    uid_sets[d] = set(list(data[d].keys()))

In [4]:
uid_doc_set = set([])
for d in range(1, 15):
    uid_doc_set = uid_doc_set | uid_sets[d]

In [5]:
data_doc = dict({})
data_qry_train = dict({})
data_qry_test = dict({})

for d in range(1, 15):
    for uid in data[d]:
        if uid not in data_doc:
            data_doc[uid] = [data[d][uid]]
        else:
            data_doc[uid].append(data[d][uid])

for d in range(15, 22):
    for uid in data[d]:
        if uid not in data_qry_train:
            data_qry_train[uid] = [data[d][uid]]
        else:
            data_qry_train[uid].append(data[d][uid])
            
for d in range(22, 29):
    for uid in data[d]:
        if uid not in data_qry_test:
            data_qry_test[uid] = [data[d][uid]]
        else:
            data_qry_test[uid].append(data[d][uid])

In [6]:
for uid in data_doc:
    data_doc[uid] = torch.LongTensor(data_doc[uid]).cuda(0)

for uid in data_qry_train:
    data_qry_train[uid] = torch.LongTensor(data_qry_train[uid]).cuda(0)
    
for uid in data_qry_test:
    data_qry_test[uid] = torch.LongTensor(data_qry_test[uid]).cuda(0)

In [7]:
num_locs = 1600
loc_embedding_dim = 128
T = 96
num_time = T
time_embedding_dim = 32
hidden_dim = 256
latent_dim = 256
n_layers = 2

In [10]:
local_predictor = LocalPredictorFullSearchOffset(num_locs, loc_embedding_dim, num_time, time_embedding_dim, hidden_dim, latent_dim, n_layers).cuda(0)

In [11]:
optimizer = torch.optim.RMSprop(local_predictor.parameters(), lr=1e-3)
optimizer_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

In [12]:
dT = 4

In [13]:
training_loss = dict({})
validation_loss = dict({})

In [14]:
batch_size = 128
user_list_train = list(data_qry_train.keys())

for epoch in range(1, 21):
    optimizer.zero_grad()
    optimizer_scheduler.step()
    
    random.shuffle(user_list_train)
    
    avg_loss = 0.0
    cnt = 0
    
    for uid in user_list_train:
        t = np.random.randint(T - 2 * dT + 1)
        x_loc_qry = data_qry_train[uid][:, t: t + dT]
        x_t_qry = torch.zeros_like(x_loc_qry) + t
        y = data_qry_train[uid][:, t + 2 * dT - 1]
        
        if uid not in data_doc:
            loss = local_predictor(x_loc_qry, x_t_qry, None, None, y)
        else:
            x_loc_doc = data_doc[uid][:, t: t + 2 * dT]
            x_t_doc = torch.zeros_like(x_loc_doc) + t
            loss = local_predictor(x_loc_qry, x_t_qry, x_loc_doc, x_t_doc, y)
        
        loss.backward()
        cnt += data_qry_train[uid].shape[0]
        avg_loss += loss.item()

        if cnt % batch_size == 0:
            optimizer.step()
            optimizer.zero_grad()

            print('Epoch {:02d}, avg_loss = {:.4f}'.format(epoch, avg_loss / cnt), end='\r')
    training_loss[epoch] = avg_loss / cnt
    print('')
    
    if epoch % 5 == 0:
        # testing
        cnt = 0
        avg_loss = 0.0

        with torch.no_grad():
            for uid in data_qry_test:
                t = np.random.randint(T - 2 * dT + 1)
                x_loc_qry = data_qry_test[uid][:, t: t + dT]
                x_t_qry = torch.zeros_like(x_loc_qry) + t
                y = data_qry_test[uid][:, t + 2 * dT - 1]

                if uid not in data_doc:
                    loss = local_predictor(x_loc_qry, x_t_qry, None, None, y)
                else:
                    x_loc_doc = data_doc[uid][:, t: t + 2 * dT]
                    x_t_doc = torch.zeros_like(x_loc_doc) + t
                    loss = local_predictor(x_loc_qry, x_t_qry, x_loc_doc, x_t_doc, y)

                cnt += data_qry_test[uid].shape[0]
                avg_loss += loss.item()

        print('Validation Loss = {:.4f}'.format(avg_loss / cnt))
        validation_loss[epoch] = avg_loss / cnt

Epoch 01, avg_loss = 2.1359
Epoch 02, avg_loss = 1.5604
Epoch 03, avg_loss = 1.4180
Epoch 04, avg_loss = 1.3662
Epoch 05, avg_loss = 1.3188
Validation Loss = 1.3486
Epoch 06, avg_loss = 1.3044
Epoch 07, avg_loss = 1.2795
Epoch 08, avg_loss = 1.2735
Epoch 09, avg_loss = 1.2669
Epoch 10, avg_loss = 1.2561
Validation Loss = 1.2929
Epoch 11, avg_loss = 1.2584
Epoch 12, avg_loss = 1.2574
Epoch 13, avg_loss = 1.2584
Epoch 14, avg_loss = 1.2541
Epoch 15, avg_loss = 1.2534
Validation Loss = 1.3008
Epoch 16, avg_loss = 1.2538
Epoch 17, avg_loss = 1.2584
Epoch 18, avg_loss = 1.2492
Epoch 19, avg_loss = 1.2490
Epoch 20, avg_loss = 1.2568
Validation Loss = 1.2899


In [15]:
torch.save(local_predictor, 'local_predictor_full_search_offset.pytorch')

In [17]:
with open('./centralized_local_predictor_full_search_oct_4096_training_loss.pk', 'wb') as f:
    pk.dump(training_loss, f)
with open('./centralized_local_predictor_full_search_oct_4096_validation_loss.pk', 'wb') as f:
    pk.dump(validation_loss, f)