In [1]:
import torch

from torch.utils.data import DataLoader, Dataset


import pandas as pd
import numpy as np
from gmf import GMFEngine, GMF
from mlp import MLPEngine
from neumf import NeuMFEngine
from data import SampleGenerator
from metrics import MetronAtK

import copy

In [2]:


def evaluate_hit_ndcg(model, evaluate_data):    
    model.eval()
    with torch.no_grad():
        test_users, test_items = evaluate_data[0], evaluate_data[1]
        negative_users, negative_items = evaluate_data[2], evaluate_data[3]        
        test_scores = model(test_users, test_items)
        negative_scores = model(negative_users, negative_items)
        
        metron = MetronAtK(top_k=10)
        
        metron.subjects = [test_users.data.view(-1).tolist(),
                             test_items.data.view(-1).tolist(),
                             test_scores.data.view(-1).tolist(),
                             negative_users.data.view(-1).tolist(),
                             negative_items.data.view(-1).tolist(),
                             negative_scores.data.view(-1).tolist()]
    hit_ratio, ndcg = metron.cal_hit_ratio(), metron.cal_ndcg()
        
    
    return hit_ratio, ndcg


def eval_ce_loss(model, data_loader):
    total_loss = 0.0
    n_batches = 0
    model.eval()
    with torch.no_grad():
        for batch in data_loader:
            users, items, ratings = batch[0], batch[1], batch[2]
            ratings = ratings.float()
            ratings_pred = model(users, items)
            crit = torch.nn.BCELoss()
            loss = crit(ratings_pred.view(-1), ratings)

            total_loss += loss.item()
            n_batches += 1
    
    return total_loss / n_batches

In [3]:
ml1m_dir = 'data/ml-100k/u.data'
ml1m_rating = pd.read_csv(ml1m_dir, sep='\t', header=None, names=['uid', 'mid', 'rating', 'timestamp'],  engine='python')
# Reindex
user_id = ml1m_rating[['uid']].drop_duplicates().reindex()
user_id['userId'] = np.arange(len(user_id))
ml1m_rating = pd.merge(ml1m_rating, user_id, on=['uid'], how='left')
item_id = ml1m_rating[['mid']].drop_duplicates()
item_id['itemId'] = np.arange(len(item_id))
ml1m_rating = pd.merge(ml1m_rating, item_id, on=['mid'], how='left')
ml1m_rating = ml1m_rating[['userId', 'itemId', 'rating', 'timestamp']]
print('Range of userId is [{}, {}]'.format(ml1m_rating.userId.min(), ml1m_rating.userId.max()))
print('Range of itemId is [{}, {}]'.format(ml1m_rating.itemId.min(), ml1m_rating.itemId.max()))

Range of userId is [0, 942]
Range of itemId is [0, 1681]


In [4]:
num_users = len(ml1m_rating.userId.unique())
num_items = len(ml1m_rating.itemId.unique())

gmf_config = {'alias': 'gmf_factor8neg4-implict',
              'num_epoch': 200,
              'batch_size': 1024,
              'optimizer': 'adam',
              'adam_lr': 1e-3,
              'num_users': num_users,
              'num_items': num_items,
              'latent_dim': 12,
              'num_negative': 5,
              'l2_regularization': 0, # 0.01
              'use_cuda': False,
              'device_id': 0,
              'model_dir':'checkpoints/{}_Epoch{}_HR{:.4f}_NDCG{:.4f}.model'}





## centralised training

In [5]:
sample_generator = SampleGenerator(ratings=ml1m_rating)
train_loader = sample_generator.instance_a_train_loader(num_negatives=gmf_config['num_negative'], 
                                                        batch_size=256)
train_dataset = train_loader.dataset
evaluate_data = sample_generator.evaluate_data

