In [1]:
import torch

from torch.utils.data import DataLoader, Dataset


import pandas as pd
import numpy as np
from gmf import GMFEngine, GMF
from mlp import MLPEngine
from neumf import NeuMFEngine
from data import SampleGenerator
from metrics import MetronAtK

import copy

In [2]:


def evaluate_hit_ndcg(model, evaluate_data):    
    model.eval()
    with torch.no_grad():
        test_users, test_items = evaluate_data[0], evaluate_data[1]
        negative_users, negative_items = evaluate_data[2], evaluate_data[3]        
        test_scores = model(test_users, test_items)
        negative_scores = model(negative_users, negative_items)
        
        metron = MetronAtK(top_k=10)
        
        metron.subjects = [test_users.data.view(-1).tolist(),
                             test_items.data.view(-1).tolist(),
                             test_scores.data.view(-1).tolist(),
                             negative_users.data.view(-1).tolist(),
                             negative_items.data.view(-1).tolist(),
                             negative_scores.data.view(-1).tolist()]
    hit_ratio, ndcg = metron.cal_hit_ratio(), metron.cal_ndcg()
        
    
    return hit_ratio, ndcg


def eval_ce_loss(model, data_loader):
    total_loss = 0.0
    n_batches = 0
    model.eval()
    with torch.no_grad():
        for batch in data_loader:
            users, items, ratings = batch[0], batch[1], batch[2]
            ratings = ratings.float()
            ratings_pred = model(users, items)
            crit = torch.nn.BCELoss()
            loss = crit(ratings_pred.view(-1), ratings)

            total_loss += loss.item()
            n_batches += 1
    
    return total_loss / n_batches

In [3]:
ml1m_dir = 'data/ml-100k/u.data'
ml1m_rating = pd.read_csv(ml1m_dir, sep='\t', header=None, names=['uid', 'mid', 'rating', 'timestamp'],  engine='python')
# Reindex
user_id = ml1m_rating[['uid']].drop_duplicates().reindex()
user_id['userId'] = np.arange(len(user_id))
ml1m_rating = pd.merge(ml1m_rating, user_id, on=['uid'], how='left')
item_id = ml1m_rating[['mid']].drop_duplicates()
item_id['itemId'] = np.arange(len(item_id))
ml1m_rating = pd.merge(ml1m_rating, item_id, on=['mid'], how='left')
ml1m_rating = ml1m_rating[['userId', 'itemId', 'rating', 'timestamp']]
print('Range of userId is [{}, {}]'.format(ml1m_rating.userId.min(), ml1m_rating.userId.max()))
print('Range of itemId is [{}, {}]'.format(ml1m_rating.itemId.min(), ml1m_rating.itemId.max()))

Range of userId is [0, 942]
Range of itemId is [0, 1681]


In [4]:
num_users = len(ml1m_rating.userId.unique())
num_items = len(ml1m_rating.itemId.unique())

gmf_config = {
              'num_users': num_users,
              'num_items': num_items,
              'latent_dim': 12,
              'num_negative': 5,
              'l2_regularization': 0, # 0.01
              'use_cuda': False,
              'device_id': 0,
              'model_dir':'checkpoints/{}_Epoch{}_HR{:.4f}_NDCG{:.4f}.model'}





## centralised training

In [5]:
sample_generator = SampleGenerator(ratings=ml1m_rating)
train_loader = sample_generator.instance_a_train_loader(num_negatives=gmf_config['num_negative'], 
                                                        batch_size=256)
train_dataset = train_loader.dataset
evaluate_data = sample_generator.evaluate_data

In [6]:
def train_main(model_global, train_dataset, evaluate_data):
    lr = 1e-3
    
    opt = torch.optim.Adam(model_global.parameters(), lr)
    for epoch in range(200):
        train_loader = sample_generator.instance_a_train_loader(num_negatives=4, batch_size=1024)
        
        for batch in train_loader:
            opt.zero_grad()
            users, items, ratings = batch[0], batch[1], batch[2]
            ratings = ratings.float()
            ratings_pred = model_global(users, items)
            crit = torch.nn.BCELoss()
            loss = crit(ratings_pred.view(-1), ratings)
            loss.backward()
            opt.step()

        model_global.eval()
        hit_ratio, ndcg = evaluate_hit_ndcg(model_global, evaluate_data)
        train_ce_loss = eval_ce_loss(model_global, train_loader)
        print(epoch, "train_ce_loss", train_ce_loss, "hit_ratio", hit_ratio, "ndcg", ndcg)
        

In [7]:
model_c = GMF(gmf_config)
train_main(model_c, train_dataset, evaluate_data)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_in_top_k['ndcg'] = test_in_top_k['rank'].apply(lambda x: math.log(2) / math.log(1 + x)) # the rank starts from 1


