In [1]:
import torch
import torch.nn as nn
import pickle as pk
import numpy as np
from model import LocalPredictorFullSearch
import random

In [2]:
data = dict({})
for d in range(1, 29):
    filename = '../data/UsersInTokyo2012/201211{:02d}_interp.pk'.format(d)
    print(filename)
    with open(filename, 'rb') as f:
        data[d] = pk.load(f)

../data/UsersInTokyo2012/20121101_interp.pk
../data/UsersInTokyo2012/20121102_interp.pk
../data/UsersInTokyo2012/20121103_interp.pk
../data/UsersInTokyo2012/20121104_interp.pk
../data/UsersInTokyo2012/20121105_interp.pk
../data/UsersInTokyo2012/20121106_interp.pk
../data/UsersInTokyo2012/20121107_interp.pk
../data/UsersInTokyo2012/20121108_interp.pk
../data/UsersInTokyo2012/20121109_interp.pk
../data/UsersInTokyo2012/20121110_interp.pk
../data/UsersInTokyo2012/20121111_interp.pk
../data/UsersInTokyo2012/20121112_interp.pk
../data/UsersInTokyo2012/20121113_interp.pk
../data/UsersInTokyo2012/20121114_interp.pk
../data/UsersInTokyo2012/20121115_interp.pk
../data/UsersInTokyo2012/20121116_interp.pk
../data/UsersInTokyo2012/20121117_interp.pk
../data/UsersInTokyo2012/20121118_interp.pk
../data/UsersInTokyo2012/20121119_interp.pk
../data/UsersInTokyo2012/20121120_interp.pk
../data/UsersInTokyo2012/20121121_interp.pk
../data/UsersInTokyo2012/20121122_interp.pk
../data/UsersInTokyo2012/2012112

In [3]:
uid_sets = dict({})
for d in data:
    uid_sets[d] = set(list(data[d].keys()))

In [4]:
uid_doc_set = set([])
for d in range(1, 15):
    uid_doc_set = uid_doc_set | uid_sets[d]

In [5]:
data_doc = dict({})
data_qry_train = dict({})
data_qry_test = dict({})

for d in range(1, 15):
    for uid in data[d]:
        if uid not in data_doc:
            data_doc[uid] = [data[d][uid]]
        else:
            data_doc[uid].append(data[d][uid])

for d in range(15, 22):
    for uid in data[d]:
        if uid not in data_qry_train:
            data_qry_train[uid] = [data[d][uid]]
        else:
            data_qry_train[uid].append(data[d][uid])
            
for d in range(22, 29):
    for uid in data[d]:
        if uid not in data_qry_test:
            data_qry_test[uid] = [data[d][uid]]
        else:
            data_qry_test[uid].append(data[d][uid])

In [6]:
for uid in data_doc:
    data_doc[uid] = torch.LongTensor(data_doc[uid]).cuda(1)

for uid in data_qry_train:
    data_qry_train[uid] = torch.LongTensor(data_qry_train[uid]).cuda(1)
    
for uid in data_qry_test:
    data_qry_test[uid] = torch.LongTensor(data_qry_test[uid]).cuda(1)

In [7]:
num_locs = 1600
loc_embedding_dim = 128
T = 96
num_time = T
time_embedding_dim = 32
hidden_dim = 256
latent_dim = 256

In [8]:
lr = 1e-3
momentum = 0.8

In [9]:
local_predictor = LocalPredictorFullSearch(num_locs, loc_embedding_dim, num_time, time_embedding_dim, hidden_dim, latent_dim).cuda(1)
local_predictor_server = LocalPredictorFullSearch(num_locs, loc_embedding_dim, num_time, time_embedding_dim, hidden_dim, latent_dim).cuda(1)
optimizer = torch.optim.SGD(local_predictor.parameters(), lr=lr)
optimizer.zero_grad()

In [10]:
dT = 4

In [11]:
training_loss = dict({})
validation_loss = dict({})

In [12]:
user_list_train = list(data_qry_train.keys())
time_list = list(range(T - 2 * dT + 1))

# Initialize global SGD
local_predictor_update = dict({})
        
