In [1]:
import torch
import torch.nn as nn
import pickle as pk
import numpy as np
from model import GlobalPredictor
import random

In [2]:
data = dict({})
for d in range(15, 29):
    filename = '/data/fan/UsersInOsakaProcessed/201210{:02d}_interp.pk'.format(d)
    print(filename)
    with open(filename, 'rb') as f:
        data[d] = pk.load(f)

/data/fan/UsersInOsakaProcessed/20121015_interp.pk
/data/fan/UsersInOsakaProcessed/20121016_interp.pk
/data/fan/UsersInOsakaProcessed/20121017_interp.pk
/data/fan/UsersInOsakaProcessed/20121018_interp.pk
/data/fan/UsersInOsakaProcessed/20121019_interp.pk
/data/fan/UsersInOsakaProcessed/20121020_interp.pk
/data/fan/UsersInOsakaProcessed/20121021_interp.pk
/data/fan/UsersInOsakaProcessed/20121022_interp.pk
/data/fan/UsersInOsakaProcessed/20121023_interp.pk
/data/fan/UsersInOsakaProcessed/20121024_interp.pk
/data/fan/UsersInOsakaProcessed/20121025_interp.pk
/data/fan/UsersInOsakaProcessed/20121026_interp.pk
/data/fan/UsersInOsakaProcessed/20121027_interp.pk
/data/fan/UsersInOsakaProcessed/20121028_interp.pk


In [3]:
data_qry_train = dict({})
data_qry_test = dict({})

for d in range(15, 22):
    for uid in data[d]:
        if uid not in data_qry_train:
            data_qry_train[uid] = [data[d][uid]]
        else:
            data_qry_train[uid].append(data[d][uid])
            
for d in range(22, 29):
    for uid in data[d]:
        if uid not in data_qry_test:
            data_qry_test[uid] = [data[d][uid]]
        else:
            data_qry_test[uid].append(data[d][uid])

In [4]:
for uid in data_qry_train:
    data_qry_train[uid] = torch.LongTensor(data_qry_train[uid]).cuda(0)
    
for uid in data_qry_test:
    data_qry_test[uid] = torch.LongTensor(data_qry_test[uid]).cuda(0)

In [5]:
num_locs = 1600
loc_embedding_dim = 128
T = 96
num_time = T
time_embedding_dim = 32
hidden_dim = 256
latent_dim = 256

In [6]:
lr = 5e-4
momentum = 0.8

In [7]:
global_predictor = GlobalPredictor(num_locs, loc_embedding_dim, num_time, time_embedding_dim, hidden_dim, latent_dim).cuda(0)
global_predictor_server = GlobalPredictor(num_locs, loc_embedding_dim, num_time, time_embedding_dim, hidden_dim, latent_dim).cuda(0)
optimizer = torch.optim.SGD(global_predictor.parameters(), lr=lr)

In [8]:
dT = 4

In [9]:
training_loss = dict({})
validation_loss = dict({})

In [10]:
user_list_train = list(data_qry_train.keys())
time_list = list(range(T - 2 * dT + 1))

# Initialize global SGD
global_predictor_update = dict({})
        
