In [1]:
import torch
import torch.nn as nn
import pickle as pk
import numpy as np
from model import LocalPredictorFullSearch
import random

In [2]:
data = dict({})
for d in range(1, 29):
    filename = '/data/fan/UsersInTokyoProcessed/201211{:02d}_interp.pk'.format(d)
    print(filename)
    with open(filename, 'rb') as f:
        data[d] = pk.load(f)

/data/fan/UsersInTokyoProcessed/20121101_interp.pk
/data/fan/UsersInTokyoProcessed/20121102_interp.pk
/data/fan/UsersInTokyoProcessed/20121103_interp.pk
/data/fan/UsersInTokyoProcessed/20121104_interp.pk
/data/fan/UsersInTokyoProcessed/20121105_interp.pk
/data/fan/UsersInTokyoProcessed/20121106_interp.pk
/data/fan/UsersInTokyoProcessed/20121107_interp.pk
/data/fan/UsersInTokyoProcessed/20121108_interp.pk
/data/fan/UsersInTokyoProcessed/20121109_interp.pk
/data/fan/UsersInTokyoProcessed/20121110_interp.pk
/data/fan/UsersInTokyoProcessed/20121111_interp.pk
/data/fan/UsersInTokyoProcessed/20121112_interp.pk
/data/fan/UsersInTokyoProcessed/20121113_interp.pk
/data/fan/UsersInTokyoProcessed/20121114_interp.pk
/data/fan/UsersInTokyoProcessed/20121115_interp.pk
/data/fan/UsersInTokyoProcessed/20121116_interp.pk
/data/fan/UsersInTokyoProcessed/20121117_interp.pk
/data/fan/UsersInTokyoProcessed/20121118_interp.pk
/data/fan/UsersInTokyoProcessed/20121119_interp.pk
/data/fan/UsersInTokyoProcessed

In [3]:
uid_sets = dict({})
for d in data:
    uid_sets[d] = set(list(data[d].keys()))

In [4]:
uid_doc_set = set([])
for d in range(1, 15):
    uid_doc_set = uid_doc_set | uid_sets[d]

In [5]:
data_doc = dict({})
data_qry_train = dict({})
data_qry_test = dict({})

for d in range(1, 15):
    for uid in data[d]:
        if uid not in data_doc:
            data_doc[uid] = [data[d][uid]]
        else:
            data_doc[uid].append(data[d][uid])

for d in range(15, 22):
    for uid in data[d]:
        if uid not in data_qry_train:
            data_qry_train[uid] = [data[d][uid]]
        else:
            data_qry_train[uid].append(data[d][uid])
            
for d in range(22, 29):
    for uid in data[d]:
        if uid not in data_qry_test:
            data_qry_test[uid] = [data[d][uid]]
        else:
            data_qry_test[uid].append(data[d][uid])

In [6]:
for uid in data_doc:
    data_doc[uid] = torch.LongTensor(data_doc[uid]).cuda(0)

for uid in data_qry_train:
    data_qry_train[uid] = torch.LongTensor(data_qry_train[uid]).cuda(0)
    
for uid in data_qry_test:
    data_qry_test[uid] = torch.LongTensor(data_qry_test[uid]).cuda(0)

In [7]:
num_locs = 1600
loc_embedding_dim = 128
T = 96
num_time = T
time_embedding_dim = 32
hidden_dim = 256
latent_dim = 256

In [8]:
lr = 1e-3
momentum = 0.8

In [9]:
local_predictor = LocalPredictorFullSearch(num_locs, loc_embedding_dim, num_time, time_embedding_dim, hidden_dim, latent_dim).cuda(0)
local_predictor_server = LocalPredictorFullSearch(num_locs, loc_embedding_dim, num_time, time_embedding_dim, hidden_dim, latent_dim).cuda(0)
optimizer = torch.optim.SGD(local_predictor.parameters(), lr=lr)
optimizer.zero_grad()

In [10]:
local_predictor_server.load_state_dict(torch.load('./results_osaka/local_predictor_full_search_oct_4096.pytorch').state_dict())

FileNotFoundError: [Errno 2] No such file or directory: './results_osaka/local_predictor_full_search_oct_4096.pytorch'

In [None]:
dT = 4

In [None]:
training_loss = dict({})
validation_loss = dict({})

In [16]:
user_list_train = list(data_qry_train.keys())
time_list = list(range(T - 2 * dT + 1))

# Initialize global SGD
local_predictor_update = dict({})
        