for epoch in range(1, 4001):
    
    if epoch % 500 == 0:
        lr *= 0.5
    
    update_user_list = random.sample(user_list_train, 2048)
    random.shuffle(update_user_list)
    
    avg_loss = 0.0
    cnt = 0
    n = 0
    
    uidx = 0
    
    for uid in update_user_list:
        uidx += 1
        # loading server model to local
        local_predictor.load_state_dict(local_predictor_server.state_dict())

        nk = data_qry_train[uid].shape[0]
        n += data_qry_train[uid].shape[0]

        t = np.random.randint(T - 2 * dT + 1)
        x_loc_qry = data_qry_train[uid][:, t: t + dT]
        x_t_qry = torch.zeros_like(x_loc_qry) + t
        y = data_qry_train[uid][:, t + 2 * dT - 1]

        if uid not in data_doc:
            loss = local_predictor(x_loc_qry, x_t_qry, None, None, y)
        else:
            x_loc_doc = data_doc[uid][:, t: t + 2 * dT]
            x_t_doc = torch.zeros_like(x_loc_doc) + t
            loss = local_predictor(x_loc_qry, x_t_qry, x_loc_doc, x_t_doc, y)

        loss.backward()
        avg_loss += loss.item()
        cnt += nk
        
        with torch.no_grad():
            for name, para in local_predictor.named_parameters():
                if para.grad is None:
                    continue
                elif name not in local_predictor_update:
                    local_predictor_update[name] = (para.grad + torch.randn_like(para.grad) * 1e-2) * lr
                else:
                    local_predictor_update[name] += (para.grad + torch.randn_like(para.grad) * 1e-2) * lr
                
        optimizer.zero_grad()
        
        print('Epoch {:02d}, {:.1f}%, avg_loss = {:.4f}'.format(epoch, uidx * 100 / len(update_user_list), avg_loss / cnt), end='\r')

    # update
    with torch.no_grad():
        for name, para in local_predictor_server.named_parameters():
            para -= local_predictor_update[name]
            
        for name in local_predictor_update:
            local_predictor_update[name] *= momentum
    
    training_loss[epoch] = avg_loss / cnt
    print('')
    # testing
    
    if epoch % 5 == 0:
        with torch.no_grad():

            cnt = 0
            avg_loss = 0.0

            for uid in data_qry_test:

                if np.random.ranf() > 0.05:
                    continue

                nk = data_qry_test[uid].shape[0]

                t = np.random.randint(T - 2 * dT + 1)
                x_loc_qry = data_qry_test[uid][:, t: t + dT]
                x_t_qry = torch.zeros_like(x_loc_qry) + t
                y = data_qry_test[uid][:, t + 2 * dT - 1]

                if uid not in data_doc:
                    loss = local_predictor_server(x_loc_qry, x_t_qry, None, None, y)
                else:
                    x_loc_doc = data_doc[uid][:, t: t + 2 * dT]
                    x_t_doc = torch.zeros_like(x_loc_doc) + t
                    loss = local_predictor_server(x_loc_qry, x_t_qry, x_loc_doc, x_t_doc, y)

                avg_loss += loss.item()
                cnt += nk

            print('Validation Loss: {:.4f}'.format(avg_loss / cnt))
            validation_loss[epoch] = avg_loss / cnt

Epoch 01, 100.0%, avg_loss = 7.3783
Epoch 02, 100.0%, avg_loss = 7.2986
Epoch 03, 100.0%, avg_loss = 7.1357
Epoch 04, 100.0%, avg_loss = 6.8730
Epoch 05, 100.0%, avg_loss = 6.5930
Validation Loss: 6.4149
Epoch 06, 100.0%, avg_loss = 6.3554
Epoch 07, 100.0%, avg_loss = 6.1058
Epoch 08, 100.0%, avg_loss = 5.5671
Epoch 09, 100.0%, avg_loss = 5.0436
Epoch 10, 100.0%, avg_loss = 4.5968
Validation Loss: 4.1795
Epoch 11, 100.0%, avg_loss = 4.0858
Epoch 12, 100.0%, avg_loss = 3.8204
Epoch 13, 100.0%, avg_loss = 3.5401
Epoch 14, 100.0%, avg_loss = 3.2754
Epoch 15, 100.0%, avg_loss = 3.2098
Validation Loss: 3.1315
Epoch 16, 100.0%, avg_loss = 3.1756
Epoch 17, 100.0%, avg_loss = 2.9148
Epoch 18, 100.0%, avg_loss = 2.9282
Epoch 19, 100.0%, avg_loss = 2.7448
Epoch 20, 100.0%, avg_loss = 2.6724
Validation Loss: 2.8416
Epoch 21, 100.0%, avg_loss = 2.6849
Epoch 22, 100.0%, avg_loss = 2.7125
Epoch 23, 100.0%, avg_loss = 2.6452
Epoch 24, 100.0%, avg_loss = 2.6694
Epoch 25, 100.0%, avg_loss = 2.7241
Vali

Epoch 199, 100.0%, avg_loss = 1.8512
Epoch 200, 100.0%, avg_loss = 1.8478
Validation Loss: 1.8419
Epoch 201, 100.0%, avg_loss = 1.8917
Epoch 202, 100.0%, avg_loss = 1.7780
Epoch 203, 100.0%, avg_loss = 1.9045
Epoch 204, 100.0%, avg_loss = 1.7961
Epoch 205, 100.0%, avg_loss = 1.8659
Validation Loss: 1.8424
Epoch 206, 100.0%, avg_loss = 1.8297
Epoch 207, 100.0%, avg_loss = 1.8113
Epoch 208, 100.0%, avg_loss = 1.8466
Epoch 209, 100.0%, avg_loss = 1.8451
Epoch 210, 100.0%, avg_loss = 1.7873
Validation Loss: 1.8711
Epoch 211, 100.0%, avg_loss = 1.7392
Epoch 212, 100.0%, avg_loss = 1.9183
Epoch 213, 100.0%, avg_loss = 1.8788
Epoch 214, 100.0%, avg_loss = 1.8319
Epoch 215, 100.0%, avg_loss = 1.8559
Validation Loss: 1.8251
Epoch 216, 100.0%, avg_loss = 1.8646
Epoch 217, 100.0%, avg_loss = 1.7497
Epoch 218, 100.0%, avg_loss = 1.9005
Epoch 219, 100.0%, avg_loss = 1.8728
Epoch 220, 100.0%, avg_loss = 1.8578
Validation Loss: 1.8100
Epoch 221, 100.0%, avg_loss = 1.8106
Epoch 222, 100.0%, avg_loss =