for epoch in range(1, 2001):
    
    if epoch % 500 == 0:
        lr *= 0.5
    
    update_user_list = random.sample(user_list_train, 2048)
    random.shuffle(update_user_list)
    
    avg_loss = 0.0
    cnt = 0
    n = 0
    
    uidx = 0
    
    for uid in update_user_list:
        uidx += 1
        # loading server model to local
        global_predictor.load_state_dict(global_predictor_server.state_dict())

        nk = data_qry_train[uid].shape[0]
        n += data_qry_train[uid].shape[0]

        t = np.random.randint(T - 2 * dT + 1)
        x_loc_qry = data_qry_train[uid][:, t: t + dT]
        x_t_qry = torch.zeros_like(x_loc_qry) + t
        y = data_qry_train[uid][:, t + 2 * dT - 1]

        loss = global_predictor(x_loc_qry, x_t_qry, y)
        loss.backward()
        avg_loss += loss.item()
        cnt += nk
        
        with torch.no_grad():
            for name, para in global_predictor.named_parameters():
                if name not in global_predictor_update:
                    global_predictor_update[name] = para.grad * lr
                else:
                    global_predictor_update[name] += para.grad * lr
                
        optimizer.zero_grad()
        
        print('Epoch {:02d}, {:.1f}%, avg_loss = {:.4f}'.format(epoch, uidx * 100 / len(update_user_list), avg_loss / cnt), end='\r')

    # update
    with torch.no_grad():
        for name, para in global_predictor_server.named_parameters():
            para -= global_predictor_update[name]
            
        for name in global_predictor_update:
            global_predictor_update[name] *= momentum
    
    training_loss[epoch] = avg_loss / cnt
    print('')
    # testing
    
    if epoch % 5 == 0:
        with torch.no_grad():

            cnt = 0
            avg_loss = 0.0

            for uid in data_qry_test:

                if np.random.ranf() > 0.05:
                    continue

                nk = data_qry_test[uid].shape[0]

                t = np.random.randint(T - 2 * dT + 1)
                x_loc_qry = data_qry_test[uid][:, t: t + dT]
                x_t_qry = torch.zeros_like(x_loc_qry) + t
                y = data_qry_test[uid][:, t + 2 * dT - 1]

                loss = global_predictor_server(x_loc_qry, x_t_qry, y)

                avg_loss += loss.item()
                cnt += nk

            print('Validation Loss: {:.4f}'.format(avg_loss / cnt))
            validation_loss[epoch] = avg_loss / cnt

Epoch 01, 100.0%, avg_loss = 7.3859
Epoch 02, 100.0%, avg_loss = 7.3340
Epoch 03, 100.0%, avg_loss = 7.2415
Epoch 04, 100.0%, avg_loss = 7.0446
Epoch 05, 100.0%, avg_loss = 6.9063
Validation Loss: 6.5875
Epoch 06, 100.0%, avg_loss = 6.6321
Epoch 07, 100.0%, avg_loss = 6.3155
Epoch 08, 100.0%, avg_loss = 6.0075
Epoch 09, 100.0%, avg_loss = 5.7423
Epoch 10, 100.0%, avg_loss = 5.2845
Validation Loss: 4.7598
Epoch 11, 100.0%, avg_loss = 4.7604
Epoch 12, 100.0%, avg_loss = 4.4010
Epoch 13, 100.0%, avg_loss = 4.0416
Epoch 14, 100.0%, avg_loss = 3.6594
Epoch 15, 100.0%, avg_loss = 3.4762
Validation Loss: 3.2698
Epoch 16, 100.0%, avg_loss = 3.3236
Epoch 17, 100.0%, avg_loss = 3.1460
Epoch 18, 100.0%, avg_loss = 3.0412
Epoch 19, 100.0%, avg_loss = 2.7512
Epoch 20, 100.0%, avg_loss = 2.8803
Validation Loss: 2.7592
Epoch 21, 100.0%, avg_loss = 2.7525
Epoch 22, 100.0%, avg_loss = 2.5330
Epoch 23, 100.0%, avg_loss = 2.5063
Epoch 24, 100.0%, avg_loss = 2.6269
Epoch 25, 100.0%, avg_loss = 2.5235
Vali

Epoch 199, 100.0%, avg_loss = 1.7383
Epoch 200, 100.0%, avg_loss = 1.6626
Validation Loss: 1.7252
Epoch 201, 100.0%, avg_loss = 1.7375
Epoch 202, 100.0%, avg_loss = 1.7359
Epoch 203, 100.0%, avg_loss = 1.7739
Epoch 204, 100.0%, avg_loss = 1.7102
Epoch 205, 100.0%, avg_loss = 1.8066
Validation Loss: 1.8054
Epoch 206, 100.0%, avg_loss = 1.8077
Epoch 207, 100.0%, avg_loss = 1.7584
Epoch 208, 100.0%, avg_loss = 1.7747
Epoch 209, 100.0%, avg_loss = 1.8580
Epoch 210, 100.0%, avg_loss = 1.7180
Validation Loss: 1.7272
Epoch 211, 100.0%, avg_loss = 1.6851
Epoch 212, 100.0%, avg_loss = 1.7097
Epoch 213, 100.0%, avg_loss = 1.7331
Epoch 214, 100.0%, avg_loss = 1.7457
Epoch 215, 100.0%, avg_loss = 1.6773
Validation Loss: 1.7608
Epoch 216, 100.0%, avg_loss = 1.7558
Epoch 217, 100.0%, avg_loss = 1.6896
Epoch 218, 100.0%, avg_loss = 1.7389
Epoch 219, 100.0%, avg_loss = 1.7680
Epoch 220, 100.0%, avg_loss = 1.8155
Validation Loss: 1.7942
Epoch 221, 100.0%, avg_loss = 1.6893
Epoch 222, 100.0%, avg_loss =