for epoch in range(1, 2001):
    
    if epoch % 500 == 0:
        lr *= 0.5
    
    update_user_list = random.sample(user_list_train, 2048)
    random.shuffle(update_user_list)
    
    avg_loss = 0.0
    cnt = 0
    n = 0
    
    uidx = 0
    
    for uid in update_user_list:
        uidx += 1
        # loading server model to local
        local_predictor.load_state_dict(local_predictor_server.state_dict())

        nk = data_qry_train[uid].shape[0]
        n += data_qry_train[uid].shape[0]

        t = np.random.randint(T - 2 * dT + 1)
        x_loc_qry = data_qry_train[uid][:, t: t + dT]
        x_t_qry = torch.zeros_like(x_loc_qry) + t
        y = data_qry_train[uid][:, t + 2 * dT - 1]

        if uid not in data_doc:
            loss = local_predictor(x_loc_qry, x_t_qry, None, None, y)
        else:
            x_loc_doc = data_doc[uid][:, t: t + 2 * dT]
            x_t_doc = torch.zeros_like(x_loc_doc) + t
            loss = local_predictor(x_loc_qry, x_t_qry, x_loc_doc, x_t_doc, y)

        loss.backward()
        avg_loss += loss.item()
        cnt += nk
        
        with torch.no_grad():
            for name, para in local_predictor.named_parameters():
                if para.grad is None:
                    continue
                elif name not in local_predictor_update:
                    local_predictor_update[name] = (para.grad + torch.randn_like(para.grad) * 1e-2) * lr
                else:
                    local_predictor_update[name] += (para.grad + torch.randn_like(para.grad) * 1e-2) * lr
                
        optimizer.zero_grad()
        
        print('Epoch {:02d}, {:.1f}%, avg_loss = {:.4f}'.format(epoch, uidx * 100 / len(update_user_list), avg_loss / cnt), end='\r')

    # update
    with torch.no_grad():
        for name, para in local_predictor_server.named_parameters():
            para -= local_predictor_update[name]
            
        for name in local_predictor_update:
            local_predictor_update[name] *= momentum
    
    training_loss[epoch] = avg_loss / cnt
    print('')
    # testing
    
    if epoch % 5 == 0:
        with torch.no_grad():

            cnt = 0
            avg_loss = 0.0

            for uid in data_qry_test:

                if np.random.ranf() > 0.05:
                    continue

                nk = data_qry_test[uid].shape[0]

                t = np.random.randint(T - 2 * dT + 1)
                x_loc_qry = data_qry_test[uid][:, t: t + dT]
                x_t_qry = torch.zeros_like(x_loc_qry) + t
                y = data_qry_test[uid][:, t + 2 * dT - 1]

                if uid not in data_doc:
                    loss = local_predictor_server(x_loc_qry, x_t_qry, None, None, y)
                else:
                    x_loc_doc = data_doc[uid][:, t: t + 2 * dT]
                    x_t_doc = torch.zeros_like(x_loc_doc) + t
                    loss = local_predictor_server(x_loc_qry, x_t_qry, x_loc_doc, x_t_doc, y)

                avg_loss += loss.item()
                cnt += nk

            print('Validation Loss: {:.4f}'.format(avg_loss / cnt))
            validation_loss[epoch] = avg_loss / cnt

Epoch 2000, 100.0%, avg_loss = 1.4328
Validation Loss: 1.5613
Epoch 2001, 100.0%, avg_loss = 1.4631
Epoch 2002, 100.0%, avg_loss = 1.5110
Epoch 2003, 100.0%, avg_loss = 1.5575
Epoch 2004, 100.0%, avg_loss = 1.5032
Epoch 2005, 100.0%, avg_loss = 1.4924
Validation Loss: 1.5054
Epoch 2006, 100.0%, avg_loss = 1.4662
Epoch 2007, 100.0%, avg_loss = 1.4960
Epoch 2008, 100.0%, avg_loss = 1.4797
Epoch 2009, 100.0%, avg_loss = 1.5138
Epoch 2010, 100.0%, avg_loss = 1.5388
Validation Loss: 1.5394
Epoch 2011, 100.0%, avg_loss = 1.5478
Epoch 2012, 100.0%, avg_loss = 1.4758
Epoch 2013, 100.0%, avg_loss = 1.5442
Epoch 2014, 100.0%, avg_loss = 1.4074
Epoch 2015, 100.0%, avg_loss = 1.5605
Validation Loss: 1.4903
Epoch 2016, 100.0%, avg_loss = 1.4527
Epoch 2017, 100.0%, avg_loss = 1.4450
Epoch 2018, 100.0%, avg_loss = 1.5062
Epoch 2019, 100.0%, avg_loss = 1.5322
Epoch 2020, 100.0%, avg_loss = 1.5316
Validation Loss: 1.4806
Epoch 2021, 100.0%, avg_loss = 1.5072
Epoch 2022, 100.0%, avg_loss = 1.4288
Epoch 