In [9]:
def train_main(model_global, train_dataset, evaluate_data):
    opt = torch.optim.Adam(model_global.parameters(), 1e-3)
    for epoch in range(200):
        train_loader = sample_generator.instance_a_train_loader(num_negatives=4, batch_size=1024)
        
        for batch in train_loader:
            opt.zero_grad()
            users, items, ratings = batch[0], batch[1], batch[2]
            ratings = ratings.float()
            ratings_pred = model_global(users, items)
            crit = torch.nn.BCELoss()
            loss = crit(ratings_pred.view(-1), ratings)
            loss.backward()
            opt.step()

        model_global.eval()
        hit_ratio, ndcg = evaluate_hit_ndcg(model_global, evaluate_data)
        train_ce_loss = eval_ce_loss(model_global, train_loader)
        if epoch % 10 == 0:
            print(epoch, "train_ce_loss", train_ce_loss, "hit_ratio", hit_ratio, "ndcg", ndcg)
        

In [None]:
model_c = GMF(gmf_config)
train_main(model_c, train_dataset, evaluate_data)

0 train_ce_loss 0.5467866653006924 hit_ratio 0.088016967126193 ndcg 0.040109803837457436
10 train_ce_loss 0.4965631184129676 hit_ratio 0.11558854718981973 ndcg 0.052032741536971586


## federated training

In [18]:
class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        return self.dataset[self.idxs[item]]


class LocalUpdate:
    def __init__(self, model_global, dataset, user_id):
        idxs = torch.where(dataset.user_tensor == user_id)[0]
        self.local_dataset = DatasetSplit(dataset, idxs)
        self.num_local = len(self.local_dataset)        
        self.local_model = copy.deepcopy(model_global)
        self.user_id = user_id
        
        
    def train(self, lr, epoches_local, local_batch_size):
        local_data_loader = DataLoader(self.local_dataset, 
                                       batch_size=local_batch_size, 
                                       shuffle=True)
        
        self.local_model.train()
        opt = torch.optim.Adam(self.local_model.parameters(), lr=lr)
        for ep_local in range(epoches_local):
            t_loss = 0.0
            n_batch = 0
            for users, items, ratings in local_data_loader:
                opt.zero_grad()
                ratings = ratings.float()
                ratings_pred = self.local_model(users, items).view(-1)
                
                crit = torch.nn.BCELoss()
                loss = crit(ratings_pred, ratings)
                loss.backward()
                opt.step()
                
                t_loss += loss.item()
                n_batch += 1
            
            if (ep_local + 1) % 100  == 0:
                print("user", self.user_id, "ep_local", ep_local, t_loss/n_batch)
        
        return t_loss/n_batch


In [23]:
def train_fedavg(model_global, train_dataset, evaluate_data):
    frac = 0.1
    epoches_local = 3
    local_lr = 1e-2
    
    for epoch in range(500):
        # sample users
        m = max(int(frac * num_users), 1)
        idxs_users = np.random.choice(range(num_users), m, replace=False)

        w_locals, n_locals = [], []
        total_loss = 0.0
        for idx in idxs_users:        
            # local update
            local_update = LocalUpdate(model_global, train_dataset, user_id=idx)
            local_batch_size = int(len(local_update.local_dataset)*0.1)
            total_loss += local_update.train(local_lr, epoches_local, local_batch_size=local_batch_size)

            w_locals.append(copy.deepcopy(local_update.local_model.state_dict()))
            n_locals.append(local_update.num_local)

        n_locals = np.array(n_locals, dtype=np.float64)
        n_locals /= n_locals.sum()


        # average model
        w_avg = {}    
        for k in model_global.state_dict().keys():
            w_avg[k] = 0.0
            for i in range(len(w_locals)):
                w_avg[k] += w_locals[i][k] * n_locals[i]

        model_global.load_state_dict(w_avg)
        hit_ratio, ndcg = evaluate_hit_ndcg(model_global, evaluate_data)

        total_loss_all = eval_ce_loss(model_global, train_loader)
        print(epoch, "loss", total_loss/m, total_loss_all, "hit_ratio", hit_ratio, "ndcg", ndcg)


    

    

In [24]:
model_global = GMF(gmf_config)

In [None]:
train_fedavg(model_global, train_dataset, evaluate_data)

0 loss 0.5390761853791988 0.5783632659737556 hit_ratio 0.24284199363732767 ndcg 0.11061090882636314
1 loss 0.48633375453764505 0.5156896067887519 hit_ratio 0.22693531283138918 ndcg 0.10670224238783678