Epoch 395, 100.0%, avg_loss = 1.7169
Validation Loss: 1.7054
Epoch 396, 100.0%, avg_loss = 1.7533
Epoch 397, 100.0%, avg_loss = 1.7315
Epoch 398, 100.0%, avg_loss = 1.6334
Epoch 399, 100.0%, avg_loss = 1.6061
Epoch 400, 100.0%, avg_loss = 1.6769
Validation Loss: 1.6884
Epoch 401, 100.0%, avg_loss = 1.6639
Epoch 402, 100.0%, avg_loss = 1.6962
Epoch 403, 100.0%, avg_loss = 1.7475
Epoch 404, 100.0%, avg_loss = 1.6862
Epoch 405, 100.0%, avg_loss = 1.7671
Validation Loss: 1.6419
Epoch 406, 100.0%, avg_loss = 1.6428
Epoch 407, 100.0%, avg_loss = 1.6905
Epoch 408, 100.0%, avg_loss = 1.6720
Epoch 409, 100.0%, avg_loss = 1.6676
Epoch 410, 100.0%, avg_loss = 1.6309
Validation Loss: 1.6595
Epoch 411, 100.0%, avg_loss = 1.6667
Epoch 412, 100.0%, avg_loss = 1.7026
Epoch 413, 100.0%, avg_loss = 1.6789
Epoch 414, 100.0%, avg_loss = 1.7096
Epoch 415, 100.0%, avg_loss = 1.6771
Validation Loss: 1.6827
Epoch 416, 100.0%, avg_loss = 1.6718
Epoch 417, 100.0%, avg_loss = 1.6336
Epoch 418, 100.0%, avg_loss =

Epoch 591, 100.0%, avg_loss = 1.6698
Epoch 592, 100.0%, avg_loss = 1.6015
Epoch 593, 100.0%, avg_loss = 1.6658
Epoch 594, 100.0%, avg_loss = 1.6147
Epoch 595, 100.0%, avg_loss = 1.6270
Validation Loss: 1.6312
Epoch 596, 100.0%, avg_loss = 1.6030
Epoch 597, 100.0%, avg_loss = 1.6792
Epoch 598, 100.0%, avg_loss = 1.6091
Epoch 599, 100.0%, avg_loss = 1.5619
Epoch 600, 100.0%, avg_loss = 1.5967
Validation Loss: 1.6423
Epoch 601, 100.0%, avg_loss = 1.5816
Epoch 602, 100.0%, avg_loss = 1.6085
Epoch 603, 100.0%, avg_loss = 1.5842
Epoch 604, 100.0%, avg_loss = 1.6302
Epoch 605, 100.0%, avg_loss = 1.5935
Validation Loss: 1.6444
Epoch 606, 100.0%, avg_loss = 1.6190
Epoch 607, 100.0%, avg_loss = 1.6055
Epoch 608, 100.0%, avg_loss = 1.6021
Epoch 609, 100.0%, avg_loss = 1.5263
Epoch 610, 100.0%, avg_loss = 1.5930
Validation Loss: 1.6277
Epoch 611, 100.0%, avg_loss = 1.6940
Epoch 612, 100.0%, avg_loss = 1.6538
Epoch 613, 100.0%, avg_loss = 1.6212
Epoch 614, 100.0%, avg_loss = 1.5867
Epoch 615, 100.0

