In [1]:
import torch

from torch.utils.data import DataLoader, Dataset


import pandas as pd
import numpy as np
from gmf import GMFEngine, GMF
from mlp import MLPEngine
from neumf import NeuMFEngine
from data import SampleGenerator
from metrics import MetronAtK

import copy

In [2]:


def evaluate_hit_ndcg(model, evaluate_data):    
    model.eval()
    with torch.no_grad():
        test_users, test_items = evaluate_data[0], evaluate_data[1]
        negative_users, negative_items = evaluate_data[2], evaluate_data[3]        
        test_scores = model(test_users, test_items)
        negative_scores = model(negative_users, negative_items)
        
        metron = MetronAtK(top_k=10)
        
        metron.subjects = [test_users.data.view(-1).tolist(),
                             test_items.data.view(-1).tolist(),
                             test_scores.data.view(-1).tolist(),
                             negative_users.data.view(-1).tolist(),
                             negative_items.data.view(-1).tolist(),
                             negative_scores.data.view(-1).tolist()]
    hit_ratio, ndcg = metron.cal_hit_ratio(), metron.cal_ndcg()
        
    
    return hit_ratio, ndcg


def eval_ce_loss(model, data_loader):
    total_loss = 0.0
    n_batches = 0
    model.eval()
    with torch.no_grad():
        for batch in data_loader:
            users, items, ratings = batch[0], batch[1], batch[2]
            ratings = ratings.float()
            ratings_pred = model(users, items)
            crit = torch.nn.BCELoss()
            loss = crit(ratings_pred.view(-1), ratings)

            total_loss += loss.item()
            n_batches += 1
    
    return total_loss / n_batches

In [3]:
ml1m_dir = 'data/ml-100k/u.data'
ml1m_rating = pd.read_csv(ml1m_dir, sep='\t', header=None, names=['uid', 'mid', 'rating', 'timestamp'],  engine='python')
# Reindex
user_id = ml1m_rating[['uid']].drop_duplicates().reindex()
user_id['userId'] = np.arange(len(user_id))
ml1m_rating = pd.merge(ml1m_rating, user_id, on=['uid'], how='left')
item_id = ml1m_rating[['mid']].drop_duplicates()
item_id['itemId'] = np.arange(len(item_id))
ml1m_rating = pd.merge(ml1m_rating, item_id, on=['mid'], how='left')
ml1m_rating = ml1m_rating[['userId', 'itemId', 'rating', 'timestamp']]
print('Range of userId is [{}, {}]'.format(ml1m_rating.userId.min(), ml1m_rating.userId.max()))
print('Range of itemId is [{}, {}]'.format(ml1m_rating.itemId.min(), ml1m_rating.itemId.max()))

Range of userId is [0, 942]
Range of itemId is [0, 1681]


In [4]:
num_users = len(ml1m_rating.userId.unique())
num_items = len(ml1m_rating.itemId.unique())

gmf_config = {
              'num_users': num_users,
              'num_items': num_items,
              'latent_dim': 10,
              'num_negative': 5,
              'l2_regularization': 0, # 0.01
              'use_cuda': False,
              'device_id': 0,
              'model_dir':'checkpoints/{}_Epoch{}_HR{:.4f}_NDCG{:.4f}.model'}





## federated training

In [5]:
class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        return self.dataset[self.idxs[item]]


class LocalUpdate:
    def __init__(self, model_global, dataset, user_id):
        idxs = torch.where(dataset.user_tensor == user_id)[0]
        self.local_dataset = DatasetSplit(dataset, idxs)
        self.num_local = len(self.local_dataset)        
        self.local_model = copy.deepcopy(model_global)
        self.user_id = user_id
        
        
    def train(self, lr, epoches_local, local_batch_size):
        local_data_loader = DataLoader(self.local_dataset, 
                                       batch_size=local_batch_size, 
                                       shuffle=True)
        
        self.local_model.train()
        opt = torch.optim.Adam(self.local_model.parameters(), lr=lr)
        for ep_local in range(epoches_local):
            t_loss = 0.0
            n_batch = 0
            for users, items, ratings in local_data_loader:
                opt.zero_grad()
                ratings = ratings.float()
                ratings_pred = self.local_model(users, items).view(-1)
                
                crit = torch.nn.BCELoss()
                loss = crit(ratings_pred, ratings)
                loss.backward()
                opt.step()
                
                t_loss += loss.item()
                n_batch += 1
            
            if (ep_local + 1) % 100  == 0:
                print("user", self.user_id, "ep_local", ep_local, t_loss/n_batch)
        
        return t_loss/n_batch