0 train_ce_loss 0.5740198267639176 hit_ratio 0.2142099681866384 ndcg 0.09704850859113098
1 train_ce_loss 0.5274813306356264 hit_ratio 0.20784729586426298 ndcg 0.09476810643240419
2 train_ce_loss 0.5086326403312447 hit_ratio 0.2258748674443266 ndcg 0.10858717181525783
3 train_ce_loss 0.5022407069432834 hit_ratio 0.21739130434782608 ndcg 0.10477327401283115
4 train_ce_loss 0.5001253020788027 hit_ratio 0.23753976670201485 ndcg 0.11131652518011188
5 train_ce_loss 0.49751815132119437 hit_ratio 0.26193001060445387 ndcg 0.129711627431123
6 train_ce_loss 0.4880572518041311 hit_ratio 0.31919406150583246 ndcg 0.161500028277831
7 train_ce_loss 0.46632805731424615 hit_ratio 0.3743372216330859 ndcg 0.19927761662079796
8 train_ce_loss 0.4391517216755339 hit_ratio 0.41675503711558853 ndcg 0.22848971762007547
9 train_ce_loss 0.41371900188036204 hit_ratio 0.471898197242842 ndcg 0.25775813152238747
10 train_ce_loss 0.3937424024767127 hit_ratio 0.5047720042417816 ndcg 0.27362133957651646
11 train_ce_loss

92 train_ce_loss 0.22376056197137872 hit_ratio 0.799575821845175 ndcg 0.5115728670381563
93 train_ce_loss 0.22347815085417969 hit_ratio 0.7985153764581124 ndcg 0.5106200889140203
94 train_ce_loss 0.22378784799871365 hit_ratio 0.7942735949098622 ndcg 0.5121273757069116
95 train_ce_loss 0.22284690814077363 hit_ratio 0.7974549310710498 ndcg 0.51347040284599
96 train_ce_loss 0.22284306952160252 hit_ratio 0.8006362672322376 ndcg 0.5128588148546073
97 train_ce_loss 0.22156178495608086 hit_ratio 0.7974549310710498 ndcg 0.5136760683119489
98 train_ce_loss 0.22175343950425297 hit_ratio 0.7985153764581124 ndcg 0.5136367559092091
99 train_ce_loss 0.22148555791205612 hit_ratio 0.8016967126193001 ndcg 0.5153130573076777
100 train_ce_loss 0.22079066110174517 hit_ratio 0.8016967126193001 ndcg 0.512572832794505
101 train_ce_loss 0.2207164763172796 hit_ratio 0.8006362672322376 ndcg 0.5116272035179897
102 train_ce_loss 0.22059614436069797 hit_ratio 0.8006362672322376 ndcg 0.5152782886868305
103 train_ce

183 train_ce_loss 0.20905940918144117 hit_ratio 0.8123011664899258 ndcg 0.5309151173295159
184 train_ce_loss 0.2092302415119715 hit_ratio 0.8101802757158006 ndcg 0.5283339316524382
185 train_ce_loss 0.20807968338658986 hit_ratio 0.8123011664899258 ndcg 0.5294628533603106
186 train_ce_loss 0.20891578218414764 hit_ratio 0.8133616118769883 ndcg 0.5297744670089422
187 train_ce_loss 0.20829693523566584 hit_ratio 0.8123011664899258 ndcg 0.5293670763685744
188 train_ce_loss 0.2089595706317543 hit_ratio 0.8144220572640509 ndcg 0.5295784025697623
189 train_ce_loss 0.20816267771292324 hit_ratio 0.8154825026511134 ndcg 0.5297002690497877
190 train_ce_loss 0.20827406934223885 hit_ratio 0.8144220572640509 ndcg 0.5273491643964787
191 train_ce_loss 0.20807476153920504 hit_ratio 0.8133616118769883 ndcg 0.5275154876171488
192 train_ce_loss 0.20782285811733608 hit_ratio 0.8144220572640509 ndcg 0.5294258187862193
193 train_ce_loss 0.20839203278387874 hit_ratio 0.8101802757158006 ndcg 0.5273630046234773
1

## federated training

In [8]:
class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        return self.dataset[self.idxs[item]]


class LocalUpdate:
    def __init__(self, model_global, dataset, user_id):
        idxs = torch.where(dataset.user_tensor == user_id)[0]
        self.local_dataset = DatasetSplit(dataset, idxs)
        self.num_local = len(self.local_dataset)        
        self.local_model = copy.deepcopy(model_global)
        self.user_id = user_id
        
        
    def train(self, lr, epoches_local, local_batch_size):
        local_data_loader = DataLoader(self.local_dataset, 
                                       batch_size=local_batch_size, 
                                       shuffle=True)
        
        self.local_model.train()
        opt = torch.optim.Adam(self.local_model.parameters(), lr=lr)
        for ep_local in range(epoches_local):
            t_loss = 0.0
            n_batch = 0
            for users, items, ratings in local_data_loader:
                opt.zero_grad()
                ratings = ratings.float()
                ratings_pred = self.local_model(users, items).view(-1)
                
                crit = torch.nn.BCELoss()
                loss = crit(ratings_pred, ratings)
                loss.backward()
                opt.step()
                
                t_loss += loss.item()
                n_batch += 1
            
            if (ep_local + 1) % 100  == 0:
                print("user", self.user_id, "ep_local", ep_local, t_loss/n_batch)
        
        return t_loss/n_batch