Epoch 787, 100.0%, avg_loss = 1.6043
Epoch 788, 100.0%, avg_loss = 1.6115
Epoch 789, 100.0%, avg_loss = 1.6400
Epoch 790, 100.0%, avg_loss = 1.6040
Validation Loss: 1.6440
Epoch 791, 100.0%, avg_loss = 1.5762
Epoch 792, 100.0%, avg_loss = 1.5691
Epoch 793, 100.0%, avg_loss = 1.5778
Epoch 794, 100.0%, avg_loss = 1.5624
Epoch 795, 100.0%, avg_loss = 1.5952
Validation Loss: 1.6421
Epoch 796, 100.0%, avg_loss = 1.6130
Epoch 797, 100.0%, avg_loss = 1.5330
Epoch 798, 100.0%, avg_loss = 1.6128
Epoch 799, 100.0%, avg_loss = 1.6655
Epoch 800, 100.0%, avg_loss = 1.5880
Validation Loss: 1.6708
Epoch 801, 100.0%, avg_loss = 1.6158
Epoch 802, 100.0%, avg_loss = 1.6161
Epoch 803, 100.0%, avg_loss = 1.6473
Epoch 804, 100.0%, avg_loss = 1.6268
Epoch 805, 100.0%, avg_loss = 1.5099
Validation Loss: 1.6271
Epoch 806, 100.0%, avg_loss = 1.6448
Epoch 807, 100.0%, avg_loss = 1.5644
Epoch 808, 100.0%, avg_loss = 1.5821
Epoch 809, 100.0%, avg_loss = 1.5683
Epoch 810, 100.0%, avg_loss = 1.6097
Validation Loss:

Epoch 983, 100.0%, avg_loss = 1.7373
Epoch 984, 100.0%, avg_loss = 1.6728
Epoch 985, 100.0%, avg_loss = 1.5644
Validation Loss: 1.6457
Epoch 986, 100.0%, avg_loss = 1.6015
Epoch 987, 100.0%, avg_loss = 1.6684
Epoch 988, 100.0%, avg_loss = 1.5317
Epoch 989, 100.0%, avg_loss = 1.6519
Epoch 990, 100.0%, avg_loss = 1.5378
Validation Loss: 1.5369
Epoch 991, 100.0%, avg_loss = 1.6094
Epoch 992, 100.0%, avg_loss = 1.5793
Epoch 993, 100.0%, avg_loss = 1.5896
Epoch 994, 100.0%, avg_loss = 1.5328
Epoch 995, 100.0%, avg_loss = 1.6157
Validation Loss: 1.6883
Epoch 996, 100.0%, avg_loss = 1.5945
Epoch 997, 100.0%, avg_loss = 1.6118
Epoch 998, 100.0%, avg_loss = 1.6356
Epoch 999, 100.0%, avg_loss = 1.5561
Epoch 1000, 100.0%, avg_loss = 1.5695
Validation Loss: 1.5949
Epoch 1001, 100.0%, avg_loss = 1.6518
Epoch 1002, 100.0%, avg_loss = 1.6329
Epoch 1003, 100.0%, avg_loss = 1.6231
Epoch 1004, 100.0%, avg_loss = 1.6125
Epoch 1005, 100.0%, avg_loss = 1.6508
Validation Loss: 1.5390
Epoch 1006, 100.0%, avg

Epoch 1175, 100.0%, avg_loss = 1.5640
Validation Loss: 1.5501
Epoch 1176, 100.0%, avg_loss = 1.5572
Epoch 1177, 100.0%, avg_loss = 1.6591
Epoch 1178, 100.0%, avg_loss = 1.6091
Epoch 1179, 100.0%, avg_loss = 1.5990
Epoch 1180, 100.0%, avg_loss = 1.5637
Validation Loss: 1.5952
Epoch 1181, 100.0%, avg_loss = 1.6155
Epoch 1182, 100.0%, avg_loss = 1.5255
Epoch 1183, 100.0%, avg_loss = 1.6691
Epoch 1184, 100.0%, avg_loss = 1.5618
Epoch 1185, 100.0%, avg_loss = 1.6169
Validation Loss: 1.5677
Epoch 1186, 100.0%, avg_loss = 1.5282
Epoch 1187, 100.0%, avg_loss = 1.5769
Epoch 1188, 100.0%, avg_loss = 1.5913
Epoch 1189, 100.0%, avg_loss = 1.5304
Epoch 1190, 100.0%, avg_loss = 1.5928
Validation Loss: 1.5862
Epoch 1191, 100.0%, avg_loss = 1.5278
Epoch 1192, 100.0%, avg_loss = 1.5639
Epoch 1193, 100.0%, avg_loss = 1.5723
Epoch 1194, 100.0%, avg_loss = 1.5915
Epoch 1195, 100.0%, avg_loss = 1.5988
Validation Loss: 1.6368
Epoch 1196, 100.0%, avg_loss = 1.5904
Epoch 1197, 100.0%, avg_loss = 1.5769
Epoch 