Epoch 395, 100.0%, avg_loss = 1.7724
Validation Loss: 1.7583
Epoch 396, 100.0%, avg_loss = 1.7392
Epoch 397, 100.0%, avg_loss = 1.7399
Epoch 398, 100.0%, avg_loss = 1.7238
Epoch 399, 100.0%, avg_loss = 1.6896
Epoch 400, 100.0%, avg_loss = 1.8216
Validation Loss: 1.7579
Epoch 401, 100.0%, avg_loss = 1.8548
Epoch 402, 100.0%, avg_loss = 1.7383
Epoch 403, 100.0%, avg_loss = 1.7208
Epoch 404, 100.0%, avg_loss = 1.7308
Epoch 405, 100.0%, avg_loss = 1.7438
Validation Loss: 1.8287
Epoch 406, 100.0%, avg_loss = 1.7026
Epoch 407, 100.0%, avg_loss = 1.7673
Epoch 408, 100.0%, avg_loss = 1.7610
Epoch 409, 100.0%, avg_loss = 1.6938
Epoch 410, 100.0%, avg_loss = 1.7889
Validation Loss: 1.7189
Epoch 411, 100.0%, avg_loss = 1.7921
Epoch 412, 100.0%, avg_loss = 1.7654
Epoch 413, 100.0%, avg_loss = 1.7461
Epoch 414, 100.0%, avg_loss = 1.7293
Epoch 415, 100.0%, avg_loss = 1.7399
Validation Loss: 1.7577
Epoch 416, 100.0%, avg_loss = 1.7353
Epoch 417, 100.0%, avg_loss = 1.8305
Epoch 418, 100.0%, avg_loss =

Epoch 591, 100.0%, avg_loss = 1.6308
Epoch 592, 100.0%, avg_loss = 1.6644
Epoch 593, 100.0%, avg_loss = 1.6521
Epoch 594, 100.0%, avg_loss = 1.5943
Epoch 595, 100.0%, avg_loss = 1.6481
Validation Loss: 1.6353
Epoch 596, 100.0%, avg_loss = 1.6854
Epoch 597, 100.0%, avg_loss = 1.6387
Epoch 598, 100.0%, avg_loss = 1.6251
Epoch 599, 100.0%, avg_loss = 1.6798
Epoch 600, 100.0%, avg_loss = 1.6114
Validation Loss: 1.6804
Epoch 601, 100.0%, avg_loss = 1.5494
Epoch 602, 100.0%, avg_loss = 1.6795
Epoch 603, 100.0%, avg_loss = 1.6820
Epoch 604, 100.0%, avg_loss = 1.6893
Epoch 605, 100.0%, avg_loss = 1.6840
Validation Loss: 1.6250
Epoch 606, 100.0%, avg_loss = 1.6258
Epoch 607, 100.0%, avg_loss = 1.6409
Epoch 608, 100.0%, avg_loss = 1.5855
Epoch 609, 100.0%, avg_loss = 1.7391
Epoch 610, 100.0%, avg_loss = 1.6698
Validation Loss: 1.6130
Epoch 611, 100.0%, avg_loss = 1.6557
Epoch 612, 100.0%, avg_loss = 1.5881
Epoch 613, 100.0%, avg_loss = 1.6084
Epoch 614, 100.0%, avg_loss = 1.6956
Epoch 615, 100.0

Epoch 787, 100.0%, avg_loss = 1.6221
Epoch 788, 100.0%, avg_loss = 1.6521
Epoch 789, 100.0%, avg_loss = 1.5819
Epoch 790, 100.0%, avg_loss = 1.5320
Validation Loss: 1.5937
Epoch 791, 100.0%, avg_loss = 1.5966
Epoch 792, 100.0%, avg_loss = 1.4986
Epoch 793, 100.0%, avg_loss = 1.6227
Epoch 794, 100.0%, avg_loss = 1.6686
Epoch 795, 100.0%, avg_loss = 1.6122
Validation Loss: 1.5988
Epoch 796, 100.0%, avg_loss = 1.6338
Epoch 797, 100.0%, avg_loss = 1.6054
Epoch 798, 100.0%, avg_loss = 1.5753
Epoch 799, 100.0%, avg_loss = 1.5119
Epoch 800, 100.0%, avg_loss = 1.7005
Validation Loss: 1.5636
Epoch 801, 100.0%, avg_loss = 1.5358
Epoch 802, 100.0%, avg_loss = 1.5947
Epoch 803, 100.0%, avg_loss = 1.6060
Epoch 804, 100.0%, avg_loss = 1.5805
Epoch 805, 100.0%, avg_loss = 1.6137
Validation Loss: 1.6420
Epoch 806, 100.0%, avg_loss = 1.4885
Epoch 807, 100.0%, avg_loss = 1.5774
Epoch 808, 100.0%, avg_loss = 1.5637
Epoch 809, 100.0%, avg_loss = 1.5724
Epoch 810, 100.0%, avg_loss = 1.6414
Validation Loss:

