# LightGCN implementation

In [None]:
!pip install ipdb

In [2]:
import random
import numpy as np
import scipy.sparse as sp
import torch
from torch import nn
import torch.nn.functional as F
import ipdb
from torch.utils.data import Dataset, DataLoader
import os
from itertools import product
from datetime import datetime
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import time

In [7]:
class LightGCN(nn.Module):
    def __init__(self, conf, ui_adj_graph):
        super(LightGCN, self).__init__()
        self.conf = conf
        self.num_users = conf["n_users"]
        self.num_items = conf["n_items"]
        self.emb_size = self.conf['emb_size']
        self.n_layer = self.conf['n_layer']
        self.Graph = ui_adj_graph
        self.__init_weight()


    def __init_weight(self):
        self.embedding_user = torch.nn.Embedding(
            num_embeddings=self.num_users, embedding_dim=self.emb_size)
        self.embedding_item = torch.nn.Embedding(
            num_embeddings=self.num_items, embedding_dim=self.emb_size)

        nn.init.xavier_uniform_(self.embedding_user.weight, gain=1)
        nn.init.xavier_uniform_(self.embedding_item.weight, gain=1)
        self.f = nn.Sigmoid()


    def propagate(self):
        users_emb = self.embedding_user.weight
        items_emb = self.embedding_item.weight
        features = torch.cat([users_emb, items_emb])

        embs = [features]

        g_droped = self.Graph

        for layer in range(self.n_layer):
            features = torch.sparse.mm(g_droped, features)
            features = F.normalize(features, p=2, dim=1)
            embs.append(features)

        light_out = torch.stack(embs, dim=1)
        light_out = torch.sum(light_out, dim=1)

        users, items = torch.split(light_out, [self.num_users, self.num_items])
        return users, items


    def getUsersRating(self, users):
        all_users, all_items = self.propagate()
        users_emb = all_users[users.long()]
        items_emb = all_items
        rating = self.f(torch.matmul(users_emb, items_emb.t()))

        return rating


    def getEmbedding(self, users, pos_items, neg_items):
        all_users, all_items = self.propagate()
        users_emb = all_users[users]
        pos_emb = all_items[pos_items]
        neg_emb = all_items[neg_items]
        users_emb_ego = self.embedding_user(users)
        pos_emb_ego = self.embedding_item(pos_items)
        neg_emb_ego = self.embedding_item(neg_items)

        return users_emb, pos_emb, neg_emb, users_emb_ego, pos_emb_ego, neg_emb_ego


    def bpr_loss(self, users, pos, neg):
        (users_emb, pos_emb, neg_emb,
         userEmb0, posEmb0, negEmb0) = self.getEmbedding(users.long(), pos.long(), neg.long())
        reg_loss = (1 / 2) * (userEmb0.norm(2).pow(2) +
                              posEmb0.norm(2).pow(2) +
                              negEmb0.norm(2).pow(2)) / float(len(users))
        pos_scores = torch.mul(users_emb, pos_emb)
        pos_scores = torch.sum(pos_scores, dim=1)
        neg_scores = torch.mul(users_emb, neg_emb)
        neg_scores = torch.sum(neg_scores, dim=1)

        loss = torch.mean(torch.nn.functional.softplus(neg_scores - pos_scores))

        return loss, reg_loss


    def forward(self, batch):
        user_ids, pos_ids, neg_ids = batch
        bpr, L2_reg = self.bpr_loss(user_ids, pos_ids, neg_ids)
        return bpr


    def evaluate(self, propagate_result, users):
        users_feature, item_feature = propagate_result
        users_embedding = users_feature[users]
        scores = torch.mm(users_embedding, item_feature.t())

        return scores


In [8]:
class TrainData4UI(Dataset):
    def __init__(self, conf, ui_pairs, ui_graph_train):
        self.conf = conf
        self.ui_pairs = ui_pairs
        self.ui_graph = ui_graph_train
        self.n_items = ui_graph_train.shape[1]
    

    def __len__(self):
        return len(self.ui_pairs)


    def __getitem__(self, idx):
        # Return the train data
        # Output: user, grd item, negative item
        u, i = self.ui_pairs[idx]
        j = random.randint(0, self.n_items-1)
        while self.ui_graph[u, j] == 1:
            j = random.randint(0, self.n_items-1)
        return u, i, int(j)