Epoch 2191, 100.0%, avg_loss = 1.5052
Epoch 2192, 100.0%, avg_loss = 1.5452
Epoch 2193, 100.0%, avg_loss = 1.4109
Epoch 2194, 100.0%, avg_loss = 1.5225
Epoch 2195, 100.0%, avg_loss = 1.4954
Validation Loss: 1.4825
Epoch 2196, 100.0%, avg_loss = 1.4757
Epoch 2197, 100.0%, avg_loss = 1.4317
Epoch 2198, 100.0%, avg_loss = 1.4831
Epoch 2199, 100.0%, avg_loss = 1.5369
Epoch 2200, 100.0%, avg_loss = 1.4738
Validation Loss: 1.5497
Epoch 2201, 100.0%, avg_loss = 1.5290
Epoch 2202, 100.0%, avg_loss = 1.4622
Epoch 2203, 100.0%, avg_loss = 1.5217
Epoch 2204, 100.0%, avg_loss = 1.4345
Epoch 2205, 100.0%, avg_loss = 1.5196
Validation Loss: 1.5119
Epoch 2206, 100.0%, avg_loss = 1.5034
Epoch 2207, 100.0%, avg_loss = 1.5447
Epoch 2208, 100.0%, avg_loss = 1.4679
Epoch 2209, 100.0%, avg_loss = 1.3737
Epoch 2210, 100.0%, avg_loss = 1.5438
Validation Loss: 1.4963
Epoch 2211, 100.0%, avg_loss = 1.4815
Epoch 2212, 100.0%, avg_loss = 1.5387
Epoch 2213, 100.0%, avg_loss = 1.5391
Epoch 2214, 100.0%, avg_loss =

Epoch 2382, 100.0%, avg_loss = 1.5157
Epoch 2383, 100.0%, avg_loss = 1.5175
Epoch 2384, 100.0%, avg_loss = 1.4856
Epoch 2385, 100.0%, avg_loss = 1.4978
Validation Loss: 1.4784
Epoch 2386, 100.0%, avg_loss = 1.4875
Epoch 2387, 100.0%, avg_loss = 1.4936
Epoch 2388, 100.0%, avg_loss = 1.5271
Epoch 2389, 100.0%, avg_loss = 1.4984
Epoch 2390, 100.0%, avg_loss = 1.4465
Validation Loss: 1.4915
Epoch 2391, 100.0%, avg_loss = 1.5219
Epoch 2392, 100.0%, avg_loss = 1.4527
Epoch 2393, 100.0%, avg_loss = 1.5369
Epoch 2394, 100.0%, avg_loss = 1.4701
Epoch 2395, 100.0%, avg_loss = 1.5098
Validation Loss: 1.5549
Epoch 2396, 100.0%, avg_loss = 1.4818
Epoch 2397, 100.0%, avg_loss = 1.4639
Epoch 2398, 100.0%, avg_loss = 1.5427
Epoch 2399, 100.0%, avg_loss = 1.5175
Epoch 2400, 100.0%, avg_loss = 1.5142
Validation Loss: 1.4984
Epoch 2401, 100.0%, avg_loss = 1.5319
Epoch 2402, 100.0%, avg_loss = 1.5150
Epoch 2403, 100.0%, avg_loss = 1.4410
Epoch 2404, 100.0%, avg_loss = 1.4875
Epoch 2405, 100.0%, avg_loss =

Epoch 2573, 100.0%, avg_loss = 1.5423
Epoch 2574, 100.0%, avg_loss = 1.4189
Epoch 2575, 100.0%, avg_loss = 1.3954
Validation Loss: 1.5140
Epoch 2576, 100.0%, avg_loss = 1.4997
Epoch 2577, 100.0%, avg_loss = 1.4996
Epoch 2578, 100.0%, avg_loss = 1.4614
Epoch 2579, 100.0%, avg_loss = 1.4469
Epoch 2580, 100.0%, avg_loss = 1.4779
Validation Loss: 1.4865
Epoch 2581, 100.0%, avg_loss = 1.4980
Epoch 2582, 100.0%, avg_loss = 1.4874
Epoch 2583, 100.0%, avg_loss = 1.4936
Epoch 2584, 100.0%, avg_loss = 1.3981
Epoch 2585, 100.0%, avg_loss = 1.4718
Validation Loss: 1.5490
Epoch 2586, 100.0%, avg_loss = 1.5290
Epoch 2587, 100.0%, avg_loss = 1.4925
Epoch 2588, 100.0%, avg_loss = 1.4407
Epoch 2589, 100.0%, avg_loss = 1.5086
Epoch 2590, 100.0%, avg_loss = 1.4833
Validation Loss: 1.4955
Epoch 2591, 100.0%, avg_loss = 1.4123
Epoch 2592, 100.0%, avg_loss = 1.5111
Epoch 2593, 100.0%, avg_loss = 1.5364
Epoch 2594, 100.0%, avg_loss = 1.4800
Epoch 2595, 100.0%, avg_loss = 1.4273
Validation Loss: 1.5348
Epoch 