Epoch 983, 100.0%, avg_loss = 1.5568
Epoch 984, 100.0%, avg_loss = 1.5692
Epoch 985, 100.0%, avg_loss = 1.6043
Validation Loss: 1.5966
Epoch 986, 100.0%, avg_loss = 1.5783
Epoch 987, 100.0%, avg_loss = 1.5744
Epoch 988, 100.0%, avg_loss = 1.5817
Epoch 989, 100.0%, avg_loss = 1.6466
Epoch 990, 100.0%, avg_loss = 1.5154
Validation Loss: 1.5933
Epoch 991, 100.0%, avg_loss = 1.5126
Epoch 992, 100.0%, avg_loss = 1.5475
Epoch 993, 100.0%, avg_loss = 1.5859
Epoch 994, 100.0%, avg_loss = 1.6360
Epoch 995, 100.0%, avg_loss = 1.5686
Validation Loss: 1.5993
Epoch 996, 100.0%, avg_loss = 1.5661
Epoch 997, 100.0%, avg_loss = 1.5768
Epoch 998, 100.0%, avg_loss = 1.5700
Epoch 999, 100.0%, avg_loss = 1.5968
Epoch 1000, 100.0%, avg_loss = 1.5378
Validation Loss: 1.6262
Epoch 1001, 100.0%, avg_loss = 1.5481
Epoch 1002, 100.0%, avg_loss = 1.5676
Epoch 1003, 100.0%, avg_loss = 1.5633
Epoch 1004, 100.0%, avg_loss = 1.5311
Epoch 1005, 100.0%, avg_loss = 1.5991
Validation Loss: 1.6076
Epoch 1006, 100.0%, avg

Epoch 1175, 100.0%, avg_loss = 1.4142
Validation Loss: 1.5805
Epoch 1176, 100.0%, avg_loss = 1.5163
Epoch 1177, 100.0%, avg_loss = 1.5564
Epoch 1178, 100.0%, avg_loss = 1.5568
Epoch 1179, 100.0%, avg_loss = 1.5182
Epoch 1180, 100.0%, avg_loss = 1.5616
Validation Loss: 1.5615
Epoch 1181, 100.0%, avg_loss = 1.6264
Epoch 1182, 100.0%, avg_loss = 1.5693
Epoch 1183, 100.0%, avg_loss = 1.5712
Epoch 1184, 100.0%, avg_loss = 1.4796
Epoch 1185, 100.0%, avg_loss = 1.5782
Validation Loss: 1.5585
Epoch 1186, 100.0%, avg_loss = 1.5071
Epoch 1187, 100.0%, avg_loss = 1.5856
Epoch 1188, 100.0%, avg_loss = 1.5521
Epoch 1189, 100.0%, avg_loss = 1.5207
Epoch 1190, 100.0%, avg_loss = 1.5729
Validation Loss: 1.5339
Epoch 1191, 100.0%, avg_loss = 1.6627
Epoch 1192, 100.0%, avg_loss = 1.5392
Epoch 1193, 100.0%, avg_loss = 1.4788
Epoch 1194, 100.0%, avg_loss = 1.6117
Epoch 1195, 100.0%, avg_loss = 1.5462
Validation Loss: 1.5936
Epoch 1196, 100.0%, avg_loss = 1.4964
Epoch 1197, 100.0%, avg_loss = 1.5366
Epoch 

Epoch 1366, 100.0%, avg_loss = 1.4856
Epoch 1367, 100.0%, avg_loss = 1.5728
Epoch 1368, 100.0%, avg_loss = 1.5719
Epoch 1369, 100.0%, avg_loss = 1.5096
Epoch 1370, 100.0%, avg_loss = 1.5657
Validation Loss: 1.5411
Epoch 1371, 100.0%, avg_loss = 1.5065
Epoch 1372, 100.0%, avg_loss = 1.5691
Epoch 1373, 100.0%, avg_loss = 1.5165
Epoch 1374, 100.0%, avg_loss = 1.5029
Epoch 1375, 100.0%, avg_loss = 1.5361
Validation Loss: 1.5696
Epoch 1376, 100.0%, avg_loss = 1.5164
Epoch 1377, 100.0%, avg_loss = 1.5267
Epoch 1378, 100.0%, avg_loss = 1.5280
Epoch 1379, 100.0%, avg_loss = 1.5258
Epoch 1380, 100.0%, avg_loss = 1.3944
Validation Loss: 1.5386
Epoch 1381, 100.0%, avg_loss = 1.5394
Epoch 1382, 100.0%, avg_loss = 1.4606
Epoch 1383, 100.0%, avg_loss = 1.4965
Epoch 1384, 100.0%, avg_loss = 1.5682
Epoch 1385, 100.0%, avg_loss = 1.5736
Validation Loss: 1.5346
Epoch 1386, 100.0%, avg_loss = 1.5204
Epoch 1387, 100.0%, avg_loss = 1.5108
Epoch 1388, 100.0%, avg_loss = 1.5555
Epoch 1389, 100.0%, avg_loss =

