In [1]:
import torch
import torch.nn as nn
import pickle as pk
import numpy as np
from model import GlobalPredictor
import random

In [2]:
data = dict({})
for d in range(15, 29):
    filename = '/data/fan/UsersInTokyoProcessed/201211{:02d}_interp.pk'.format(d)
    print(filename)
    with open(filename, 'rb') as f:
        data[d] = pk.load(f)

/data/fan/UsersInTokyoProcessed/20121115_interp.pk
/data/fan/UsersInTokyoProcessed/20121116_interp.pk
/data/fan/UsersInTokyoProcessed/20121117_interp.pk
/data/fan/UsersInTokyoProcessed/20121118_interp.pk
/data/fan/UsersInTokyoProcessed/20121119_interp.pk
/data/fan/UsersInTokyoProcessed/20121120_interp.pk
/data/fan/UsersInTokyoProcessed/20121121_interp.pk
/data/fan/UsersInTokyoProcessed/20121122_interp.pk
/data/fan/UsersInTokyoProcessed/20121123_interp.pk
/data/fan/UsersInTokyoProcessed/20121124_interp.pk
/data/fan/UsersInTokyoProcessed/20121125_interp.pk
/data/fan/UsersInTokyoProcessed/20121126_interp.pk
/data/fan/UsersInTokyoProcessed/20121127_interp.pk
/data/fan/UsersInTokyoProcessed/20121128_interp.pk


In [3]:
data_qry_train = dict({})
data_qry_test = dict({})

for d in range(15, 22):
    for uid in data[d]:
        if uid not in data_qry_train:
            data_qry_train[uid] = [data[d][uid]]
        else:
            data_qry_train[uid].append(data[d][uid])
            
for d in range(22, 29):
    for uid in data[d]:
        if uid not in data_qry_test:
            data_qry_test[uid] = [data[d][uid]]
        else:
            data_qry_test[uid].append(data[d][uid])

In [4]:
for uid in data_qry_train:
    data_qry_train[uid] = torch.LongTensor(data_qry_train[uid]).cuda(1)
    
for uid in data_qry_test:
    data_qry_test[uid] = torch.LongTensor(data_qry_test[uid]).cuda(1)

In [5]:
num_locs = 1600
loc_embedding_dim = 128
T = 96
num_time = T
time_embedding_dim = 32
hidden_dim = 256
latent_dim = 256

In [6]:
lr = 1e-2

In [7]:
global_predictor = GlobalPredictor(num_locs, loc_embedding_dim, num_time, time_embedding_dim, hidden_dim, latent_dim).cuda(1)
global_predictor_server = GlobalPredictor(num_locs, loc_embedding_dim, num_time, time_embedding_dim, hidden_dim, latent_dim).cuda(1)
optimizer = torch.optim.SGD(global_predictor.parameters(), lr=lr)
optimizer_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)
optimizer.zero_grad()

In [8]:
dT = 4

In [9]:
training_loss = dict({})
validation_loss = dict({})

In [10]:
user_list_train = list(data_qry_train.keys())
time_list = list(range(T - 2 * dT + 1))

