In [1]:
import torch
import torch.nn as nn
import pickle as pk
import numpy as np
from model import LocalPredictor
import random

In [2]:
data = dict({})
for d in range(1, 29):
    filename = '/data/fan/UsersInTokyoProcessed/201210{:02d}_interp.pk'.format(d)
    print(filename)
    with open(filename, 'rb') as f:
        data[d] = pk.load(f)

/data/fan/UsersInTokyoProcessed/20121001_interp.pk
/data/fan/UsersInTokyoProcessed/20121002_interp.pk
/data/fan/UsersInTokyoProcessed/20121003_interp.pk
/data/fan/UsersInTokyoProcessed/20121004_interp.pk
/data/fan/UsersInTokyoProcessed/20121005_interp.pk
/data/fan/UsersInTokyoProcessed/20121006_interp.pk
/data/fan/UsersInTokyoProcessed/20121007_interp.pk
/data/fan/UsersInTokyoProcessed/20121008_interp.pk
/data/fan/UsersInTokyoProcessed/20121009_interp.pk
/data/fan/UsersInTokyoProcessed/20121010_interp.pk
/data/fan/UsersInTokyoProcessed/20121011_interp.pk
/data/fan/UsersInTokyoProcessed/20121012_interp.pk
/data/fan/UsersInTokyoProcessed/20121013_interp.pk
/data/fan/UsersInTokyoProcessed/20121014_interp.pk
/data/fan/UsersInTokyoProcessed/20121015_interp.pk
/data/fan/UsersInTokyoProcessed/20121016_interp.pk
/data/fan/UsersInTokyoProcessed/20121017_interp.pk
/data/fan/UsersInTokyoProcessed/20121018_interp.pk
/data/fan/UsersInTokyoProcessed/20121019_interp.pk
/data/fan/UsersInTokyoProcessed

In [3]:
uid_sets = dict({})
for d in data:
    uid_sets[d] = set(list(data[d].keys()))

In [4]:
uid_doc_set = set([])
for d in range(1, 15):
    uid_doc_set = uid_doc_set | uid_sets[d]

In [5]:
data_doc = dict({})
data_qry_train = dict({})
data_qry_test = dict({})

for d in range(1, 15):
    for uid in data[d]:
        if uid not in data_doc:
            data_doc[uid] = [data[d][uid]]
        else:
            data_doc[uid].append(data[d][uid])

for d in range(15, 22):
    for uid in data[d]:
        if uid not in data_qry_train:
            data_qry_train[uid] = [data[d][uid]]
        else:
            data_qry_train[uid].append(data[d][uid])
            
for d in range(22, 29):
    for uid in data[d]:
        if uid not in data_qry_test:
            data_qry_test[uid] = [data[d][uid]]
        else:
            data_qry_test[uid].append(data[d][uid])

In [6]:
for uid in data_doc:
    data_doc[uid] = torch.LongTensor(data_doc[uid]).cuda(0)

for uid in data_qry_train:
    data_qry_train[uid] = torch.LongTensor(data_qry_train[uid]).cuda(0)
    
for uid in data_qry_test:
    data_qry_test[uid] = torch.LongTensor(data_qry_test[uid]).cuda(0)

In [7]:
num_locs = 1600
loc_embedding_dim = 128
T = 96
num_time = T
time_embedding_dim = 32
hidden_dim = 256
latent_dim = 256

In [8]:
lr = 1e-3
momentum = 0.8

In [20]:
local_predictor = LocalPredictor(num_locs, loc_embedding_dim, num_time, time_embedding_dim, hidden_dim, latent_dim).cuda(0)
local_predictor_server = LocalPredictor(num_locs, loc_embedding_dim, num_time, time_embedding_dim, hidden_dim, latent_dim).cuda(0)
optimizer = torch.optim.SGD(local_predictor.parameters(), lr=lr)
optimizer.zero_grad()

In [21]:
local_predictor_server.load_state_dict(torch.load('./results_tokyo/local_predictor_open.pytorch').state_dict())

In [22]:
dT = 4

In [23]:
training_loss = dict({})
validation_loss = dict({})

In [27]:
user_list_train = list(data_qry_train.keys())
time_list = list(range(T - 2 * dT + 1))

# Initialize global SGD
local_predictor_update = dict({})
        