Epoch 1557, 100.0%, avg_loss = 1.4414
Epoch 1558, 100.0%, avg_loss = 1.5201
Epoch 1559, 100.0%, avg_loss = 1.5460
Epoch 1560, 100.0%, avg_loss = 1.5550
Validation Loss: 1.5408
Epoch 1561, 100.0%, avg_loss = 1.5427
Epoch 1562, 100.0%, avg_loss = 1.4867
Epoch 1563, 100.0%, avg_loss = 1.5003
Epoch 1564, 100.0%, avg_loss = 1.5034
Epoch 1565, 100.0%, avg_loss = 1.5431
Validation Loss: 1.5292
Epoch 1566, 100.0%, avg_loss = 1.4949
Epoch 1567, 100.0%, avg_loss = 1.5119
Epoch 1568, 100.0%, avg_loss = 1.4607
Epoch 1569, 100.0%, avg_loss = 1.6534
Epoch 1570, 100.0%, avg_loss = 1.4681
Validation Loss: 1.5627
Epoch 1571, 100.0%, avg_loss = 1.4652
Epoch 1572, 100.0%, avg_loss = 1.5589
Epoch 1573, 100.0%, avg_loss = 1.4481
Epoch 1574, 100.0%, avg_loss = 1.4318
Epoch 1575, 100.0%, avg_loss = 1.5182
Validation Loss: 1.5628
Epoch 1576, 100.0%, avg_loss = 1.4444
Epoch 1577, 100.0%, avg_loss = 1.5360
Epoch 1578, 100.0%, avg_loss = 1.4324
Epoch 1579, 100.0%, avg_loss = 1.5109
Epoch 1580, 100.0%, avg_loss =

Epoch 1748, 100.0%, avg_loss = 1.5238
Epoch 1749, 100.0%, avg_loss = 1.4704
Epoch 1750, 100.0%, avg_loss = 1.4832
Validation Loss: 1.5262
Epoch 1751, 100.0%, avg_loss = 1.4827
Epoch 1752, 100.0%, avg_loss = 1.5534
Epoch 1753, 100.0%, avg_loss = 1.4329
Epoch 1754, 100.0%, avg_loss = 1.5096
Epoch 1755, 100.0%, avg_loss = 1.5181
Validation Loss: 1.5322
Epoch 1756, 100.0%, avg_loss = 1.4613
Epoch 1757, 100.0%, avg_loss = 1.5025
Epoch 1758, 100.0%, avg_loss = 1.5834
Epoch 1759, 100.0%, avg_loss = 1.5254
Epoch 1760, 100.0%, avg_loss = 1.4378
Validation Loss: 1.4790
Epoch 1761, 100.0%, avg_loss = 1.5376
Epoch 1762, 100.0%, avg_loss = 1.5022
Epoch 1763, 100.0%, avg_loss = 1.4648
Epoch 1764, 100.0%, avg_loss = 1.4409
Epoch 1765, 100.0%, avg_loss = 1.5275
Validation Loss: 1.5065
Epoch 1766, 100.0%, avg_loss = 1.5215
Epoch 1767, 100.0%, avg_loss = 1.4480
Epoch 1768, 100.0%, avg_loss = 1.4679
Epoch 1769, 100.0%, avg_loss = 1.5794
Epoch 1770, 100.0%, avg_loss = 1.5498
Validation Loss: 1.5307
Epoch 

Epoch 1939, 100.0%, avg_loss = 1.4447
Epoch 1940, 100.0%, avg_loss = 1.5054
Validation Loss: 1.5291
Epoch 1941, 100.0%, avg_loss = 1.5297
Epoch 1942, 100.0%, avg_loss = 1.4715
Epoch 1943, 100.0%, avg_loss = 1.4715
Epoch 1944, 100.0%, avg_loss = 1.4669
Epoch 1945, 100.0%, avg_loss = 1.5009
Validation Loss: 1.4931
Epoch 1946, 100.0%, avg_loss = 1.4638
Epoch 1947, 100.0%, avg_loss = 1.5436
Epoch 1948, 100.0%, avg_loss = 1.4577
Epoch 1949, 100.0%, avg_loss = 1.4687
Epoch 1950, 100.0%, avg_loss = 1.4755
Validation Loss: 1.5140
Epoch 1951, 100.0%, avg_loss = 1.5016
Epoch 1952, 100.0%, avg_loss = 1.4620
Epoch 1953, 100.0%, avg_loss = 1.5307
Epoch 1954, 100.0%, avg_loss = 1.5304
Epoch 1955, 100.0%, avg_loss = 1.4501
Validation Loss: 1.5109
Epoch 1956, 100.0%, avg_loss = 1.4890
Epoch 1957, 100.0%, avg_loss = 1.5605
Epoch 1958, 100.0%, avg_loss = 1.5642
Epoch 1959, 100.0%, avg_loss = 1.4863
Epoch 1960, 100.0%, avg_loss = 1.5183
Validation Loss: 1.4850
Epoch 1961, 100.0%, avg_loss = 1.5109
Epoch 

Epoch 2130, 100.0%, avg_loss = 1.4553
Validation Loss: 1.4831
Epoch 2131, 100.0%, avg_loss = 1.4449
Epoch 2132, 100.0%, avg_loss = 1.4201
Epoch 2133, 100.0%, avg_loss = 1.4552
Epoch 2134, 100.0%, avg_loss = 1.4861
Epoch 2135, 100.0%, avg_loss = 1.3736
Validation Loss: 1.4835
Epoch 2136, 100.0%, avg_loss = 1.4307
Epoch 2137, 100.0%, avg_loss = 1.3901
Epoch 2138, 100.0%, avg_loss = 1.5283
Epoch 2139, 100.0%, avg_loss = 1.5331
Epoch 2140, 100.0%, avg_loss = 1.5016
Validation Loss: 1.4626
Epoch 2141, 100.0%, avg_loss = 1.5017
Epoch 2142, 100.0%, avg_loss = 1.4594
Epoch 2143, 100.0%, avg_loss = 1.4847
Epoch 2144, 100.0%, avg_loss = 1.5669
Epoch 2145, 100.0%, avg_loss = 1.4004
Validation Loss: 1.4725
Epoch 2146, 100.0%, avg_loss = 1.4627
Epoch 2147, 100.0%, avg_loss = 1.6155
Epoch 2148, 100.0%, avg_loss = 1.5611
Epoch 2149, 100.0%, avg_loss = 1.3577
Epoch 2150, 100.0%, avg_loss = 1.4528
Validation Loss: 1.5267
Epoch 2151, 100.0%, avg_loss = 1.4717
Epoch 2152, 100.0%, avg_loss = 1.4990
Epoch 

