In [1]:
import torch
import torch.nn as nn
import pickle as pk
import numpy as np
from model import LocalPredictor
import random

In [2]:
data = dict({})
for d in range(1, 29):
    filename = '/data/fan/UsersInOsakaProcessed/201211{:02d}_interp.pk'.format(d)
    print(filename)
    with open(filename, 'rb') as f:
        data[d] = pk.load(f)

/data/fan/UsersInOsakaProcessed/20121101_interp.pk
/data/fan/UsersInOsakaProcessed/20121102_interp.pk
/data/fan/UsersInOsakaProcessed/20121103_interp.pk
/data/fan/UsersInOsakaProcessed/20121104_interp.pk
/data/fan/UsersInOsakaProcessed/20121105_interp.pk
/data/fan/UsersInOsakaProcessed/20121106_interp.pk
/data/fan/UsersInOsakaProcessed/20121107_interp.pk
/data/fan/UsersInOsakaProcessed/20121108_interp.pk
/data/fan/UsersInOsakaProcessed/20121109_interp.pk
/data/fan/UsersInOsakaProcessed/20121110_interp.pk
/data/fan/UsersInOsakaProcessed/20121111_interp.pk
/data/fan/UsersInOsakaProcessed/20121112_interp.pk
/data/fan/UsersInOsakaProcessed/20121113_interp.pk
/data/fan/UsersInOsakaProcessed/20121114_interp.pk
/data/fan/UsersInOsakaProcessed/20121115_interp.pk
/data/fan/UsersInOsakaProcessed/20121116_interp.pk
/data/fan/UsersInOsakaProcessed/20121117_interp.pk
/data/fan/UsersInOsakaProcessed/20121118_interp.pk
/data/fan/UsersInOsakaProcessed/20121119_interp.pk
/data/fan/UsersInOsakaProcessed

In [3]:
uid_sets = dict({})
for d in data:
    uid_sets[d] = set(list(data[d].keys()))

In [4]:
uid_doc_set = set([])
for d in range(1, 15):
    uid_doc_set = uid_doc_set | uid_sets[d]

In [5]:
data_doc = dict({})
data_qry_train = dict({})
data_qry_test = dict({})

for d in range(1, 15):
    for uid in data[d]:
        if uid not in data_doc:
            data_doc[uid] = [data[d][uid]]
        else:
            data_doc[uid].append(data[d][uid])

for d in range(15, 22):
    for uid in data[d]:
        if uid not in data_qry_train:
            data_qry_train[uid] = [data[d][uid]]
        else:
            data_qry_train[uid].append(data[d][uid])
            
for d in range(22, 29):
    for uid in data[d]:
        if uid not in data_qry_test:
            data_qry_test[uid] = [data[d][uid]]
        else:
            data_qry_test[uid].append(data[d][uid])

In [6]:
for uid in data_doc:
    data_doc[uid] = torch.LongTensor(data_doc[uid]).cuda(1)

for uid in data_qry_train:
    data_qry_train[uid] = torch.LongTensor(data_qry_train[uid]).cuda(1)
    
for uid in data_qry_test:
    data_qry_test[uid] = torch.LongTensor(data_qry_test[uid]).cuda(1)

In [7]:
num_locs = 1600
loc_embedding_dim = 128
T = 96
num_time = T
time_embedding_dim = 32
hidden_dim = 256
latent_dim = 256

In [8]:
lr = 1e-3
momentum = 0.8

In [9]:
local_predictor = LocalPredictor(num_locs, loc_embedding_dim, num_time, time_embedding_dim, hidden_dim, latent_dim).cuda(1)
local_predictor_server = LocalPredictor(num_locs, loc_embedding_dim, num_time, time_embedding_dim, hidden_dim, latent_dim).cuda(1)
optimizer = torch.optim.SGD(local_predictor.parameters(), lr=lr)
optimizer.zero_grad()

In [10]:
dT = 4

In [11]:
training_loss = dict({})
validation_loss = dict({})

In [None]:
user_list_train = list(data_qry_train.keys())
time_list = list(range(T - 2 * dT + 1))

# Initialize global SGD
local_predictor_update = dict({})
        