In [6]:
sample_generator = SampleGenerator(ratings=ml1m_rating)

In [7]:
model_global = GMF(gmf_config)

In [8]:
# def train_fedavg(model_global, sample_generator):
frac = 0.1
epoches_local = 3
local_lr = 1e-2

for epoch in range(500):
    train_loader = sample_generator.instance_a_train_loader(num_negatives=gmf_config['num_negative'], 
                                                            batch_size=1024)
    train_dataset = train_loader.dataset
    evaluate_data = sample_generator.evaluate_data

    # sample users
    m = max(int(frac * num_users), 1)
    idxs_users = np.random.choice(range(num_users), m, replace=False)

    w_locals, n_locals = [], []
    total_loss = 0.0
    for idx in idxs_users:        
        # local update
        local_update = LocalUpdate(model_global, train_dataset, user_id=idx)
        local_batch_size = int(len(local_update.local_dataset)*0.1)
        total_loss += local_update.train(local_lr, epoches_local, local_batch_size=local_batch_size)

        w_locals.append(copy.deepcopy(local_update.local_model.state_dict()))
        n_locals.append(local_update.num_local)


    
    # aggregated model
    w_avg = {}
    
    # keys of user and item embeddings in the model state dict
    k_u, k_i = 'embedding_user.weight', 'embedding_item.weight'
    w_global = model_global.state_dict()
    
    # aggregated user embeddings
    # line 13 - 16 of Algorithm 3
    # user embedding: the k-th chosen user only updates its own user embedding
    w_avg[k_u] = copy.deepcopy(w_global[k_u])
    for k in range(m):
        w_avg[k_u][idxs_users[k]] = w_locals[k][k_u][idxs_users[k]]
            
    # contribution of each user to all items, measured by L1 distance between the 
    # local item embeddings to the global item embeddings
    # line 6 - 9 of Algorithm 3
    contrib = torch.zeros(m, num_items)
    for k in range(m):
        contrib[k] = (w_global[k_i] - w_locals[k][k_i]).abs().sum(1)

    eps = 1e-10
    contrib = contrib / (contrib.sum(0) + eps)
    
    # aggregate item embeddings: line 10 - 12 of Algorithm 3
    w_avg[k_i] = copy.deepcopy(w_global[k_i])
    for i in range(num_items):
        if contrib[:, i].sum() == 0:
            # item i is not updated by any chosen user, keep the original embedding
            continue
        
        w_avg[k_i][i] *= 0
        for k in range(m):
            # aggregated embedding of item i: weighted by the contribution of each user
            w_avg[k_i][i] += w_locals[k][k_i][i] * contrib[k][i]
    
    # aggregate other model parameters
    # line 2 - 3 of algorithm 3
    n_locals = np.array(n_locals, dtype=np.float64)
    n_locals /= n_locals.sum()
    for k in w_global.keys():
        if k != k_u and k != k_i:
            w_avg[k] = 0.0
            for i in range(len(w_locals)):
                w_avg[k] += w_locals[i][k] * n_locals[i]
    

    model_global.load_state_dict(w_avg)
    hit_ratio, ndcg = evaluate_hit_ndcg(model_global, evaluate_data)

    total_loss_global = eval_ce_loss(model_global, train_loader)
    print(epoch, "local_loss", total_loss/m, "total_loss_global", total_loss_global, 
          "hit_ratio", hit_ratio, "ndcg", ndcg)






A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_in_top_k['ndcg'] = test_in_top_k['rank'].apply(lambda x: math.log(2) / math.log(1 + x)) # the rank starts from 1


0 local_loss 0.6179362699504747 total_loss_global 0.6475661893514676 hit_ratio 0.19618239660657477 ndcg 0.09073737534794556
1 local_loss 0.5446828639685992 total_loss_global 0.5654165337500187 hit_ratio 0.19830328738069988 ndcg 0.09026113888950565
2 local_loss 0.49065841142109784 total_loss_global 0.5095021511827197 hit_ratio 0.2067868504772004 ndcg 0.09186212939027832
3 local_loss 0.458707283362314 total_loss_global 0.47738166889512396 hit_ratio 0.20254506892895016 ndcg 0.09649998715824556
4 local_loss 0.43598022367553046 total_loss_global 0.4609015198254544 hit_ratio 0.21208907741251326 ndcg 0.09435351147607289
5 local_loss 0.4321984362584942 total_loss_global 0.45571748140346574 hit_ratio 0.2110286320254507 ndcg 0.09112216115582966
6 local_loss 0.42983504326246696 total_loss_global 0.4553811625543847 hit_ratio 0.20360551431601273 ndcg 0.09386684921890157
7 local_loss 0.43104893537940797 total_loss_global 0.4555468778294253 hit_ratio 0.20996818663838812 ndcg 0.0935092541494785
8 loca