In [12]:
def train_fedavg(model_global, train_dataset, evaluate_data):
    frac = 0.1
    epoches_local = 3
    local_lr = 1e-3
    
    for epoch in range(500):
        # sample users
        m = max(int(frac * num_users), 1)
        idxs_users = np.random.choice(range(num_users), m, replace=False)

        w_locals, n_locals = [], []
        total_loss = 0.0
        for idx in idxs_users:        
            # local update
            local_update = LocalUpdate(model_global, train_dataset, user_id=idx)
            local_batch_size = int(len(local_update.local_dataset)*0.1)
            total_loss += local_update.train(local_lr, epoches_local, local_batch_size=local_batch_size)

            w_locals.append(copy.deepcopy(local_update.local_model.state_dict()))
            n_locals.append(local_update.num_local)

        n_locals = np.array(n_locals, dtype=np.float64)
        n_locals /= n_locals.sum()


        # average model
        w_avg = {}    
        for k in model_global.state_dict().keys():
            w_avg[k] = 0.0
            for i in range(len(w_locals)):
                w_avg[k] += w_locals[i][k] * n_locals[i]

        model_global.load_state_dict(w_avg)
        hit_ratio, ndcg = evaluate_hit_ndcg(model_global, evaluate_data)

        total_loss_global = eval_ce_loss(model_global, train_loader)
        print(epoch, "local_loss", total_loss/m, "total_loss_global", total_loss_global, 
              "hit_ratio", hit_ratio, "ndcg", ndcg)


    

    

In [13]:
model_global = GMF(gmf_config)

In [None]:
train_fedavg(model_global, train_dataset, evaluate_data)

0 local_loss 0.7245157656632725 total_loss_global 0.7282792217355674 hit_ratio 0.20042417815482502 ndcg 0.08980022042868435
1 local_loss 0.7112132534989998 total_loss_global 0.7142292800183752 hit_ratio 0.20042417815482502 ndcg 0.09001893182618896
2 local_loss 0.6950572024115744 total_loss_global 0.7004692702587055 hit_ratio 0.20254506892895016 ndcg 0.09063698058384649
3 local_loss 0.6838121922553392 total_loss_global 0.6872859585398132 hit_ratio 0.20360551431601273 ndcg 0.09069299799974677
4 local_loss 0.673727118703798 total_loss_global 0.6746066728611633 hit_ratio 0.20360551431601273 ndcg 0.0910813326418671
5 local_loss 0.6588897539810701 total_loss_global 0.6624931042354544 hit_ratio 0.2046659597030753 ndcg 0.09065401637839415
6 local_loss 0.6481224210914832 total_loss_global 0.6510529661106714 hit_ratio 0.20360551431601273 ndcg 0.0907122851514535
7 local_loss 0.6372062382372942 total_loss_global 0.6399484349673006 hit_ratio 0.20572640509013787 ndcg 0.09126728676385613
8 local_loss

66 local_loss 0.44965678617588795 total_loss_global 0.4507348009936261 hit_ratio 0.19936373276776245 ndcg 0.08851442111256895
67 local_loss 0.4473368826079067 total_loss_global 0.4507341010369277 hit_ratio 0.19618239660657477 ndcg 0.08909198723231407
68 local_loss 0.4502548539733633 total_loss_global 0.4507049663019632 hit_ratio 0.19300106044538706 ndcg 0.08849320102182787
69 local_loss 0.45336838299750853 total_loss_global 0.45072872704416383 hit_ratio 0.19300106044538706 ndcg 0.08819500359173019
70 local_loss 0.446662590906341 total_loss_global 0.4507091249324041 hit_ratio 0.1951219512195122 ndcg 0.08842741607590049
71 local_loss 0.44952919865081103 total_loss_global 0.45068773442460586 hit_ratio 0.1951219512195122 ndcg 0.08866971033402114
72 local_loss 0.4498004335917286 total_loss_global 0.4506787903010075 hit_ratio 0.1919406150583245 ndcg 0.08668266651562578
73 local_loss 0.44740900675292394 total_loss_global 0.45066595828348355 hit_ratio 0.19618239660657477 ndcg 0.088415589979879