for epoch in range(1, 4001):
    
    optimizer_scheduler.step()
    
    update_user_list = random.sample(user_list_train, 512)
    
    update_weights_dict = global_predictor_server.state_dict()
    for para in update_weights_dict:
        update_weights_dict[para] = []
    
    avg_loss = 0.0
    cnt = 0
    n = 0
    
    uidx = 0
    
    for uid in update_user_list:
        uidx += 1
        # loading server model to local
        global_predictor.load_state_dict(global_predictor_server.state_dict())
        optimizer = torch.optim.SGD(global_predictor.parameters(), lr=lr)
        
        nk = data_qry_train[uid].shape[0]
        n += data_qry_train[uid].shape[0]
        
        for e in range(5):
        
            t = np.random.randint(T - 2 * dT + 1)
            x_loc_qry = data_qry_train[uid][:, t: t + dT]
            x_t_qry = torch.zeros_like(x_loc_qry) + t
            y = data_qry_train[uid][:, t + 2 * dT - 1]

            loss = global_predictor(x_loc_qry, x_t_qry, y)

            loss.backward()
            avg_loss += loss.item()
            cnt += nk

            optimizer.step()
            optimizer.zero_grad()
            
        cur_state_dict = global_predictor.state_dict()
        for para in cur_state_dict:
            update_weights_dict[para].append(cur_state_dict[para].cpu())
        
        print('Epoch {:02d}, {:.1f}%, avg_loss = {:.4f}'.format(epoch, uidx * 100 / len(update_user_list), avg_loss / cnt), end='\r')

    with torch.no_grad():
        
        for para in update_weights_dict:
            update_weights_dict[para] = torch.mean(torch.stack(update_weights_dict[para]), dim=0)
            
        global_predictor_server.load_state_dict(update_weights_dict)
    
    training_loss[epoch] = avg_loss / cnt
    print('')
    # testing
    
    if epoch % 5 == 0:
        with torch.no_grad():

            cnt = 0
            avg_loss = 0.0

            for uid in data_qry_test:

                if np.random.ranf() > 0.1:
                    continue

                nk = data_qry_test[uid].shape[0]

                t = np.random.randint(T - 2 * dT + 1)
                x_loc_qry = data_qry_test[uid][:, t: t + dT]
                x_t_qry = torch.zeros_like(x_loc_qry) + t
                y = data_qry_test[uid][:, t + 2 * dT - 1]

                loss = global_predictor_server(x_loc_qry, x_t_qry, y)

                avg_loss += loss.item()
                cnt += nk

            print('Validation Loss: {:.4f}'.format(avg_loss / cnt))
            validation_loss[epoch] = avg_loss / cnt

Epoch 01, 100.0%, avg_loss = 6.9236
Epoch 02, 100.0%, avg_loss = 6.9713
Epoch 03, 100.0%, avg_loss = 6.9360
Epoch 04, 100.0%, avg_loss = 6.8872
Epoch 05, 100.0%, avg_loss = 6.8709
Validation Loss: 7.3761
Epoch 06, 100.0%, avg_loss = 6.8852
Epoch 07, 100.0%, avg_loss = 6.9495
Epoch 08, 100.0%, avg_loss = 6.8814
Epoch 09, 100.0%, avg_loss = 6.8974
Epoch 10, 100.0%, avg_loss = 6.8656
Validation Loss: 7.3648
Epoch 11, 100.0%, avg_loss = 6.8461
Epoch 12, 100.0%, avg_loss = 6.7761
Epoch 13, 100.0%, avg_loss = 6.8053
Epoch 14, 100.0%, avg_loss = 6.8174
Epoch 15, 100.0%, avg_loss = 6.7880
Validation Loss: 7.3531
Epoch 16, 100.0%, avg_loss = 6.8332
Epoch 17, 100.0%, avg_loss = 6.7856
Epoch 18, 100.0%, avg_loss = 6.8112
Epoch 19, 100.0%, avg_loss = 6.8567
Epoch 20, 100.0%, avg_loss = 6.7561
Validation Loss: 7.3420
Epoch 21, 100.0%, avg_loss = 6.7449
Epoch 22, 100.0%, avg_loss = 6.7073
Epoch 23, 100.0%, avg_loss = 6.6906
Epoch 24, 100.0%, avg_loss = 6.7121
Epoch 25, 100.0%, avg_loss = 6.7049
Vali

Epoch 395, 100.0%, avg_loss = 3.7007
Validation Loss: 5.3192
Epoch 396, 100.0%, avg_loss = 3.4739
Epoch 397, 100.0%, avg_loss = 3.5046
Epoch 398, 100.0%, avg_loss = 3.5119
Epoch 399, 100.0%, avg_loss = 3.5741
Epoch 400, 100.0%, avg_loss = 3.4946
Validation Loss: 5.2972
Epoch 401, 100.0%, avg_loss = 3.6060
Epoch 402, 100.0%, avg_loss = 3.4839
Epoch 403, 100.0%, avg_loss = 3.4766
Epoch 404, 100.0%, avg_loss = 3.4271
Epoch 405, 100.0%, avg_loss = 3.5366
Validation Loss: 5.2616
Epoch 406, 100.0%, avg_loss = 3.4598
Epoch 407, 100.0%, avg_loss = 3.4127
Epoch 408, 100.0%, avg_loss = 3.4651
Epoch 409, 100.0%, avg_loss = 3.5024
Epoch 410, 100.0%, avg_loss = 3.4750
Validation Loss: 5.2573
Epoch 411, 100.0%, avg_loss = 3.3149
Epoch 412, 100.0%, avg_loss = 3.3909
Epoch 413, 100.0%, avg_loss = 3.3265
Epoch 414, 100.0%, avg_loss = 3.4759
Epoch 415, 100.0%, avg_loss = 3.3348
Validation Loss: 5.1996
Epoch 416, 100.0%, avg_loss = 3.2976
Epoch 417, 100.0%, avg_loss = 3.4550
Epoch 418, 100.0%, avg_loss =