for epoch in range(2001, 3001):
    
    if epoch % 500 == 0:
        lr *= 0.5
    
    update_user_list = random.sample(user_list_train, 2048)
    random.shuffle(update_user_list)
    
    avg_loss = 0.0
    cnt = 0
    n = 0
    
    uidx = 0
    
    for uid in update_user_list:
        uidx += 1
        # loading server model to local
        local_predictor.load_state_dict(local_predictor_server.state_dict())

        nk = data_qry_train[uid].shape[0]
        n += data_qry_train[uid].shape[0]

        t = np.random.randint(T - 2 * dT + 1)
        x_loc_qry = data_qry_train[uid][:, t: t + dT]
        x_t_qry = torch.zeros_like(x_loc_qry) + t
        y = data_qry_train[uid][:, t + 2 * dT - 1]

        if uid not in data_doc:
            loss = local_predictor(x_loc_qry, x_t_qry, None, None, y)
        else:
            x_loc_doc = data_doc[uid][:, t: t + 2 * dT]
            x_t_doc = torch.zeros_like(x_loc_doc) + t
            loss = local_predictor(x_loc_qry, x_t_qry, x_loc_doc, x_t_doc, y)

        loss.backward()
        avg_loss += loss.item()
        cnt += nk
        
        with torch.no_grad():
            for name, para in local_predictor.named_parameters():
                if para.grad is None:
                    continue
                elif name not in local_predictor_update:
                    local_predictor_update[name] = (para.grad + torch.randn_like(para.grad) * 1e-2) * lr
                else:
                    local_predictor_update[name] += (para.grad + torch.randn_like(para.grad) * 1e-2) * lr
                
        optimizer.zero_grad()
        
        print('Epoch {:02d}, {:.1f}%, avg_loss = {:.4f}'.format(epoch, uidx * 100 / len(update_user_list), avg_loss / cnt), end='\r')

    # update
    with torch.no_grad():
        for name, para in local_predictor_server.named_parameters():
            para -= local_predictor_update[name]
            
        for name in local_predictor_update:
            local_predictor_update[name] *= momentum
    
    training_loss[epoch] = avg_loss / cnt
    print('')
    # testing
    
    if epoch % 5 == 0:
        with torch.no_grad():

            cnt = 0
            avg_loss = 0.0

            for uid in data_qry_test:

                if np.random.ranf() > 0.05:
                    continue

                nk = data_qry_test[uid].shape[0]

                t = np.random.randint(T - 2 * dT + 1)
                x_loc_qry = data_qry_test[uid][:, t: t + dT]
                x_t_qry = torch.zeros_like(x_loc_qry) + t
                y = data_qry_test[uid][:, t + 2 * dT - 1]

                if uid not in data_doc:
                    loss = local_predictor_server(x_loc_qry, x_t_qry, None, None, y)
                else:
                    x_loc_doc = data_doc[uid][:, t: t + 2 * dT]
                    x_t_doc = torch.zeros_like(x_loc_doc) + t
                    loss = local_predictor_server(x_loc_qry, x_t_qry, x_loc_doc, x_t_doc, y)

                avg_loss += loss.item()
                cnt += nk

            print('Validation Loss: {:.4f}'.format(avg_loss / cnt))
            validation_loss[epoch] = avg_loss / cnt

Epoch 2001, 100.0%, avg_loss = 1.5154
Epoch 2002, 100.0%, avg_loss = 1.5615
Epoch 2003, 100.0%, avg_loss = 1.5015
Epoch 2004, 100.0%, avg_loss = 1.5619
Epoch 2005, 100.0%, avg_loss = 1.4334
Validation Loss: 1.5480
Epoch 2006, 100.0%, avg_loss = 1.5662
Epoch 2007, 100.0%, avg_loss = 1.5925
Epoch 2008, 100.0%, avg_loss = 1.4826
Epoch 2009, 100.0%, avg_loss = 1.4964
Epoch 2010, 100.0%, avg_loss = 1.5007
Validation Loss: 1.5312
Epoch 2011, 100.0%, avg_loss = 1.5122
Epoch 2012, 100.0%, avg_loss = 1.4908
Epoch 2013, 100.0%, avg_loss = 1.5643
Epoch 2014, 100.0%, avg_loss = 1.4965
Epoch 2015, 100.0%, avg_loss = 1.5107
Validation Loss: 1.5410
Epoch 2016, 100.0%, avg_loss = 1.4971
Epoch 2017, 100.0%, avg_loss = 1.5687
Epoch 2018, 100.0%, avg_loss = 1.5178
Epoch 2019, 100.0%, avg_loss = 1.5049
Epoch 2020, 100.0%, avg_loss = 1.5196
Validation Loss: 1.5411
Epoch 2021, 100.0%, avg_loss = 1.5615
Epoch 2022, 100.0%, avg_loss = 1.5617
Epoch 2023, 100.0%, avg_loss = 1.5399
Epoch 2024, 100.0%, avg_loss =