66 local_loss 0.30977111977076527 total_loss_global 0.38583547492445847 hit_ratio 0.5206786850477201 ndcg 0.28337602408680335
67 local_loss 0.3082049751402071 total_loss_global 0.38422145352100956 hit_ratio 0.5206786850477201 ndcg 0.28651072289751356
68 local_loss 0.30879831564607285 total_loss_global 0.38299355690532627 hit_ratio 0.528101802757158 ndcg 0.2894538970794147
69 local_loss 0.31184220848392474 total_loss_global 0.3818929615509079 hit_ratio 0.5344644750795334 ndcg 0.29237362976121406
70 local_loss 0.2973967442341235 total_loss_global 0.38067944365811635 hit_ratio 0.5323435843054083 ndcg 0.2931042931514285
71 local_loss 0.3054746288006381 total_loss_global 0.3794966309801026 hit_ratio 0.5397667020148462 ndcg 0.2966719487434487
72 local_loss 0.31267589258671347 total_loss_global 0.37847980364464645 hit_ratio 0.5387062566277837 ndcg 0.2963291914867383
73 local_loss 0.3109090305018155 total_loss_global 0.375865135053349 hit_ratio 0.5387062566277837 ndcg 0.2971632298380665
74 loc

132 local_loss 0.29631959335586655 total_loss_global 0.3308452825845827 hit_ratio 0.6118769883351007 ndcg 0.3448083148365382
133 local_loss 0.2897401956949434 total_loss_global 0.3300653512424528 hit_ratio 0.6118769883351007 ndcg 0.3466830709561086
134 local_loss 0.30010487952042625 total_loss_global 0.3309326547866434 hit_ratio 0.6139978791092259 ndcg 0.3481694468384065
135 local_loss 0.2930338955489734 total_loss_global 0.330042617405138 hit_ratio 0.616118769883351 ndcg 0.35019009475220175
136 local_loss 0.28546486335141824 total_loss_global 0.3291516212711236 hit_ratio 0.6182396606574762 ndcg 0.35419682642490724
137 local_loss 0.2877586657068796 total_loss_global 0.3290502376687711 hit_ratio 0.616118769883351 ndcg 0.3514891444594073
138 local_loss 0.2922534888698512 total_loss_global 0.3286843860005758 hit_ratio 0.6182396606574762 ndcg 0.35390318937040965
139 local_loss 0.29552125254793665 total_loss_global 0.32853690374533606 hit_ratio 0.623541887592789 ndcg 0.3562721959860389
140 

198 local_loss 0.27803767324466705 total_loss_global 0.3108871708946261 hit_ratio 0.6670201484623541 ndcg 0.3970116585781533
199 local_loss 0.27127854610312335 total_loss_global 0.3114147666399327 hit_ratio 0.6765641569459173 ndcg 0.4003247901868654
200 local_loss 0.25972111120475977 total_loss_global 0.3108364850180079 hit_ratio 0.679745493107105 ndcg 0.40280394534671965
201 local_loss 0.2557495825672686 total_loss_global 0.3105658928826836 hit_ratio 0.6829268292682927 ndcg 0.40466207779927515
202 local_loss 0.2638026675458896 total_loss_global 0.31125703784869585 hit_ratio 0.6871686108165429 ndcg 0.40560121593578935
203 local_loss 0.2712420144437216 total_loss_global 0.31310416196381574 hit_ratio 0.689289501590668 ndcg 0.40581955580816675
204 local_loss 0.26424376515255493 total_loss_global 0.3117653398218335 hit_ratio 0.6914103923647932 ndcg 0.407958065553307
205 local_loss 0.26162636418461205 total_loss_global 0.3103087051497474 hit_ratio 0.6956521739130435 ndcg 0.40738124080978577