Epoch 591, 100.0%, avg_loss = 3.0640
Epoch 592, 100.0%, avg_loss = 3.3152
Epoch 593, 100.0%, avg_loss = 3.2198
Epoch 594, 100.0%, avg_loss = 3.2150
Epoch 595, 100.0%, avg_loss = 3.1369
Validation Loss: 4.6015
Epoch 596, 100.0%, avg_loss = 3.1289
Epoch 597, 100.0%, avg_loss = 3.0756
Epoch 598, 100.0%, avg_loss = 3.2007
Epoch 599, 100.0%, avg_loss = 3.1399
Epoch 600, 100.0%, avg_loss = 3.0217
Validation Loss: 4.5765
Epoch 601, 100.0%, avg_loss = 3.0871
Epoch 602, 100.0%, avg_loss = 3.1963
Epoch 603, 100.0%, avg_loss = 3.1423
Epoch 604, 100.0%, avg_loss = 3.0877
Epoch 605, 100.0%, avg_loss = 3.1652
Validation Loss: 4.5540
Epoch 606, 100.0%, avg_loss = 3.1714
Epoch 607, 100.0%, avg_loss = 3.1942
Epoch 608, 100.0%, avg_loss = 3.1461
Epoch 609, 100.0%, avg_loss = 3.2107
Epoch 610, 100.0%, avg_loss = 3.2281
Validation Loss: 4.5295
Epoch 611, 100.0%, avg_loss = 3.1208
Epoch 612, 100.0%, avg_loss = 3.0411
Epoch 613, 100.0%, avg_loss = 3.1689
Epoch 614, 100.0%, avg_loss = 3.2039
Epoch 615, 100.0

Epoch 983, 100.0%, avg_loss = 2.8703
Epoch 984, 100.0%, avg_loss = 2.9064
Epoch 985, 100.0%, avg_loss = 2.7599
Validation Loss: 3.6219
Epoch 986, 100.0%, avg_loss = 2.8437
Epoch 987, 100.0%, avg_loss = 2.8627
Epoch 988, 100.0%, avg_loss = 2.8877
Epoch 989, 100.0%, avg_loss = 2.9078
Epoch 990, 100.0%, avg_loss = 2.8461
Validation Loss: 3.6062
Epoch 991, 100.0%, avg_loss = 2.8667
Epoch 992, 100.0%, avg_loss = 2.6930
Epoch 993, 100.0%, avg_loss = 2.8959
Epoch 994, 100.0%, avg_loss = 2.7066
Epoch 995, 100.0%, avg_loss = 2.8576
Validation Loss: 3.6085
Epoch 996, 100.0%, avg_loss = 2.7692
Epoch 997, 100.0%, avg_loss = 2.7145
Epoch 998, 100.0%, avg_loss = 2.8472
Epoch 999, 100.0%, avg_loss = 2.7870
Epoch 1000, 100.0%, avg_loss = 2.8479
Validation Loss: 3.5962
Epoch 1001, 100.0%, avg_loss = 2.7517
Epoch 1002, 100.0%, avg_loss = 2.8345
Epoch 1003, 100.0%, avg_loss = 2.8378
Epoch 1004, 100.0%, avg_loss = 2.8620
Epoch 1005, 100.0%, avg_loss = 2.7365
Validation Loss: 3.5840
Epoch 1006, 100.0%, avg