Epoch 2192, 100.0%, avg_loss = 1.5641
Epoch 2193, 100.0%, avg_loss = 1.4942
Epoch 2194, 100.0%, avg_loss = 1.5011
Epoch 2195, 100.0%, avg_loss = 1.5830
Validation Loss: 1.5236
Epoch 2196, 100.0%, avg_loss = 1.5278
Epoch 2197, 100.0%, avg_loss = 1.4985
Epoch 2198, 100.0%, avg_loss = 1.5144
Epoch 2199, 100.0%, avg_loss = 1.5081
Epoch 2200, 100.0%, avg_loss = 1.4687
Validation Loss: 1.5702
Epoch 2201, 100.0%, avg_loss = 1.5362
Epoch 2202, 100.0%, avg_loss = 1.5384
Epoch 2203, 100.0%, avg_loss = 1.5533
Epoch 2204, 100.0%, avg_loss = 1.5360
Epoch 2205, 100.0%, avg_loss = 1.4947
Validation Loss: 1.4994
Epoch 2206, 100.0%, avg_loss = 1.5116
Epoch 2207, 100.0%, avg_loss = 1.5082
Epoch 2208, 100.0%, avg_loss = 1.4616
Epoch 2209, 100.0%, avg_loss = 1.5932
Epoch 2210, 100.0%, avg_loss = 1.4159
Validation Loss: 1.5957
Epoch 2211, 100.0%, avg_loss = 1.4854
Epoch 2212, 100.0%, avg_loss = 1.5196
Epoch 2213, 100.0%, avg_loss = 1.4938
Epoch 2214, 100.0%, avg_loss = 1.4494
Epoch 2215, 100.0%, avg_loss =

Epoch 2383, 100.0%, avg_loss = 1.5262
Epoch 2384, 100.0%, avg_loss = 1.5318
Epoch 2385, 100.0%, avg_loss = 1.5910
Validation Loss: 1.5258
Epoch 2386, 100.0%, avg_loss = 1.5420
Epoch 2387, 100.0%, avg_loss = 1.4921
Epoch 2388, 100.0%, avg_loss = 1.4910
Epoch 2389, 100.0%, avg_loss = 1.4138
Epoch 2390, 100.0%, avg_loss = 1.4860
Validation Loss: 1.5234
Epoch 2391, 100.0%, avg_loss = 1.5164
Epoch 2392, 100.0%, avg_loss = 1.5547
Epoch 2393, 100.0%, avg_loss = 1.5130
Epoch 2394, 100.0%, avg_loss = 1.4624
Epoch 2395, 100.0%, avg_loss = 1.4800
Validation Loss: 1.5747
Epoch 2396, 100.0%, avg_loss = 1.4758
Epoch 2397, 100.0%, avg_loss = 1.5793
Epoch 2398, 100.0%, avg_loss = 1.4697
Epoch 2399, 100.0%, avg_loss = 1.5546
Epoch 2400, 100.0%, avg_loss = 1.6150
Validation Loss: 1.5197
Epoch 2401, 100.0%, avg_loss = 1.5157
Epoch 2402, 100.0%, avg_loss = 1.5006
Epoch 2403, 100.0%, avg_loss = 1.5699
Epoch 2404, 100.0%, avg_loss = 1.4578
Epoch 2405, 100.0%, avg_loss = 1.5690
Validation Loss: 1.5234
Epoch 

Epoch 2574, 100.0%, avg_loss = 1.5059
Epoch 2575, 100.0%, avg_loss = 1.5373
Validation Loss: 1.5194
Epoch 2576, 100.0%, avg_loss = 1.5462
Epoch 2577, 100.0%, avg_loss = 1.4745
Epoch 2578, 100.0%, avg_loss = 1.4825
Epoch 2579, 100.0%, avg_loss = 1.5020
Epoch 2580, 100.0%, avg_loss = 1.4795
Validation Loss: 1.5250
Epoch 2581, 100.0%, avg_loss = 1.4803
Epoch 2582, 100.0%, avg_loss = 1.4830
Epoch 2583, 100.0%, avg_loss = 1.4970
Epoch 2584, 100.0%, avg_loss = 1.4859
Epoch 2585, 100.0%, avg_loss = 1.4893
Validation Loss: 1.5091
Epoch 2586, 100.0%, avg_loss = 1.4960
Epoch 2587, 100.0%, avg_loss = 1.5538
Epoch 2588, 100.0%, avg_loss = 1.5300
Epoch 2589, 100.0%, avg_loss = 1.5089
Epoch 2590, 100.0%, avg_loss = 1.5181
Validation Loss: 1.5589
Epoch 2591, 100.0%, avg_loss = 1.5604
Epoch 2592, 100.0%, avg_loss = 1.5365
Epoch 2593, 100.0%, avg_loss = 1.5087
Epoch 2594, 100.0%, avg_loss = 1.4975
Epoch 2595, 100.0%, avg_loss = 1.4632
Validation Loss: 1.5426
Epoch 2596, 100.0%, avg_loss = 1.4701
Epoch 