Epoch 1366, 100.0%, avg_loss = 1.5156
Epoch 1367, 100.0%, avg_loss = 1.5536
Epoch 1368, 100.0%, avg_loss = 1.5686
Epoch 1369, 100.0%, avg_loss = 1.5370
Epoch 1370, 100.0%, avg_loss = 1.5858
Validation Loss: 1.5445
Epoch 1371, 100.0%, avg_loss = 1.5888
Epoch 1372, 100.0%, avg_loss = 1.4805
Epoch 1373, 100.0%, avg_loss = 1.5207
Epoch 1374, 100.0%, avg_loss = 1.6101
Epoch 1375, 100.0%, avg_loss = 1.5880
Validation Loss: 1.5502
Epoch 1376, 100.0%, avg_loss = 1.5536
Epoch 1377, 100.0%, avg_loss = 1.5257
Epoch 1378, 100.0%, avg_loss = 1.6488
Epoch 1379, 100.0%, avg_loss = 1.6104
Epoch 1380, 100.0%, avg_loss = 1.5931
Validation Loss: 1.6145
Epoch 1381, 100.0%, avg_loss = 1.6050
Epoch 1382, 100.0%, avg_loss = 1.6176
Epoch 1383, 100.0%, avg_loss = 1.6139
Epoch 1384, 100.0%, avg_loss = 1.4840
Epoch 1385, 100.0%, avg_loss = 1.4868
Validation Loss: 1.5342
Epoch 1386, 100.0%, avg_loss = 1.6096
Epoch 1387, 100.0%, avg_loss = 1.6079
Epoch 1388, 100.0%, avg_loss = 1.5694
Epoch 1389, 100.0%, avg_loss =

Epoch 1557, 100.0%, avg_loss = 1.6364
Epoch 1558, 100.0%, avg_loss = 1.5674
Epoch 1559, 100.0%, avg_loss = 1.6097
Epoch 1560, 100.0%, avg_loss = 1.5593
Validation Loss: 1.6506
Epoch 1561, 100.0%, avg_loss = 1.4849
Epoch 1562, 100.0%, avg_loss = 1.5583
Epoch 1563, 100.0%, avg_loss = 1.5360
Epoch 1564, 100.0%, avg_loss = 1.5568
Epoch 1565, 100.0%, avg_loss = 1.5928
Validation Loss: 1.5343
Epoch 1566, 100.0%, avg_loss = 1.6540
Epoch 1567, 100.0%, avg_loss = 1.5076
Epoch 1568, 100.0%, avg_loss = 1.6091
Epoch 1569, 100.0%, avg_loss = 1.5021
Epoch 1570, 100.0%, avg_loss = 1.5782
Validation Loss: 1.6097
Epoch 1571, 100.0%, avg_loss = 1.5396
Epoch 1572, 100.0%, avg_loss = 1.5401
Epoch 1573, 100.0%, avg_loss = 1.5763
Epoch 1574, 100.0%, avg_loss = 1.5492
Epoch 1575, 100.0%, avg_loss = 1.5403
Validation Loss: 1.6257
Epoch 1576, 100.0%, avg_loss = 1.5751
Epoch 1577, 100.0%, avg_loss = 1.5281
Epoch 1578, 100.0%, avg_loss = 1.5530
Epoch 1579, 100.0%, avg_loss = 1.5322
Epoch 1580, 100.0%, avg_loss =

