In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from scipy import sparse
import numpy as np
from torch import nn
from torch.nn import functional as F
import faiss
from metric import mapk
import pandas as pd
import schema
from utils import extract_transactions_train, extract_transactions_valid
import datetime
from logzero import logger

In [2]:
class BPRDataset(Dataset):
    def __init__(self, n_user: int, n_item: int, transactions: np.ndarray):
        """
        Parameters
        ----------
        n_user
            number of users
        n_item
            number of items
        transactions
            (n_transactions, 2) 2d array
        """
        self.n_user = n_user
        self.n_item = n_item
        self.transactions = transactions
        self.transactions_matrix = sparse.lil_matrix((self.n_user, self.n_item), dtype='int')
        self.transactions_matrix[transactions[:,0], transactions[:,1]] = 1

    def __len__(self):
        return len(self.transactions)

    def __getitem__(self, idx):
        user, pos_item = self.transactions[idx]
        while True:
            neg_item = np.random.randint(0, self.n_item)
            if self.transactions_matrix[user, neg_item] == 0:
                break
        return torch.tensor(user), torch.tensor(pos_item), torch.tensor(neg_item)

In [3]:
class BPRModel(nn.Module):
    def __init__(self, n_user: int, n_item: int, embedding_dim: int):
        super(BPRModel, self).__init__()
        self.n_user = n_user
        self.n_item = n_item

        self.embedding_dim = embedding_dim

        self.user_embedding = nn.Embedding(self.n_user, embedding_dim)
        self.item_embedding = nn.Embedding(self.n_item, embedding_dim)

    def forward(self, users, pos_items, neg_items):
        x_users = self.forward_user(users)
        x_pos_items = self.forward_item(pos_items)
        x_neg_items = self.forward_item(neg_items)

        pos = (x_users * x_pos_items).sum(dim=1)
        neg = (x_users * x_neg_items).sum(dim=1)
        return pos, neg

    def forward_user(self, users):
        return self.user_embedding(users)

    def forward_item(self, items):
        return self.item_embedding(items)

In [4]:
transactions = pd.read_pickle('input/transformed/transactions_train.pkl')[schema.TRANSACTIONS]
articles = pd.read_pickle('input/transformed/articles.pkl')[schema.ARTICLES]
customers = pd.read_pickle('input/transformed/customers.pkl')[schema.CUSTOMERS]

In [5]:
tmp = datetime.date(2020, 9, 16) - datetime.timedelta(days=21)
transactions = transactions.query("t_dat >= @tmp")

users = sorted(transactions.customer_id_idx.unique())
items = sorted(transactions.article_id_idx.unique())
mp_user = {x: i for i, x in enumerate(users)}
mp_item = {x: i for i, x in enumerate(items)}
transactions.customer_id_idx = transactions.customer_id_idx.apply(lambda x: mp_user[x])
transactions.article_id_idx = transactions.article_id_idx.apply(lambda x: mp_item[x])

customers = customers.query("customer_id_idx in @users").reset_index(drop=True)
articles = articles.query("article_id_idx in @items").reset_index(drop=True)

customers.customer_id_idx = customers.customer_id_idx.apply(lambda x: mp_user[x])
articles.article_id_idx = articles.article_id_idx.apply(lambda x: mp_item[x])

n_user = len(users)
n_item = len(items)

In [6]:

transactions_valid = extract_transactions_valid(transactions, datetime.date(2020, 9, 16))
transactions_train = extract_transactions_train(transactions, datetime.date(2020, 9, 16), 21)

[I 220307 13:41:56 utils:14] valid: [2020-09-16, 2020-09-23)
[I 220307 13:41:56 utils:16] # of records: 240311
[I 220307 13:41:56 utils:27] train: [2020-08-26, 2020-09-16)
[I 220307 13:41:56 utils:29] # of records: 803079


In [7]:
val = transactions_valid.groupby('customer_id_idx')['article_id_idx'].apply(list).reset_index()

In [8]:
dataset = BPRDataset(n_user, n_item, transactions_train[['customer_id_idx', 'article_id_idx']].values)
model = BPRModel(n_user, n_item, 128).cuda()

In [9]:
def calc_user_representations(model):
    users = torch.from_numpy(np.arange(model.n_user))
    user_dataset = TensorDataset(users)
    user_loader = DataLoader(user_dataset, batch_size=256, shuffle=False)
    representations = []
    for x in user_loader:
        x = x[0].cuda()
        representations.append(model.forward_user(x).cpu().detach().numpy())
    return np.vstack(representations)


def calc_item_representations(model):
    items = torch.from_numpy(np.arange(model.n_item))
    item_dataset = TensorDataset(items)
    item_loader = DataLoader(item_dataset, batch_size=256, shuffle=False)
    representations = []
    for x in item_loader:
        x = x[0].cuda()
        representations.append(model.forward_item(x).cpu().detach().numpy())
    return np.vstack(representations)


In [10]:
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, drop_last=True, num_workers=4)

model = model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def loss_fn(pos_output, neg_output):
    return -(pos_output - neg_output).sigmoid().log().mean()

for _ in range(1000):
    model.train()
    losses = []
    for users, items_pos, items_neg in dataloader:
        users, items_pos, items_neg = users.cuda(), items_pos.cuda(), items_neg.cuda()

        pos_output, neg_output = model(users, items_pos, items_neg)
        loss = loss_fn(pos_output, neg_output)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.cpu().detach().item())

    model.eval()
    user_representations = calc_user_representations(model)
    item_representations = calc_item_representations(model)

    index = faiss.index_factory(model.embedding_dim, "Flat", faiss.METRIC_INNER_PRODUCT)
    index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, index)
    index.add(item_representations)
    _, idxs = index.search(user_representations, 12)

    logger.info(f"loss: {np.mean(losses)}, map: {mapk(val.article_id_idx, idxs[val.customer_id_idx])}")

[I 220307 13:43:07 3238338069:33] loss: 6.20738950146586, map: 8.841319609637867e-05
[I 220307 13:44:17 3238338069:33] loss: 4.937219098528594, map: 8.59528006454416e-05
[I 220307 13:45:26 3238338069:33] loss: 3.9870531508069074, map: 9.182804280218174e-05
[I 220307 13:46:34 3238338069:33] loss: 3.2497269384299337, map: 9.368210615304804e-05
[I 220307 13:47:43 3238338069:33] loss: 2.6611531829811184, map: 0.0001014235435985407
[I 220307 13:48:49 3238338069:33] loss: 2.187795760983451, map: 9.696115095555545e-05
[I 220307 13:49:53 3238338069:33] loss: 1.8236980932531062, map: 0.00012735965634981557
[I 220307 13:50:57 3238338069:33] loss: 1.508634222443977, map: 0.00015151952995847435
[I 220307 13:52:01 3238338069:33] loss: 1.2487400842246867, map: 0.00018393238183052048
[I 220307 13:53:05 3238338069:33] loss: 1.0283187861973822, map: 0.0002211758454885743
[I 220307 13:54:09 3238338069:33] loss: 0.8594680130030458, map: 0.00029535055204050906
[I 220307 13:55:15 3238338069:33] loss: 0.705

KeyboardInterrupt: 

In [15]:
item_representations.min()

-4.7970004