for epoch in range(1, 4001):
    
    if epoch % 500 == 0:
        lr *= 0.5
    
    update_user_list = random.sample(user_list_train, 2048)
    random.shuffle(update_user_list)
    
    avg_loss = 0.0
    cnt = 0
    n = 0
    
    uidx = 0
    
    for uid in update_user_list:
        uidx += 1
        # loading server model to local
        local_predictor.load_state_dict(local_predictor_server.state_dict())

        nk = data_qry_train[uid].shape[0]
        n += data_qry_train[uid].shape[0]

        t = np.random.randint(T - 2 * dT + 1)
        x_loc_qry = data_qry_train[uid][:, t: t + dT]
        x_t_qry = torch.zeros_like(x_loc_qry) + t
        y = data_qry_train[uid][:, t + 2 * dT - 1]

        if uid not in data_doc:
            loss = local_predictor(x_loc_qry, x_t_qry, None, None, y)
        else:
            x_loc_doc = data_doc[uid][:, t: t + 2 * dT]
            x_t_doc = torch.zeros_like(x_loc_doc) + t
            loss = local_predictor(x_loc_qry, x_t_qry, x_loc_doc, x_t_doc, y)

        loss.backward()
        avg_loss += loss.item()
        cnt += nk
        
        with torch.no_grad():
            for name, para in local_predictor.named_parameters():
                if para.grad is None:
                    continue
                elif name not in local_predictor_update:
                    local_predictor_update[name] = (para.grad + torch.randn_like(para.grad) * 1e-2) * lr
                else:
                    local_predictor_update[name] += (para.grad + torch.randn_like(para.grad) * 1e-2) * lr
                
        optimizer.zero_grad()
        
        print('Epoch {:02d}, {:.1f}%, avg_loss = {:.4f}'.format(epoch, uidx * 100 / len(update_user_list), avg_loss / cnt), end='\r')

    # update
    with torch.no_grad():
        for name, para in local_predictor_server.named_parameters():
            para -= local_predictor_update[name]
            
        for name in local_predictor_update:
            local_predictor_update[name] *= momentum
    
    training_loss[epoch] = avg_loss / cnt
    print('')
    # testing
    
    if epoch % 5 == 0:
        with torch.no_grad():

            cnt = 0
            avg_loss = 0.0

            for uid in data_qry_test:

                if np.random.ranf() > 0.05:
                    continue

                nk = data_qry_test[uid].shape[0]

                t = np.random.randint(T - 2 * dT + 1)
                x_loc_qry = data_qry_test[uid][:, t: t + dT]
                x_t_qry = torch.zeros_like(x_loc_qry) + t
                y = data_qry_test[uid][:, t + 2 * dT - 1]

                if uid not in data_doc:
                    loss = local_predictor_server(x_loc_qry, x_t_qry, None, None, y)
                else:
                    x_loc_doc = data_doc[uid][:, t: t + 2 * dT]
                    x_t_doc = torch.zeros_like(x_loc_doc) + t
                    loss = local_predictor_server(x_loc_qry, x_t_qry, x_loc_doc, x_t_doc, y)

                avg_loss += loss.item()
                cnt += nk

            print('Validation Loss: {:.4f}'.format(avg_loss / cnt))
            validation_loss[epoch] = avg_loss / cnt

In [12]:
user_list_train = list(data_qry_train.keys())
time_list = list(range(T - 2 * dT + 1))

# Initialize global SGD
local_predictor_update = dict({})
        