Epoch 1366, 100.0%, avg_loss = 2.5898
Epoch 1367, 100.0%, avg_loss = 2.6684
Epoch 1368, 100.0%, avg_loss = 2.6289
Epoch 1369, 100.0%, avg_loss = 2.6126
Epoch 1370, 100.0%, avg_loss = 2.5624
Validation Loss: 3.0819
Epoch 1371, 100.0%, avg_loss = 2.5994
Epoch 1372, 100.0%, avg_loss = 2.7336
Epoch 1373, 100.0%, avg_loss = 2.6773
Epoch 1374, 100.0%, avg_loss = 2.5958
Epoch 1375, 100.0%, avg_loss = 2.4873
Validation Loss: 3.0720
Epoch 1376, 100.0%, avg_loss = 2.5717
Epoch 1377, 100.0%, avg_loss = 2.5277
Epoch 1378, 100.0%, avg_loss = 2.6911
Epoch 1379, 100.0%, avg_loss = 2.6554
Epoch 1380, 100.0%, avg_loss = 2.5841
Validation Loss: 3.0530
Epoch 1381, 100.0%, avg_loss = 2.6613
Epoch 1382, 100.0%, avg_loss = 2.5592
Epoch 1383, 100.0%, avg_loss = 2.5869
Epoch 1384, 100.0%, avg_loss = 2.6170
Epoch 1385, 100.0%, avg_loss = 2.6095
Validation Loss: 3.0203
Epoch 1386, 100.0%, avg_loss = 2.5235
Epoch 1387, 100.0%, avg_loss = 2.7353
Epoch 1388, 100.0%, avg_loss = 2.6078
Epoch 1389, 100.0%, avg_loss =

Epoch 1748, 100.0%, avg_loss = 2.6189
Epoch 1749, 100.0%, avg_loss = 2.5325
Epoch 1750, 100.0%, avg_loss = 2.4973
Validation Loss: 2.7499
Epoch 1751, 100.0%, avg_loss = 2.5956
Epoch 1752, 100.0%, avg_loss = 2.3774
Epoch 1753, 100.0%, avg_loss = 2.5520
Epoch 1754, 100.0%, avg_loss = 2.4801
Epoch 1755, 100.0%, avg_loss = 2.4352
Validation Loss: 2.7440
Epoch 1756, 100.0%, avg_loss = 2.5045
Epoch 1757, 100.0%, avg_loss = 2.5407
Epoch 1758, 100.0%, avg_loss = 2.5357
Epoch 1759, 100.0%, avg_loss = 2.5710
Epoch 1760, 100.0%, avg_loss = 2.4669
Validation Loss: 2.7683
Epoch 1761, 100.0%, avg_loss = 2.4337
Epoch 1762, 100.0%, avg_loss = 2.4896
Epoch 1763, 100.0%, avg_loss = 2.4422
Epoch 1764, 100.0%, avg_loss = 2.4216
Epoch 1765, 100.0%, avg_loss = 2.5883
Validation Loss: 2.7465
Epoch 1766, 100.0%, avg_loss = 2.5209
Epoch 1767, 100.0%, avg_loss = 2.4636
Epoch 1768, 100.0%, avg_loss = 2.5407
Epoch 1769, 100.0%, avg_loss = 2.3635
Epoch 1770, 100.0%, avg_loss = 2.6014
Validation Loss: 2.7596
Epoch 

Epoch 2130, 100.0%, avg_loss = 2.2837
Validation Loss: 2.5325
Epoch 2131, 100.0%, avg_loss = 2.4030
Epoch 2132, 100.0%, avg_loss = 2.3745
Epoch 2133, 100.0%, avg_loss = 2.4694
Epoch 2134, 100.0%, avg_loss = 2.3744
Epoch 2135, 100.0%, avg_loss = 2.4871
Validation Loss: 2.5342
Epoch 2136, 100.0%, avg_loss = 2.3860
Epoch 2137, 100.0%, avg_loss = 2.4785
Epoch 2138, 100.0%, avg_loss = 2.3610
Epoch 2139, 100.0%, avg_loss = 2.4773
Epoch 2140, 100.0%, avg_loss = 2.4103
Validation Loss: 2.5218
Epoch 2141, 100.0%, avg_loss = 2.4228
Epoch 2142, 100.0%, avg_loss = 2.4318
Epoch 2143, 100.0%, avg_loss = 2.2899
Epoch 2144, 100.0%, avg_loss = 2.4069
Epoch 2145, 100.0%, avg_loss = 2.4132
Validation Loss: 2.5693
Epoch 2146, 100.0%, avg_loss = 2.2996
Epoch 2147, 100.0%, avg_loss = 2.4064
Epoch 2148, 100.0%, avg_loss = 2.3391
Epoch 2149, 100.0%, avg_loss = 2.3425
Epoch 2150, 100.0%, avg_loss = 2.4113
Validation Loss: 2.5498
Epoch 2151, 100.0%, avg_loss = 2.4349
Epoch 2152, 100.0%, avg_loss = 2.3940
Epoch 