Epoch 2321, 100.0%, avg_loss = 1.4623
Epoch 2322, 100.0%, avg_loss = 1.5154
Epoch 2323, 100.0%, avg_loss = 1.4205
Epoch 2324, 100.0%, avg_loss = 1.4655
Epoch 2325, 100.0%, avg_loss = 1.3523
Validation Loss: 1.4999
Epoch 2326, 100.0%, avg_loss = 1.4716
Epoch 2327, 100.0%, avg_loss = 1.5472
Epoch 2328, 100.0%, avg_loss = 1.5323
Epoch 2329, 100.0%, avg_loss = 1.4709
Epoch 2330, 100.0%, avg_loss = 1.4398
Validation Loss: 1.5070
Epoch 2331, 100.0%, avg_loss = 1.5022
Epoch 2332, 100.0%, avg_loss = 1.5143
Epoch 2333, 100.0%, avg_loss = 1.4728
Epoch 2334, 100.0%, avg_loss = 1.4578
Epoch 2335, 100.0%, avg_loss = 1.4517
Validation Loss: 1.5012
Epoch 2336, 100.0%, avg_loss = 1.5236
Epoch 2337, 100.0%, avg_loss = 1.4773
Epoch 2338, 100.0%, avg_loss = 1.4333
Epoch 2339, 100.0%, avg_loss = 1.4370
Epoch 2340, 100.0%, avg_loss = 1.4445
Validation Loss: 1.4966
Epoch 2341, 100.0%, avg_loss = 1.5074
Epoch 2342, 100.0%, avg_loss = 1.4705
Epoch 2343, 100.0%, avg_loss = 1.4930
Epoch 2344, 100.0%, avg_loss =

Epoch 2512, 100.0%, avg_loss = 1.4166
Epoch 2513, 100.0%, avg_loss = 1.3867
Epoch 2514, 100.0%, avg_loss = 1.4152
Epoch 2515, 100.0%, avg_loss = 1.4057
Validation Loss: 1.5500
Epoch 2516, 100.0%, avg_loss = 1.5644
Epoch 2517, 100.0%, avg_loss = 1.5161
Epoch 2518, 100.0%, avg_loss = 1.4828
Epoch 2519, 100.0%, avg_loss = 1.4926
Epoch 2520, 100.0%, avg_loss = 1.4188
Validation Loss: 1.4665
Epoch 2521, 100.0%, avg_loss = 1.4866
Epoch 2522, 100.0%, avg_loss = 1.4525
Epoch 2523, 100.0%, avg_loss = 1.5360
Epoch 2524, 100.0%, avg_loss = 1.4168
Epoch 2525, 100.0%, avg_loss = 1.4522
Validation Loss: 1.4884
Epoch 2526, 100.0%, avg_loss = 1.4690
Epoch 2527, 100.0%, avg_loss = 1.4734
Epoch 2528, 100.0%, avg_loss = 1.4543
Epoch 2529, 100.0%, avg_loss = 1.4373
Epoch 2530, 100.0%, avg_loss = 1.5024
Validation Loss: 1.4616
Epoch 2531, 100.0%, avg_loss = 1.4653
Epoch 2532, 100.0%, avg_loss = 1.4905
Epoch 2533, 100.0%, avg_loss = 1.4490
Epoch 2534, 100.0%, avg_loss = 1.4792
Epoch 2535, 100.0%, avg_loss =

Epoch 2703, 100.0%, avg_loss = 1.5018
Epoch 2704, 100.0%, avg_loss = 1.3966
Epoch 2705, 100.0%, avg_loss = 1.4301
Validation Loss: 1.4931
Epoch 2706, 100.0%, avg_loss = 1.4309
Epoch 2707, 100.0%, avg_loss = 1.4224
Epoch 2708, 100.0%, avg_loss = 1.4396
Epoch 2709, 100.0%, avg_loss = 1.4303
Epoch 2710, 100.0%, avg_loss = 1.4714
Validation Loss: 1.4900
Epoch 2711, 100.0%, avg_loss = 1.4777
Epoch 2712, 100.0%, avg_loss = 1.4800
Epoch 2713, 100.0%, avg_loss = 1.4756
Epoch 2714, 100.0%, avg_loss = 1.4537
Epoch 2715, 100.0%, avg_loss = 1.4303
Validation Loss: 1.5303
Epoch 2716, 100.0%, avg_loss = 1.5117
Epoch 2717, 100.0%, avg_loss = 1.4723
Epoch 2718, 100.0%, avg_loss = 1.4473
Epoch 2719, 100.0%, avg_loss = 1.4298
Epoch 2720, 100.0%, avg_loss = 1.4674
Validation Loss: 1.4744
Epoch 2721, 100.0%, avg_loss = 1.4419
Epoch 2722, 100.0%, avg_loss = 1.4502
Epoch 2723, 100.0%, avg_loss = 1.4404
Epoch 2724, 100.0%, avg_loss = 1.4863
Epoch 2725, 100.0%, avg_loss = 1.4467
Validation Loss: 1.5168
Epoch 