for epoch in range(1, 4001):
    
    if epoch % 500 == 0:
        lr *= 0.5
    
    update_user_list = random.sample(user_list_train, 2048)
    random.shuffle(update_user_list)
    
    avg_loss = 0.0
    cnt = 0
    n = 0
    
    uidx = 0
    
    for uid in update_user_list:
        uidx += 1
        # loading server model to local
        local_predictor.load_state_dict(local_predictor_server.state_dict())

        nk = data_qry_train[uid].shape[0]
        n += data_qry_train[uid].shape[0]

        t = np.random.randint(T - 2 * dT + 1)
        x_loc_qry = data_qry_train[uid][:, t: t + dT]
        x_t_qry = torch.zeros_like(x_loc_qry) + t
        y = data_qry_train[uid][:, t + 2 * dT - 1]

        if uid not in data_doc:
            loss = local_predictor(x_loc_qry, x_t_qry, None, None, y)
        else:
            x_loc_doc = data_doc[uid][:, t: t + 2 * dT]
            x_t_doc = torch.zeros_like(x_loc_doc) + t
            loss = local_predictor(x_loc_qry, x_t_qry, x_loc_doc, x_t_doc, y)

        loss.backward()
        avg_loss += loss.item()
        cnt += nk
        
        with torch.no_grad():
            for name, para in local_predictor.named_parameters():
                if para.grad is None:
                    continue
                elif name not in local_predictor_update:
                    local_predictor_update[name] = (para.grad + torch.randn_like(para.grad) * 1e-2) * lr
                else:
                    local_predictor_update[name] += (para.grad + torch.randn_like(para.grad) * 1e-2) * lr
                
        optimizer.zero_grad()
        
        print('Epoch {:02d}, {:.1f}%, avg_loss = {:.4f}'.format(epoch, uidx * 100 / len(update_user_list), avg_loss / cnt), end='\r')

    # update
    with torch.no_grad():
        for name, para in local_predictor_server.named_parameters():
            para -= local_predictor_update[name]
            
        for name in local_predictor_update:
            local_predictor_update[name] *= momentum
    
    training_loss[epoch] = avg_loss / cnt
    print('')
    # testing
    
    if epoch % 5 == 0:
        with torch.no_grad():

            cnt = 0
            avg_loss = 0.0

            for uid in data_qry_test:

                if np.random.ranf() > 0.05:
                    continue

                nk = data_qry_test[uid].shape[0]

                t = np.random.randint(T - 2 * dT + 1)
                x_loc_qry = data_qry_test[uid][:, t: t + dT]
                x_t_qry = torch.zeros_like(x_loc_qry) + t
                y = data_qry_test[uid][:, t + 2 * dT - 1]

                if uid not in data_doc:
                    loss = local_predictor_server(x_loc_qry, x_t_qry, None, None, y)
                else:
                    x_loc_doc = data_doc[uid][:, t: t + 2 * dT]
                    x_t_doc = torch.zeros_like(x_loc_doc) + t
                    loss = local_predictor_server(x_loc_qry, x_t_qry, x_loc_doc, x_t_doc, y)

                avg_loss += loss.item()
                cnt += nk

            print('Validation Loss: {:.4f}'.format(avg_loss / cnt))
            validation_loss[epoch] = avg_loss / cnt

Epoch 01, 100.0%, avg_loss = 7.3795
Epoch 02, 100.0%, avg_loss = 7.2845
Epoch 03, 100.0%, avg_loss = 7.0945
Epoch 04, 100.0%, avg_loss = 6.8151
Epoch 05, 100.0%, avg_loss = 6.4295
Validation Loss: 6.2226
Epoch 06, 100.0%, avg_loss = 6.1004
Epoch 07, 100.0%, avg_loss = 5.6692
Epoch 08, 100.0%, avg_loss = 5.1676
Epoch 09, 100.0%, avg_loss = 4.6220
Epoch 10, 100.0%, avg_loss = 4.0262
Validation Loss: 3.6390
Epoch 11, 100.0%, avg_loss = 3.6241
Epoch 12, 100.0%, avg_loss = 3.3216
Epoch 13, 100.0%, avg_loss = 3.1746
Epoch 14, 100.0%, avg_loss = 2.9315
Epoch 15, 100.0%, avg_loss = 2.7485
Validation Loss: 2.8675
Epoch 16, 100.0%, avg_loss = 2.8898
Epoch 17, 100.0%, avg_loss = 2.6990
Epoch 18, 100.0%, avg_loss = 2.4582
Epoch 19, 100.0%, avg_loss = 2.5521
Epoch 20, 100.0%, avg_loss = 2.5039
Validation Loss: 2.5392
Epoch 21, 100.0%, avg_loss = 2.5431
Epoch 22, 100.0%, avg_loss = 2.3621
Epoch 23, 100.0%, avg_loss = 2.3590
Epoch 24, 100.0%, avg_loss = 2.4526
Epoch 25, 100.0%, avg_loss = 2.3812
Vali