Epoch 2764, 100.0%, avg_loss = 1.4805
Epoch 2765, 100.0%, avg_loss = 1.4725
Validation Loss: 1.4901
Epoch 2766, 100.0%, avg_loss = 1.4396
Epoch 2767, 100.0%, avg_loss = 1.4417
Epoch 2768, 100.0%, avg_loss = 1.4417
Epoch 2769, 100.0%, avg_loss = 1.4407
Epoch 2770, 100.0%, avg_loss = 1.5476
Validation Loss: 1.5083
Epoch 2771, 100.0%, avg_loss = 1.4743
Epoch 2772, 100.0%, avg_loss = 1.4894
Epoch 2773, 100.0%, avg_loss = 1.4226
Epoch 2774, 100.0%, avg_loss = 1.4685
Epoch 2775, 100.0%, avg_loss = 1.4797
Validation Loss: 1.4581
Epoch 2776, 100.0%, avg_loss = 1.4375
Epoch 2777, 100.0%, avg_loss = 1.4602
Epoch 2778, 100.0%, avg_loss = 1.4705
Epoch 2779, 100.0%, avg_loss = 1.5081
Epoch 2780, 100.0%, avg_loss = 1.4877
Validation Loss: 1.5332
Epoch 2781, 100.0%, avg_loss = 1.5010
Epoch 2782, 100.0%, avg_loss = 1.4774
Epoch 2783, 100.0%, avg_loss = 1.4424
Epoch 2784, 100.0%, avg_loss = 1.4518
Epoch 2785, 100.0%, avg_loss = 1.4830
Validation Loss: 1.5068
Epoch 2786, 100.0%, avg_loss = 1.4979
Epoch 

Epoch 2955, 100.0%, avg_loss = 1.4461
Validation Loss: 1.4642
Epoch 2956, 100.0%, avg_loss = 1.5020
Epoch 2957, 100.0%, avg_loss = 1.5001
Epoch 2958, 100.0%, avg_loss = 1.4744
Epoch 2959, 100.0%, avg_loss = 1.5163
Epoch 2960, 100.0%, avg_loss = 1.4580
Validation Loss: 1.5179
Epoch 2961, 100.0%, avg_loss = 1.5815
Epoch 2962, 100.0%, avg_loss = 1.4250
Epoch 2963, 100.0%, avg_loss = 1.5227
Epoch 2964, 100.0%, avg_loss = 1.4248
Epoch 2965, 100.0%, avg_loss = 1.5143
Validation Loss: 1.4383
Epoch 2966, 100.0%, avg_loss = 1.5087
Epoch 2967, 100.0%, avg_loss = 1.4085
Epoch 2968, 100.0%, avg_loss = 1.5162
Epoch 2969, 100.0%, avg_loss = 1.5024
Epoch 2970, 100.0%, avg_loss = 1.5353
Validation Loss: 1.4982
Epoch 2971, 100.0%, avg_loss = 1.4756
Epoch 2972, 100.0%, avg_loss = 1.3793
Epoch 2973, 100.0%, avg_loss = 1.5210
Epoch 2974, 100.0%, avg_loss = 1.4223
Epoch 2975, 100.0%, avg_loss = 1.4721
Validation Loss: 1.4982
Epoch 2976, 100.0%, avg_loss = 1.4870
Epoch 2977, 100.0%, avg_loss = 1.5222
Epoch 

In [17]:
torch.save(local_predictor, './results_osaka/local_pretrained_4096_predictor_full_search_fl.pytorch')

In [18]:
with open('./results_osaka/federated_pretrained_4096_local_predictor_full_search_training_loss.pk', 'wb') as f:
    pk.dump(training_loss, f)
with open('./results_osaka/federated_pretrained_4096_local_predictor_full_search_inc_1024_validation_loss.pk', 'wb') as f:
    pk.dump(validation_loss, f)