class TestData(Dataset):
    def __init__(self, conf, ui_graph_test, ui_graph_train):
        self.conf = conf
        self.ui_graph_test = ui_graph_test
        self.ui_graph_train = ui_graph_train


    def __len__(self):
        return self.ui_graph_test.shape[0]


    def __getitem__(self, idx):
        # Return the test data
        # Output: index, grd test item sequence, train mask
        grd = torch.from_numpy(self.ui_graph_test[idx].toarray()).squeeze()
        train_mask = torch.from_numpy(self.ui_graph_train[idx].toarray()).squeeze()

        return idx, grd, train_mask 

    
class UI_Dataset():
    def __init__(self, conf):
        self.conf = conf
        self.n_users, self.n_items = self.get_dataset_size()

        self.ui_pairs_train, self.ui_graph_train = self.get_graph("train.txt")
        _, self.ui_graph_val = self.get_graph("valid.txt")
        _, self.ui_graph_test = self.get_graph("test.txt")
        

        self.train_set = TrainData4UI(conf, self.ui_pairs_train, self.ui_graph_train)
        self.train_loader = DataLoader(self.train_set, batch_size=conf["batch_size"], shuffle=True, num_workers=conf["data_loader_num"])
        self.test_set = TestData(conf, self.ui_graph_test, self.ui_graph_train)
        self.test_loader = DataLoader(self.test_set, batch_size=conf["test_batch_size"], shuffle=False, num_workers=conf["data_loader_num"])
        self.val_set = TestData(conf, self.ui_graph_val, self.ui_graph_train)
        self.val_loader = DataLoader(self.val_set, batch_size=conf["test_batch_size"], shuffle=False, num_workers=conf["data_loader_num"])

        self.Graph = None
    

    def get_dataset_size(self):
        target_path = self.conf["target_path"]

        n_users, n_items = 0, 0
        for line in open(target_path + "dataset_size.txt"):
            n_users, n_items = line.strip().split()
            break

        return int(n_users), int(n_items)


    def get_graph(self, filename):
        # Output: [user, item] pairs and the corresonding matrix
        target_path = self.conf["target_path"]

        ui_pairs = []
        for line in open(target_path + filename):
            each = [int(x) for x in line.strip().split("\t")]
            u = each[0]
            for i in each[1:]:
                ui_pairs.append([u, i])

        indice = np.array(ui_pairs, dtype=np.int32)
        values = np.ones(len(ui_pairs), dtype=np.float32)
        ui_graph = sp.coo_matrix((values, (indice[:, 0], indice[:, 1])), shape=(self.n_users, self.n_items)).tocsr()

        return ui_pairs, ui_graph

    def getSparseGraph(self):
        print("generating adjacency matrix")
        adj_mat = sp.dok_matrix((self.n_users + self.n_items, self.n_users + self.n_items), dtype=np.float32)
        adj_mat = adj_mat.tolil()
        R = self.ui_graph_train.tolil()
        adj_mat[:self.n_users, self.n_users:] = R
        adj_mat[self.n_users:, :self.n_users] = R.T
        adj_mat = adj_mat.todok()

        rowsum = np.array(adj_mat.sum(axis=1))
        d_inv = np.power(rowsum, -0.5).flatten()
        d_inv[np.isinf(d_inv)] = 0.
        d_mat = sp.diags(d_inv)

        norm_adj = d_mat.dot(adj_mat)
        norm_adj = norm_adj.dot(d_mat)
        norm_adj = norm_adj.tocsr()
        sp.save_npz(self.conf['target_path'] + 's_pre_adj_mat.npz', norm_adj)

        self.Graph = self._convert_sp_mat_to_sp_tensor(norm_adj)
        self.Graph = self.Graph.coalesce().to(self.conf['device'])
        return self.Graph

    def _convert_sp_mat_to_sp_tensor(self, X):
        coo = X.tocoo().astype(np.float32)
        row = torch.Tensor(coo.row).long()
        col = torch.Tensor(coo.col).long()
        index = torch.stack([row, col])
        data = torch.FloatTensor(coo.data)
        return torch.sparse.FloatTensor(index, data, torch.Size(coo.shape))

In [9]:
def init_best_metrics(conf):
    best_metrics = {}
    best_metrics["val"] = {}
    best_metrics["test"] = {}
    for key in best_metrics:
        best_metrics[key]["recall"] = {}
        best_metrics[key]["ndcg"] = {}
    for topk in conf['topk']:
        for key in best_metrics:
            for metric in best_metrics[key]:
                best_metrics[key][metric][topk] = 0
    best_perform = {}
    best_perform["val"] = {}
    best_perform["test"] = {}
    return best_metrics, best_perform