Epoch 395, 100.0%, avg_loss = 1.4057
Validation Loss: 1.4092
Epoch 396, 100.0%, avg_loss = 1.3064
Epoch 397, 100.0%, avg_loss = 1.4039
Epoch 398, 100.0%, avg_loss = 1.3944
Epoch 399, 100.0%, avg_loss = 1.3503
Epoch 400, 100.0%, avg_loss = 1.4623
Validation Loss: 1.5816
Epoch 401, 100.0%, avg_loss = 1.3510
Epoch 402, 100.0%, avg_loss = 1.3711
Epoch 403, 100.0%, avg_loss = 1.4255
Epoch 404, 100.0%, avg_loss = 1.3563
Epoch 405, 100.0%, avg_loss = 1.3539
Validation Loss: 1.5343
Epoch 406, 100.0%, avg_loss = 1.4452
Epoch 407, 100.0%, avg_loss = 1.3553
Epoch 408, 100.0%, avg_loss = 1.4289
Epoch 409, 100.0%, avg_loss = 1.4356
Epoch 410, 100.0%, avg_loss = 1.3561
Validation Loss: 1.4788
Epoch 411, 100.0%, avg_loss = 1.3404
Epoch 412, 100.0%, avg_loss = 1.3970
Epoch 413, 100.0%, avg_loss = 1.3426
Epoch 414, 100.0%, avg_loss = 1.4584
Epoch 415, 100.0%, avg_loss = 1.3622
Validation Loss: 1.5212
Epoch 416, 100.0%, avg_loss = 1.4012
Epoch 417, 100.0%, avg_loss = 1.3927
Epoch 418, 100.0%, avg_loss =

Epoch 591, 100.0%, avg_loss = 1.2875
Epoch 592, 100.0%, avg_loss = 1.2661
Epoch 593, 100.0%, avg_loss = 1.2996
Epoch 594, 100.0%, avg_loss = 1.2833
Epoch 595, 100.0%, avg_loss = 1.2960
Validation Loss: 1.3589
Epoch 596, 100.0%, avg_loss = 1.3104
Epoch 597, 100.0%, avg_loss = 1.2924
Epoch 598, 100.0%, avg_loss = 1.3050
Epoch 599, 100.0%, avg_loss = 1.3361
Epoch 600, 100.0%, avg_loss = 1.3043
Validation Loss: 1.4089
Epoch 601, 100.0%, avg_loss = 1.2698
Epoch 602, 100.0%, avg_loss = 1.2801
Epoch 603, 100.0%, avg_loss = 1.3178
Epoch 604, 100.0%, avg_loss = 1.2584
Epoch 605, 100.0%, avg_loss = 1.3624
Validation Loss: 1.3398
Epoch 606, 100.0%, avg_loss = 1.2950
Epoch 607, 100.0%, avg_loss = 1.2881
Epoch 608, 100.0%, avg_loss = 1.3246
Epoch 609, 100.0%, avg_loss = 1.2869
Epoch 610, 100.0%, avg_loss = 1.3117
Validation Loss: 1.3802
Epoch 611, 100.0%, avg_loss = 1.1984
Epoch 612, 100.0%, avg_loss = 1.2866
Epoch 613, 100.0%, avg_loss = 1.3049
Epoch 614, 100.0%, avg_loss = 1.2946
Epoch 615, 100.0

Epoch 983, 100.0%, avg_loss = 1.1876
Epoch 984, 100.0%, avg_loss = 1.2671
Epoch 985, 100.0%, avg_loss = 1.1893
Validation Loss: 1.3720
Epoch 986, 100.0%, avg_loss = 1.2440
Epoch 987, 100.0%, avg_loss = 1.2440
Epoch 988, 100.0%, avg_loss = 1.2135
Epoch 989, 100.0%, avg_loss = 1.2088
Epoch 990, 100.0%, avg_loss = 1.1938
Validation Loss: 1.3775
Epoch 991, 100.0%, avg_loss = 1.2609
Epoch 992, 100.0%, avg_loss = 1.2485
Epoch 993, 100.0%, avg_loss = 1.2475
Epoch 994, 100.0%, avg_loss = 1.2055
Epoch 995, 100.0%, avg_loss = 1.2045
Validation Loss: 1.4022
Epoch 996, 100.0%, avg_loss = 1.2897
Epoch 997, 100.0%, avg_loss = 1.2271
Epoch 998, 100.0%, avg_loss = 1.2637
Epoch 999, 100.0%, avg_loss = 1.2980
Epoch 1000, 100.0%, avg_loss = 1.2180
Validation Loss: 1.3720
Epoch 1001, 100.0%, avg_loss = 1.2401
Epoch 1002, 100.0%, avg_loss = 1.2503
Epoch 1003, 100.0%, avg_loss = 1.2551
Epoch 1004, 100.0%, avg_loss = 1.2789
Epoch 1005, 100.0%, avg_loss = 1.2428
Validation Loss: 1.3565
Epoch 1006, 100.0%, avg