Epoch 1748, 100.0%, avg_loss = 1.4857
Epoch 1749, 100.0%, avg_loss = 1.5429
Epoch 1750, 100.0%, avg_loss = 1.5323
Validation Loss: 1.5302
Epoch 1751, 100.0%, avg_loss = 1.6034
Epoch 1752, 100.0%, avg_loss = 1.5106
Epoch 1753, 100.0%, avg_loss = 1.5367
Epoch 1754, 100.0%, avg_loss = 1.5912
Epoch 1755, 100.0%, avg_loss = 1.5764
Validation Loss: 1.5525
Epoch 1756, 100.0%, avg_loss = 1.5257
Epoch 1757, 100.0%, avg_loss = 1.4937
Epoch 1758, 100.0%, avg_loss = 1.5487
Epoch 1759, 100.0%, avg_loss = 1.4975
Epoch 1760, 100.0%, avg_loss = 1.5148
Validation Loss: 1.5293
Epoch 1761, 100.0%, avg_loss = 1.5595
Epoch 1762, 100.0%, avg_loss = 1.6006
Epoch 1763, 100.0%, avg_loss = 1.5545
Epoch 1764, 100.0%, avg_loss = 1.5450
Epoch 1765, 100.0%, avg_loss = 1.5054
Validation Loss: 1.5358
Epoch 1766, 100.0%, avg_loss = 1.6367
Epoch 1767, 100.0%, avg_loss = 1.4957
Epoch 1768, 100.0%, avg_loss = 1.5905
Epoch 1769, 100.0%, avg_loss = 1.5647
Epoch 1770, 100.0%, avg_loss = 1.5460
Validation Loss: 1.5853
Epoch 

Epoch 1939, 100.0%, avg_loss = 1.5770
Epoch 1940, 100.0%, avg_loss = 1.5506
Validation Loss: 1.5802
Epoch 1941, 100.0%, avg_loss = 1.5658
Epoch 1942, 100.0%, avg_loss = 1.5899
Epoch 1943, 100.0%, avg_loss = 1.5634
Epoch 1944, 100.0%, avg_loss = 1.5257
Epoch 1945, 100.0%, avg_loss = 1.5652
Validation Loss: 1.5645
Epoch 1946, 100.0%, avg_loss = 1.5085
Epoch 1947, 100.0%, avg_loss = 1.5559
Epoch 1948, 100.0%, avg_loss = 1.5261
Epoch 1949, 100.0%, avg_loss = 1.5919
Epoch 1950, 100.0%, avg_loss = 1.5535
Validation Loss: 1.6310
Epoch 1951, 100.0%, avg_loss = 1.4843
Epoch 1952, 100.0%, avg_loss = 1.5485
Epoch 1953, 100.0%, avg_loss = 1.5428
Epoch 1954, 100.0%, avg_loss = 1.5741
Epoch 1955, 100.0%, avg_loss = 1.5668
Validation Loss: 1.5949
Epoch 1956, 100.0%, avg_loss = 1.5890
Epoch 1957, 100.0%, avg_loss = 1.5946
Epoch 1958, 100.0%, avg_loss = 1.5666
Epoch 1959, 100.0%, avg_loss = 1.5442
Epoch 1960, 100.0%, avg_loss = 1.5023
Validation Loss: 1.6199
Epoch 1961, 100.0%, avg_loss = 1.4947
Epoch 

In [12]:
torch.save(global_predictor, './results_osaka/global_predictor_fl.pytorch')

In [13]:
with open('./results_osaka/federated_global_predictor_training_loss.pk', 'wb') as f:
    pk.dump(training_loss, f)
with open('./results_osaka/federated_global_predictor_validation_loss.pk', 'wb') as f:
    pk.dump(validation_loss, f)

In [None]:
with torch.no_grad():

    cnt = 0
    avg_loss = 0.0

    for uid in data_qry_test:

        nk = data_qry_test[uid].shape[0]

        t = np.random.randint(T - 2 * dT + 1)
        x_loc_qry = data_qry_test[uid][:, t: t + dT]
        x_t_qry = torch.zeros_like(x_loc_qry) + t
        y = data_qry_test[uid][:, t + 2 * dT - 1]

        loss = global_predictor_server(x_loc_qry, x_t_qry, y)

        avg_loss += loss.item()
        cnt += nk

    print('Validation Loss: {:.4f}'.format(avg_loss / cnt))