def write_log(log_path, topk, step, metrics):
    curr_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    val_scores = metrics["val"]
    test_scores = metrics["test"]

    for m, val_score in val_scores.items():
        test_score = test_scores[m]
        
    val_str = "%s, Top_%d, Val:  recall: %f, ndcg: %f" %(curr_time, topk, val_scores["recall"][topk], val_scores["ndcg"][topk])
    test_str = "%s, Top_%d, Test: recall: %f, ndcg: %f" %(curr_time, topk, test_scores["recall"][topk], test_scores["ndcg"][topk])

    log = open(log_path, "a")
    log.write("%s\n" %(val_str))
    log.write("%s\n" %(test_str))
    log.close()
    
    print(val_str)
    print(test_str)
     

def train(conf):
    conf["device"] = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


    for k, v in conf.items():
        print(k, v)

    dataset = UI_Dataset(conf)
    conf['n_users'] = dataset.n_users
    conf['n_items'] = dataset.n_items
    
    log_path = "./log/"
    checkpoint_path = "./checkpoints/"
    if not os.path.isdir(log_path):
        os.makedirs(log_path)
    if not os.path.isdir(checkpoint_path):
        os.makedirs(checkpoint_path)

    settings = []

    settings += [ "LR"+str(conf['lr']),"emb"+str(conf['emb_size']), "bs"+str(conf['batch_size']), "WD"+str(conf['weight_decay']), 'epoch'+str(conf['num_epoches']), 'pop'+str(conf['pop']), time.strftime("%m_%d", time.localtime())] 

    setting = "_".join(settings)
    log_path = log_path + "/" + setting
    checkpoint_path = checkpoint_path + "/" + setting

    model = LightGCN(conf, dataset.getSparseGraph())

    model.to(device=conf["device"])
    optimizer = torch.optim.Adam(model.parameters(), lr=conf["lr"], weight_decay=conf['weight_decay'])
    print("%s start training ... "%datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    batch_cnt = len(dataset.train_loader)
    test_interval_bs = int(batch_cnt * conf["test_interval"])
    best_metrics, best_perform = init_best_metrics(conf)
    best_epoch = 0

    # Train
    for epoch in range(conf["num_epoches"]):
        epoch_anchor = epoch * batch_cnt
        model.train(True)
        pbar = tqdm(enumerate(dataset.train_loader), total=len(dataset.train_loader))

        for batch_i, batch in pbar:
            model.train(True)
            optimizer.zero_grad()
            batch = [x.to(conf["device"]) for x in batch]
            batch_anchor = epoch_anchor + batch_i

            loss = model(batch)
            loss.backward()
            optimizer.step()

            loss_scalar = loss.detach()
            pbar.set_description("epoch: %d, loss: %.4f" %(epoch, loss_scalar))

            # Test
            if (batch_anchor + 1) % test_interval_bs == 0:
                metrics = {}

                metrics["val"] = test(model, dataset.val_loader, conf)
                metrics["test"] = test(model, dataset.test_loader, conf)

                best_metrics, best_perform, best_epoch = log_metrics(conf, model, metrics, log_path, checkpoint_path, epoch, batch_anchor, best_metrics, best_perform, best_epoch)

        # early stopping
        if epoch >= 10 and epoch-best_epoch>=20:
            break


def log_metrics(conf, model, metrics, log_path, checkpoint_path, epoch, batch_anchor, best_metrics, best_perform, best_epoch):
    for topk in conf["topk"]:
        write_log(log_path, topk, batch_anchor, metrics)

    log = open(log_path, "a")

    topk_ = 20
    print("top%d as the final evaluation standard" %(topk_))
    if metrics["val"]["recall"][topk_] > best_metrics["val"]["recall"][topk_] and metrics["val"]["ndcg"][topk_] > best_metrics["val"]["ndcg"][topk_]:
        torch.save(model.state_dict(), checkpoint_path)
        best_epoch = epoch
        curr_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        for topk in conf['topk']:
            for key, res in best_metrics.items():
                for metric in res:
                    best_metrics[key][metric][topk] = metrics[key][metric][topk]

            best_perform["test"][topk] = "%s, Best in epoch %d, TOP %d: REC_T=%.5f, NDCG_T=%.5f" %(curr_time, best_epoch, topk, best_metrics["test"]["recall"][topk], best_metrics["test"]["ndcg"][topk])
            best_perform["val"][topk] = "%s, Best in epoch %d, TOP %d: REC_V=%.5f, NDCG_V=%.5f" %(curr_time, best_epoch, topk, best_metrics["val"]["recall"][topk], best_metrics["val"]["ndcg"][topk])
            print(best_perform["val"][topk])
            print(best_perform["test"][topk])
            log.write(best_perform["val"][topk] + "\n")
            log.write(best_perform["test"][topk] + "\n")

    log.close()

    return best_metrics, best_perform, best_epoch


def test(model, dataloader, conf):
    '''

    Run model on the validation or test data.

    Parameters
    ----------
    model: the trained model
    dataloader: validation or test data loader
    conf: model configuration
    cold_mask: the mask of the cold items

    Returns
    -------
    metrics: recall and ndcg for topk
    '''
    tmp_metrics = {}
    for m in ["recall", "ndcg"]:
        tmp_metrics[m] = {}
        for topk in conf["topk"]:
            tmp_metrics[m][topk] = [0, 0]

    device = conf["device"]
    model.eval()
    
    rs = model.propagate()
    m = 0
    for batch_cnt, batch in enumerate(dataloader):

        users, ground_truth, train_mask = batch
        users = users.to(conf['device'])
        ground_truth = ground_truth.to(conf['device'])
        
        if train_mask.shape[0] != m:
            m = train_mask.shape[0]


        pred = model.evaluate(rs, users.to(device))
        torch.cuda.empty_cache()
        pred -= 1e8 * train_mask.to(conf["device"])
        pred = pred.to(conf['device'])
        tmp_metrics = get_metrics(tmp_metrics, ground_truth, pred, conf["topk"])

    metrics = {}
    for m, topk_res in tmp_metrics.items():
        metrics[m] = {}
        for topk, res in topk_res.items():
            metrics[m][topk] = res[0] / res[1]

    return metrics




def get_metrics(metrics, grd, pred, topks):
    tmp = {"recall": {}, "ndcg": {}}

    for topk in topks:
        _, col_indice = torch.topk(pred, topk)
        row_indice = torch.zeros_like(col_indice) + torch.arange(pred.shape[0], device=pred.device, dtype=torch.long).view(-1, 1)
        is_hit = grd[row_indice.view(-1), col_indice.view(-1)].view(-1, topk)

        tmp["recall"][topk] = get_recall(pred, grd, is_hit, topk)
        tmp["ndcg"][topk] = get_ndcg(pred, grd, is_hit, topk)

    for m, topk_res in tmp.items():
        for topk, res in topk_res.items():
            for i, x in enumerate(res):
                metrics[m][topk][i] += x

    return metrics


def get_recall(pred, grd, is_hit, topk):
    
    epsilon = 1e-8
    hit_cnt = is_hit.sum(dim=1)
    num_pos = grd.sum(dim=1)
    
    # Remove those test cases who don't have any positive items
    denorm = pred.shape[0] - (num_pos == 0).sum().item()
    nomina = (hit_cnt/(num_pos+epsilon)).sum().item()
    
    return [nomina, denorm]


def get_ndcg(pred, grd, is_hit, topk):

    def DCG(hit, topk, device):
        hit = hit.to(device)
        hit = hit/torch.log2(torch.arange(2, topk+2, device=device, dtype=torch.float))
        return hit.sum(-1)

    def IDCG(num_pos, topk, device):
        hit = torch.zeros(topk, dtype=torch.float)
        hit[:num_pos] = 1
        return DCG(hit, topk, device)

    device = grd.device
    IDCGs = torch.empty(1+topk, dtype=torch.float)
    # Avoid 0/0
    IDCGs[0] = 1  
    for i in range(1, topk+1):
        IDCGs[i] = IDCG(i, topk, device)

    num_pos = grd.sum(dim=1).clamp(0, topk).to(torch.long)
    dcg = DCG(is_hit, topk, device)
    idcg = IDCGs[num_pos]
    ndcg = dcg/idcg.to(device)

    # Remove those test cases who don't have any positive items
    denorm = pred.shape[0] - (num_pos == 0).sum().item()
    nomina = ndcg.sum().item()

    return [nomina, denorm]


In [None]:
conf = {}

conf['target_path'] = './data/' 
conf['emb_size'] = 64
conf['n_layer'] = 2
conf['lr'] = 0.001
conf['weight_decay'] = 1.0e-7
conf['num_epoches'] = 100
conf['batch_size'] = 2048
conf['test_batch_size'] = 2048
conf['test_interval'] = 5
conf['data_loader_num'] = 10

conf['topk'] = [20,50,100]

conf['pop'] = 0

train(conf)