Epoch 2894, 100.0%, avg_loss = 1.3762
Epoch 2895, 100.0%, avg_loss = 1.4499
Validation Loss: 1.4988
Epoch 2896, 100.0%, avg_loss = 1.4401
Epoch 2897, 100.0%, avg_loss = 1.4644
Epoch 2898, 100.0%, avg_loss = 1.4627
Epoch 2899, 100.0%, avg_loss = 1.4340
Epoch 2900, 100.0%, avg_loss = 1.4340
Validation Loss: 1.5172
Epoch 2901, 100.0%, avg_loss = 1.5074
Epoch 2902, 100.0%, avg_loss = 1.4661
Epoch 2903, 100.0%, avg_loss = 1.4640
Epoch 2904, 100.0%, avg_loss = 1.4278
Epoch 2905, 100.0%, avg_loss = 1.5449
Validation Loss: 1.4938
Epoch 2906, 100.0%, avg_loss = 1.4504
Epoch 2907, 100.0%, avg_loss = 1.4141
Epoch 2908, 100.0%, avg_loss = 1.4970
Epoch 2909, 100.0%, avg_loss = 1.4212
Epoch 2910, 100.0%, avg_loss = 1.4475
Validation Loss: 1.5168
Epoch 2911, 100.0%, avg_loss = 1.5107
Epoch 2912, 100.0%, avg_loss = 1.4831
Epoch 2913, 100.0%, avg_loss = 1.4480
Epoch 2914, 100.0%, avg_loss = 1.5006
Epoch 2915, 100.0%, avg_loss = 1.4743
Validation Loss: 1.4627
Epoch 2916, 100.0%, avg_loss = 1.4296
Epoch 

Epoch 3085, 100.0%, avg_loss = 1.4602
Validation Loss: 1.4540
Epoch 3086, 100.0%, avg_loss = 1.4626
Epoch 3087, 100.0%, avg_loss = 1.3636
Epoch 3088, 100.0%, avg_loss = 1.4825
Epoch 3089, 100.0%, avg_loss = 1.3811
Epoch 3090, 100.0%, avg_loss = 1.4500
Validation Loss: 1.4446
Epoch 3091, 100.0%, avg_loss = 1.4558
Epoch 3092, 100.0%, avg_loss = 1.4376
Epoch 3093, 100.0%, avg_loss = 1.4535
Epoch 3094, 100.0%, avg_loss = 1.5059
Epoch 3095, 100.0%, avg_loss = 1.5256
Validation Loss: 1.5126
Epoch 3096, 100.0%, avg_loss = 1.4536
Epoch 3097, 100.0%, avg_loss = 1.5158
Epoch 3098, 100.0%, avg_loss = 1.4572
Epoch 3099, 100.0%, avg_loss = 1.4394
Epoch 3100, 100.0%, avg_loss = 1.3826
Validation Loss: 1.4989
Epoch 3101, 100.0%, avg_loss = 1.5249
Epoch 3102, 100.0%, avg_loss = 1.4955
Epoch 3103, 100.0%, avg_loss = 1.4907
Epoch 3104, 100.0%, avg_loss = 1.3670
Epoch 3105, 100.0%, avg_loss = 1.4108
Validation Loss: 1.4835
Epoch 3106, 100.0%, avg_loss = 1.4569
Epoch 3107, 100.0%, avg_loss = 1.4869
Epoch 

Epoch 3276, 100.0%, avg_loss = 1.4535
Epoch 3277, 100.0%, avg_loss = 1.4335
Epoch 3278, 100.0%, avg_loss = 1.4090
Epoch 3279, 100.0%, avg_loss = 1.4018
Epoch 3280, 100.0%, avg_loss = 1.4462
Validation Loss: 1.4615
Epoch 3281, 100.0%, avg_loss = 1.4644
Epoch 3282, 100.0%, avg_loss = 1.4124
Epoch 3283, 100.0%, avg_loss = 1.4626
Epoch 3284, 100.0%, avg_loss = 1.4556
Epoch 3285, 100.0%, avg_loss = 1.4362
Validation Loss: 1.4584
Epoch 3286, 100.0%, avg_loss = 1.4670
Epoch 3287, 100.0%, avg_loss = 1.5412
Epoch 3288, 100.0%, avg_loss = 1.5164
Epoch 3289, 100.0%, avg_loss = 1.5370
Epoch 3290, 100.0%, avg_loss = 1.4809
Validation Loss: 1.4637
Epoch 3291, 100.0%, avg_loss = 1.4049
Epoch 3292, 100.0%, avg_loss = 1.4725
Epoch 3293, 100.0%, avg_loss = 1.4264
Epoch 3294, 100.0%, avg_loss = 1.4503
Epoch 3295, 100.0%, avg_loss = 1.3935
Validation Loss: 1.4712
Epoch 3296, 100.0%, avg_loss = 1.4599
Epoch 3297, 100.0%, avg_loss = 1.4707
Epoch 3298, 100.0%, avg_loss = 1.5047
Epoch 3299, 100.0%, avg_loss =