Epoch 2321, 100.0%, avg_loss = 2.3764
Epoch 2322, 100.0%, avg_loss = 2.3484
Epoch 2323, 100.0%, avg_loss = 2.4681
Epoch 2324, 100.0%, avg_loss = 2.2889
Epoch 2325, 100.0%, avg_loss = 2.5204
Validation Loss: 2.4735
Epoch 2326, 100.0%, avg_loss = 2.3257
Epoch 2327, 100.0%, avg_loss = 2.4143
Epoch 2328, 100.0%, avg_loss = 2.4772
Epoch 2329, 100.0%, avg_loss = 2.2694
Epoch 2330, 100.0%, avg_loss = 2.3606
Validation Loss: 2.5053
Epoch 2331, 100.0%, avg_loss = 2.3824
Epoch 2332, 100.0%, avg_loss = 2.4348
Epoch 2333, 100.0%, avg_loss = 2.3196
Epoch 2334, 100.0%, avg_loss = 2.4967
Epoch 2335, 100.0%, avg_loss = 2.4465
Validation Loss: 2.4735
Epoch 2336, 100.0%, avg_loss = 2.2891
Epoch 2337, 100.0%, avg_loss = 2.3136
Epoch 2338, 100.0%, avg_loss = 2.4345
Epoch 2339, 100.0%, avg_loss = 2.4948
Epoch 2340, 100.0%, avg_loss = 2.5152
Validation Loss: 2.4805
Epoch 2341, 100.0%, avg_loss = 2.4113
Epoch 2342, 100.0%, avg_loss = 2.2667
Epoch 2343, 100.0%, avg_loss = 2.3216
Epoch 2344, 100.0%, avg_loss =

Epoch 2703, 100.0%, avg_loss = 2.2706
Epoch 2704, 100.0%, avg_loss = 2.3002
Epoch 2705, 100.0%, avg_loss = 2.2289
Validation Loss: 2.3804
Epoch 2706, 100.0%, avg_loss = 2.5102
Epoch 2707, 100.0%, avg_loss = 2.3661
Epoch 2708, 100.0%, avg_loss = 2.3144
Epoch 2709, 100.0%, avg_loss = 2.3922
Epoch 2710, 100.0%, avg_loss = 2.4077
Validation Loss: 2.3810
Epoch 2711, 100.0%, avg_loss = 2.3441
Epoch 2712, 100.0%, avg_loss = 2.4159
Epoch 2713, 100.0%, avg_loss = 2.3649
Epoch 2714, 100.0%, avg_loss = 2.2634
Epoch 2715, 100.0%, avg_loss = 2.3937
Validation Loss: 2.3770
Epoch 2716, 100.0%, avg_loss = 2.3820
Epoch 2717, 100.0%, avg_loss = 2.2793
Epoch 2718, 100.0%, avg_loss = 2.3608
Epoch 2719, 100.0%, avg_loss = 2.3387
Epoch 2720, 100.0%, avg_loss = 2.3368
Validation Loss: 2.3624
Epoch 2721, 100.0%, avg_loss = 2.2279
Epoch 2722, 100.0%, avg_loss = 2.2738
Epoch 2723, 100.0%, avg_loss = 2.4191
Epoch 2724, 100.0%, avg_loss = 2.3950
Epoch 2725, 100.0%, avg_loss = 2.3114
Validation Loss: 2.3582
Epoch 

Epoch 3085, 100.0%, avg_loss = 2.3396
Validation Loss: 2.3030
Epoch 3086, 100.0%, avg_loss = 2.2635
Epoch 3087, 100.0%, avg_loss = 2.1958
Epoch 3088, 100.0%, avg_loss = 2.2232
Epoch 3089, 100.0%, avg_loss = 2.2601
Epoch 3090, 100.0%, avg_loss = 2.2507
Validation Loss: 2.2960
Epoch 3091, 100.0%, avg_loss = 2.4231
Epoch 3092, 100.0%, avg_loss = 2.3593
Epoch 3093, 100.0%, avg_loss = 2.1838
Epoch 3094, 100.0%, avg_loss = 2.3379
Epoch 3095, 100.0%, avg_loss = 2.2474
Validation Loss: 2.3246
Epoch 3096, 100.0%, avg_loss = 2.3215
Epoch 3097, 100.0%, avg_loss = 2.2734
Epoch 3098, 100.0%, avg_loss = 2.3833
Epoch 3099, 100.0%, avg_loss = 2.2384
Epoch 3100, 100.0%, avg_loss = 2.2767
Validation Loss: 2.2655
Epoch 3101, 100.0%, avg_loss = 2.2288
Epoch 3102, 100.0%, avg_loss = 2.2402
Epoch 3103, 100.0%, avg_loss = 2.2566
Epoch 3104, 100.0%, avg_loss = 2.1314
Epoch 3105, 100.0%, avg_loss = 2.2526
Validation Loss: 2.2660
Epoch 3106, 100.0%, avg_loss = 2.2117
Epoch 3107, 100.0%, avg_loss = 2.2541
Epoch 