Epoch 2765, 100.0%, avg_loss = 1.4892
Validation Loss: 1.5621
Epoch 2766, 100.0%, avg_loss = 1.5239
Epoch 2767, 100.0%, avg_loss = 1.4712
Epoch 2768, 100.0%, avg_loss = 1.4549
Epoch 2769, 100.0%, avg_loss = 1.4902
Epoch 2770, 100.0%, avg_loss = 1.4978
Validation Loss: 1.5289
Epoch 2771, 100.0%, avg_loss = 1.5615
Epoch 2772, 100.0%, avg_loss = 1.5416
Epoch 2773, 100.0%, avg_loss = 1.5027
Epoch 2774, 100.0%, avg_loss = 1.4832
Epoch 2775, 100.0%, avg_loss = 1.4487
Validation Loss: 1.5001
Epoch 2776, 100.0%, avg_loss = 1.4724
Epoch 2777, 100.0%, avg_loss = 1.4824
Epoch 2778, 100.0%, avg_loss = 1.4782
Epoch 2779, 100.0%, avg_loss = 1.4808
Epoch 2780, 100.0%, avg_loss = 1.4948
Validation Loss: 1.5296
Epoch 2781, 100.0%, avg_loss = 1.5006
Epoch 2782, 100.0%, avg_loss = 1.4811
Epoch 2783, 100.0%, avg_loss = 1.4763
Epoch 2784, 100.0%, avg_loss = 1.5353
Epoch 2785, 100.0%, avg_loss = 1.5001
Validation Loss: 1.5062
Epoch 2786, 100.0%, avg_loss = 1.4881
Epoch 2787, 100.0%, avg_loss = 1.5211
Epoch 

Epoch 2956, 100.0%, avg_loss = 1.4823
Epoch 2957, 100.0%, avg_loss = 1.4888
Epoch 2958, 100.0%, avg_loss = 1.4600
Epoch 2959, 100.0%, avg_loss = 1.4996
Epoch 2960, 100.0%, avg_loss = 1.5194
Validation Loss: 1.5289
Epoch 2961, 100.0%, avg_loss = 1.5841
Epoch 2962, 100.0%, avg_loss = 1.5217
Epoch 2963, 100.0%, avg_loss = 1.5260
Epoch 2964, 100.0%, avg_loss = 1.5142
Epoch 2965, 100.0%, avg_loss = 1.5325
Validation Loss: 1.5758
Epoch 2966, 100.0%, avg_loss = 1.4611
Epoch 2967, 100.0%, avg_loss = 1.5117
Epoch 2968, 100.0%, avg_loss = 1.4663
Epoch 2969, 100.0%, avg_loss = 1.5219
Epoch 2970, 100.0%, avg_loss = 1.4702
Validation Loss: 1.5306
Epoch 2971, 100.0%, avg_loss = 1.4924
Epoch 2972, 100.0%, avg_loss = 1.3768
Epoch 2973, 100.0%, avg_loss = 1.5291
Epoch 2974, 100.0%, avg_loss = 1.5364
Epoch 2975, 100.0%, avg_loss = 1.6037
Validation Loss: 1.5082
Epoch 2976, 100.0%, avg_loss = 1.4629
Epoch 2977, 100.0%, avg_loss = 1.5760
Epoch 2978, 100.0%, avg_loss = 1.4569
Epoch 2979, 100.0%, avg_loss =

In [28]:
torch.save(local_predictor, './results_tokyo/local_predictor_fl_inc_open.pytorch')

In [30]:
with open('./results_tokyo/federated_pretrained_local_predictor_open_training_loss.pk', 'wb') as f:
    pk.dump(training_loss, f)
with open('./results_tokyo/federated_pretrained_local_predictor_open_validation_loss.pk', 'wb') as f:
    pk.dump(validation_loss, f)