Epoch 3467, 100.0%, avg_loss = 1.4394
Epoch 3468, 100.0%, avg_loss = 1.4981
Epoch 3469, 100.0%, avg_loss = 1.4386
Epoch 3470, 100.0%, avg_loss = 1.4598
Validation Loss: 1.4598
Epoch 3471, 100.0%, avg_loss = 1.4263
Epoch 3472, 100.0%, avg_loss = 1.4750
Epoch 3473, 100.0%, avg_loss = 1.4467
Epoch 3474, 100.0%, avg_loss = 1.4427
Epoch 3475, 100.0%, avg_loss = 1.4711
Validation Loss: 1.4836
Epoch 3476, 100.0%, avg_loss = 1.4843
Epoch 3477, 100.0%, avg_loss = 1.3956
Epoch 3478, 100.0%, avg_loss = 1.4083
Epoch 3479, 100.0%, avg_loss = 1.5183
Epoch 3480, 100.0%, avg_loss = 1.5020
Validation Loss: 1.4688
Epoch 3481, 100.0%, avg_loss = 1.5222
Epoch 3482, 100.0%, avg_loss = 1.4893
Epoch 3483, 100.0%, avg_loss = 1.4456
Epoch 3484, 100.0%, avg_loss = 1.4522
Epoch 3485, 100.0%, avg_loss = 1.4631
Validation Loss: 1.4585
Epoch 3486, 100.0%, avg_loss = 1.4545
Epoch 3487, 100.0%, avg_loss = 1.4402
Epoch 3488, 100.0%, avg_loss = 1.4638
Epoch 3489, 100.0%, avg_loss = 1.4410
Epoch 3490, 100.0%, avg_loss =

Epoch 3658, 100.0%, avg_loss = 1.4016
Epoch 3659, 100.0%, avg_loss = 1.4906
Epoch 3660, 100.0%, avg_loss = 1.4270
Validation Loss: 1.4823
Epoch 3661, 100.0%, avg_loss = 1.4255
Epoch 3662, 100.0%, avg_loss = 1.5004
Epoch 3663, 100.0%, avg_loss = 1.4183
Epoch 3664, 100.0%, avg_loss = 1.4277
Epoch 3665, 100.0%, avg_loss = 1.4699
Validation Loss: 1.4579
Epoch 3666, 100.0%, avg_loss = 1.3990
Epoch 3667, 100.0%, avg_loss = 1.4608
Epoch 3668, 100.0%, avg_loss = 1.4466
Epoch 3669, 100.0%, avg_loss = 1.5053
Epoch 3670, 100.0%, avg_loss = 1.3599
Validation Loss: 1.4587
Epoch 3671, 100.0%, avg_loss = 1.4455
Epoch 3672, 100.0%, avg_loss = 1.4049
Epoch 3673, 100.0%, avg_loss = 1.4475
Epoch 3674, 100.0%, avg_loss = 1.4269
Epoch 3675, 100.0%, avg_loss = 1.4229
Validation Loss: 1.4979
Epoch 3676, 100.0%, avg_loss = 1.4294
Epoch 3677, 100.0%, avg_loss = 1.4264
Epoch 3678, 100.0%, avg_loss = 1.4038
Epoch 3679, 100.0%, avg_loss = 1.4406
Epoch 3680, 100.0%, avg_loss = 1.5071
Validation Loss: 1.4562
Epoch 

Epoch 3849, 100.0%, avg_loss = 1.4357
Epoch 3850, 100.0%, avg_loss = 1.4581
Validation Loss: 1.4631
Epoch 3851, 100.0%, avg_loss = 1.4784
Epoch 3852, 100.0%, avg_loss = 1.4881
Epoch 3853, 100.0%, avg_loss = 1.3773
Epoch 3854, 100.0%, avg_loss = 1.4311
Epoch 3855, 100.0%, avg_loss = 1.4714
Validation Loss: 1.4698
Epoch 3856, 100.0%, avg_loss = 1.4436
Epoch 3857, 100.0%, avg_loss = 1.4968
Epoch 3858, 100.0%, avg_loss = 1.4694
Epoch 3859, 100.0%, avg_loss = 1.3975
Epoch 3860, 100.0%, avg_loss = 1.4555
Validation Loss: 1.5038
Epoch 3861, 100.0%, avg_loss = 1.4310
Epoch 3862, 100.0%, avg_loss = 1.4407
Epoch 3863, 100.0%, avg_loss = 1.4466
Epoch 3864, 100.0%, avg_loss = 1.4191
Epoch 3865, 100.0%, avg_loss = 1.4030
Validation Loss: 1.4570
Epoch 3866, 100.0%, avg_loss = 1.4569
Epoch 3867, 100.0%, avg_loss = 1.4905
Epoch 3868, 100.0%, avg_loss = 1.4650
Epoch 3869, 100.0%, avg_loss = 1.4070
Epoch 3870, 100.0%, avg_loss = 1.4161
Validation Loss: 1.5067
Epoch 3871, 100.0%, avg_loss = 1.4148
Epoch 

In [13]:
torch.save(local_predictor, 'local_predictor_fullsearch_fl_scratch.pytorch')

In [14]:
with open('./federated_local_predictor_fullsearch_training_loss.pk', 'wb') as f:
    pk.dump(training_loss, f)
with open('./federated_local_predictor_fullsearch_validation_loss.pk', 'wb') as f:
    pk.dump(validation_loss, f)