Epoch 1366, 100.0%, avg_loss = 1.2285
Epoch 1367, 100.0%, avg_loss = 1.2376
Epoch 1368, 100.0%, avg_loss = 1.1705
Epoch 1369, 100.0%, avg_loss = 1.2350
Epoch 1370, 100.0%, avg_loss = 1.1745
Validation Loss: 1.3574
Epoch 1371, 100.0%, avg_loss = 1.2365
Epoch 1372, 100.0%, avg_loss = 1.2423
Epoch 1373, 100.0%, avg_loss = 1.2293
Epoch 1374, 100.0%, avg_loss = 1.2388
Epoch 1375, 100.0%, avg_loss = 1.1881
Validation Loss: 1.2935
Epoch 1376, 100.0%, avg_loss = 1.2321
Epoch 1377, 100.0%, avg_loss = 1.2024
Epoch 1378, 100.0%, avg_loss = 1.2641
Epoch 1379, 100.0%, avg_loss = 1.1927
Epoch 1380, 100.0%, avg_loss = 1.2380
Validation Loss: 1.2974
Epoch 1381, 100.0%, avg_loss = 1.2434
Epoch 1382, 100.0%, avg_loss = 1.1899
Epoch 1383, 100.0%, avg_loss = 1.2242
Epoch 1384, 100.0%, avg_loss = 1.2086
Epoch 1385, 100.0%, avg_loss = 1.2707
Validation Loss: 1.3219
Epoch 1386, 100.0%, avg_loss = 1.1939
Epoch 1387, 100.0%, avg_loss = 1.2434
Epoch 1388, 100.0%, avg_loss = 1.1535
Epoch 1389, 100.0%, avg_loss =

Epoch 1748, 100.0%, avg_loss = 1.1910
Epoch 1749, 100.0%, avg_loss = 1.2647
Epoch 1750, 100.0%, avg_loss = 1.2155
Validation Loss: 1.2682
Epoch 1751, 100.0%, avg_loss = 1.1618
Epoch 1752, 100.0%, avg_loss = 1.1956
Epoch 1753, 100.0%, avg_loss = 1.1596
Epoch 1754, 100.0%, avg_loss = 1.1998
Epoch 1755, 100.0%, avg_loss = 1.2106
Validation Loss: 1.2924
Epoch 1756, 100.0%, avg_loss = 1.1527
Epoch 1757, 100.0%, avg_loss = 1.2522
Epoch 1758, 100.0%, avg_loss = 1.1869
Epoch 1759, 100.0%, avg_loss = 1.2320
Epoch 1760, 100.0%, avg_loss = 1.2105
Validation Loss: 1.2821
Epoch 1761, 100.0%, avg_loss = 1.2350
Epoch 1762, 100.0%, avg_loss = 1.2270
Epoch 1763, 100.0%, avg_loss = 1.2238
Epoch 1764, 100.0%, avg_loss = 1.2133
Epoch 1765, 100.0%, avg_loss = 1.2047
Validation Loss: 1.3127
Epoch 1766, 100.0%, avg_loss = 1.1847
Epoch 1767, 100.0%, avg_loss = 1.1392
Epoch 1768, 100.0%, avg_loss = 1.2507
Epoch 1769, 100.0%, avg_loss = 1.1781
Epoch 1770, 100.0%, avg_loss = 1.1879
Validation Loss: 1.2605
Epoch 

KeyboardInterrupt: 

In [13]:
torch.save(local_predictor, './results_osaka/local_predictor_fl_scratch_noise001.pytorch')

In [14]:
with open('./results_osaka/federated_local_predictor_noise001_training_loss.pk', 'wb') as f:
    pk.dump(training_loss, f)
with open('./results_osaka/federated_local_predictor_noise001_validation_loss.pk', 'wb') as f:
    pk.dump(validation_loss, f)