Epoch 3276, 100.0%, avg_loss = 2.2857
Epoch 3277, 100.0%, avg_loss = 2.2103
Epoch 3278, 100.0%, avg_loss = 2.1628
Epoch 3279, 100.0%, avg_loss = 2.1866
Epoch 3280, 100.0%, avg_loss = 2.3074
Validation Loss: 2.2643
Epoch 3281, 100.0%, avg_loss = 2.1709
Epoch 3282, 100.0%, avg_loss = 2.1620
Epoch 3283, 100.0%, avg_loss = 2.3254
Epoch 3284, 100.0%, avg_loss = 2.3026
Epoch 3285, 100.0%, avg_loss = 2.1721
Validation Loss: 2.2514
Epoch 3286, 100.0%, avg_loss = 2.2488
Epoch 3287, 100.0%, avg_loss = 2.2241
Epoch 3288, 100.0%, avg_loss = 2.3521
Epoch 3289, 100.0%, avg_loss = 2.2119
Epoch 3290, 100.0%, avg_loss = 2.2914
Validation Loss: 2.2859
Epoch 3291, 100.0%, avg_loss = 2.2631
Epoch 3292, 100.0%, avg_loss = 2.2188
Epoch 3293, 100.0%, avg_loss = 2.3299
Epoch 3294, 100.0%, avg_loss = 2.2143
Epoch 3295, 100.0%, avg_loss = 2.2267
Validation Loss: 2.2597
Epoch 3296, 100.0%, avg_loss = 2.1909
Epoch 3297, 100.0%, avg_loss = 2.2500
Epoch 3298, 100.0%, avg_loss = 2.2583
Epoch 3299, 100.0%, avg_loss =

Epoch 3658, 100.0%, avg_loss = 2.3367
Epoch 3659, 100.0%, avg_loss = 2.2860
Epoch 3660, 100.0%, avg_loss = 2.2459
Validation Loss: 2.2453
Epoch 3661, 100.0%, avg_loss = 2.2195
Epoch 3662, 100.0%, avg_loss = 2.3144
Epoch 3663, 100.0%, avg_loss = 2.3434
Epoch 3664, 100.0%, avg_loss = 2.1782
Epoch 3665, 100.0%, avg_loss = 2.1515
Validation Loss: 2.2580
Epoch 3666, 100.0%, avg_loss = 2.2277
Epoch 3667, 100.0%, avg_loss = 2.1730
Epoch 3668, 100.0%, avg_loss = 2.2646
Epoch 3669, 100.0%, avg_loss = 2.1669
Epoch 3670, 100.0%, avg_loss = 2.3647
Validation Loss: 2.2070
Epoch 3671, 100.0%, avg_loss = 2.1847
Epoch 3672, 100.0%, avg_loss = 2.2734
Epoch 3673, 100.0%, avg_loss = 2.2999
Epoch 3674, 100.0%, avg_loss = 2.2079
Epoch 3675, 100.0%, avg_loss = 2.0325
Validation Loss: 2.2207
Epoch 3676, 100.0%, avg_loss = 2.1999
Epoch 3677, 100.0%, avg_loss = 2.2475
Epoch 3678, 100.0%, avg_loss = 2.1759
Epoch 3679, 100.0%, avg_loss = 2.2018
Epoch 3680, 100.0%, avg_loss = 2.1566
Validation Loss: 2.2228
Epoch 

In [11]:
torch.save(global_predictor, './results_tokyo/global_predictor_fl_fedavg.pytorch')

In [13]:
with open('./results_tokyo/global_predictor_fl_fedavg_training_loss.pk', 'wb') as f:
    pk.dump(training_loss, f)
with open('./results_tokyo/global_predictor_fl_fedavg_validation_loss.pk', 'wb') as f:
    pk.dump(validation_loss, f)