264 local_loss 0.2497757671442337 total_loss_global 0.2912059727665069 hit_ratio 0.721102863202545 ndcg 0.4259513359498229
265 local_loss 0.2516055230970892 total_loss_global 0.29221297369663973 hit_ratio 0.7232237539766702 ndcg 0.4263885368589603
266 local_loss 0.25648000008179817 total_loss_global 0.29018286568162366 hit_ratio 0.7200424178154825 ndcg 0.4241824804227965
267 local_loss 0.2480975783780862 total_loss_global 0.2892911473414163 hit_ratio 0.7158006362672322 ndcg 0.42435133036296624
268 local_loss 0.2529352743565456 total_loss_global 0.28874625655439345 hit_ratio 0.71898197242842 ndcg 0.4259861361206685
269 local_loss 0.25200258430203404 total_loss_global 0.28948708768555714 hit_ratio 0.721102863202545 ndcg 0.42554621480365845
270 local_loss 0.25487944472119994 total_loss_global 0.28863921179315927 hit_ratio 0.7200424178154825 ndcg 0.42537663513844404
271 local_loss 0.24971787389448877 total_loss_global 0.2902675530329834 hit_ratio 0.7221633085896076 ndcg 0.4275747672484179


330 local_loss 0.2393723563828849 total_loss_global 0.2783767409148192 hit_ratio 0.7454931071049841 ndcg 0.44302069930297266
331 local_loss 0.24017963054233246 total_loss_global 0.2782075355364821 hit_ratio 0.7401908801696713 ndcg 0.441271786589976
332 local_loss 0.25612263023669557 total_loss_global 0.2783040553074901 hit_ratio 0.7486744432661718 ndcg 0.44300148117254534
333 local_loss 0.2515379355135023 total_loss_global 0.27751900591661516 hit_ratio 0.7497348886532343 ndcg 0.44297451663102394
334 local_loss 0.24357529129802366 total_loss_global 0.2778279050646058 hit_ratio 0.7454931071049841 ndcg 0.44086262592363823
335 local_loss 0.23486822563378018 total_loss_global 0.27872002730065903 hit_ratio 0.7497348886532343 ndcg 0.44330594325423345
336 local_loss 0.24468440236097036 total_loss_global 0.2801573098270085 hit_ratio 0.750795334040297 ndcg 0.44336276253417584
337 local_loss 0.25014629174767833 total_loss_global 0.27908832063055283 hit_ratio 0.7539766702014846 ndcg 0.444545880460

396 local_loss 0.24371458967817541 total_loss_global 0.27310678144843153 hit_ratio 0.76033934252386 ndcg 0.45251733613911643
397 local_loss 0.23395787940105692 total_loss_global 0.2721554166581257 hit_ratio 0.76033934252386 ndcg 0.44961372208886285
398 local_loss 0.24399574417945152 total_loss_global 0.2712585226650287 hit_ratio 0.7518557794273595 ndcg 0.44897508470876063
399 local_loss 0.233060308088633 total_loss_global 0.2718408551201763 hit_ratio 0.7539766702014846 ndcg 0.44899943596644737
400 local_loss 0.24047697322464284 total_loss_global 0.2716134942029306 hit_ratio 0.7592788971367974 ndcg 0.4515758784129861
401 local_loss 0.24224215956137746 total_loss_global 0.2720312692109566 hit_ratio 0.7560975609756098 ndcg 0.45088881142224946
402 local_loss 0.23664977574285037 total_loss_global 0.27433086693184305 hit_ratio 0.76033934252386 ndcg 0.4497333735247179
403 local_loss 0.240089006245684 total_loss_global 0.27354518653706306 hit_ratio 0.7613997879109226 ndcg 0.4498243356143674
40

462 local_loss 0.2332259818261247 total_loss_global 0.2679242177797472 hit_ratio 0.7582184517497349 ndcg 0.45607066733125984
463 local_loss 0.23464705105249473 total_loss_global 0.27101592893957477 hit_ratio 0.7560975609756098 ndcg 0.4556563328628844
464 local_loss 0.24055912649381458 total_loss_global 0.2693397076015013 hit_ratio 0.7582184517497349 ndcg 0.4552696077772617
465 local_loss 0.2405727497508256 total_loss_global 0.26703285722129316 hit_ratio 0.7613997879109226 ndcg 0.4575051105090892
466 local_loss 0.24603184262417596 total_loss_global 0.2654680074184569 hit_ratio 0.76033934252386 ndcg 0.4592975513679743
467 local_loss 0.2467537702081727 total_loss_global 0.2655064184058349 hit_ratio 0.7624602332979852 ndcg 0.4578025551887124
468 local_loss 0.2485935057731384 total_loss_global 0.26493800761059516 hit_ratio 0.7550371155885471 ndcg 0.45667150689327907
469 local_loss 0.2303995772085659 total_loss_global 0.2668455017926155 hit_ratio 0.7560975609756098 ndcg 0.4571091394486868
47