132 local_loss 0.44776814829565353 total_loss_global 0.4506344108952211 hit_ratio 0.18345705196182396 ndcg 0.08269708216859757
133 local_loss 0.4467796896294307 total_loss_global 0.4506754910478707 hit_ratio 0.17709437963944857 ndcg 0.0811873423544701
134 local_loss 0.44790973761758435 total_loss_global 0.45066033624606744 hit_ratio 0.1855779427359491 ndcg 0.08483895495613072
135 local_loss 0.4508586850341562 total_loss_global 0.45066539197518435 hit_ratio 0.18345705196182396 ndcg 0.08372828872820583
136 local_loss 0.44766839427632105 total_loss_global 0.45066565267491404 hit_ratio 0.18345705196182396 ndcg 0.08283909624253445
137 local_loss 0.44746033923711254 total_loss_global 0.4506986111043138 hit_ratio 0.18345705196182396 ndcg 0.08261067781362205
138 local_loss 0.4512632311932469 total_loss_global 0.4506853099889944 hit_ratio 0.18663838812301167 ndcg 0.08329705573175182
139 local_loss 0.45045489640456116 total_loss_global 0.45071666780779013 hit_ratio 0.18451749734888653 ndcg 0.084

197 local_loss 0.4525464210176492 total_loss_global 0.45076090844410643 hit_ratio 0.19406150583244963 ndcg 0.08717705208627838
198 local_loss 0.44864012590777025 total_loss_global 0.45080697228434163 hit_ratio 0.18981972428419935 ndcg 0.08453197583639761
199 local_loss 0.44863890835031794 total_loss_global 0.4507950442449272 hit_ratio 0.19936373276776245 ndcg 0.0874509269902903
200 local_loss 0.44342509458126145 total_loss_global 0.45081757082187546 hit_ratio 0.19618239660657477 ndcg 0.08727907365376289
201 local_loss 0.4453288133776741 total_loss_global 0.450841164712142 hit_ratio 0.19724284199363734 ndcg 0.08728796191580918
202 local_loss 0.4487759639953968 total_loss_global 0.45082194046867397 hit_ratio 0.20042417815482502 ndcg 0.08871879315201452
203 local_loss 0.44968976861935966 total_loss_global 0.45085486122804097 hit_ratio 0.20042417815482502 ndcg 0.08915533723911157
204 local_loss 0.45275633457464454 total_loss_global 0.4508171539872397 hit_ratio 0.19830328738069988 ndcg 0.08

262 local_loss 0.44824794327382655 total_loss_global 0.45090581082661124 hit_ratio 0.18769883351007424 ndcg 0.08489385806874963
263 local_loss 0.4476095658782833 total_loss_global 0.4509092332806493 hit_ratio 0.18981972428419935 ndcg 0.08411494974512451
264 local_loss 0.45164805881866404 total_loss_global 0.4509149568923889 hit_ratio 0.19300106044538706 ndcg 0.08535342136601819
265 local_loss 0.4505292272414992 total_loss_global 0.45090807890912565 hit_ratio 0.19088016967126192 ndcg 0.08479772558642495
266 local_loss 0.44865696984748543 total_loss_global 0.45088037682904136 hit_ratio 0.18769883351007424 ndcg 0.08326021413784548
267 local_loss 0.45019451181189923 total_loss_global 0.4508581949681891 hit_ratio 0.1823966065747614 ndcg 0.08226458143315106
268 local_loss 0.4474221736552983 total_loss_global 0.4508228319970623 hit_ratio 0.18345705196182396 ndcg 0.08266598575174976
269 local_loss 0.4470686737354305 total_loss_global 0.45082370414495676 hit_ratio 0.19088016967126192 ndcg 0.085

327 local_loss 0.44850263189906286 total_loss_global 0.4508520683191851 hit_ratio 0.19830328738069988 ndcg 0.08973161085775541
328 local_loss 0.4486675353702769 total_loss_global 0.4508771948214752 hit_ratio 0.20360551431601273 ndcg 0.09051617468727978
329 local_loss 0.4472956399795396 total_loss_global 0.45091933746157176 hit_ratio 0.20890774125132555 ndcg 0.09261337638222777
330 local_loss 0.44407753198762456 total_loss_global 0.45096995094333814 hit_ratio 0.20784729586426298 ndcg 0.09194949995707073
331 local_loss 0.44666863838408866 total_loss_global 0.45098182333121517 hit_ratio 0.2046659597030753 ndcg 0.09124795177173552
332 local_loss 0.44851153470446126 total_loss_global 0.45099474100795517 hit_ratio 0.2046659597030753 ndcg 0.091477762085216
333 local_loss 0.44930138608271414 total_loss_global 0.4510310531997352 hit_ratio 0.20784729586426298 ndcg 0.0928579181261422
334 local_loss 0.4530149297078988 total_loss_global 0.4509296124603705 hit_ratio 0.20784729586426298 ndcg 0.092720