# LightFM features revisited

* LightFMの埋め込み表現の取得の仕方が間違っていたりしたので再訪

In [1]:
import datetime

import faiss
import matplotlib.pyplot as plt
import numpy as np
import optuna
import pandas as pd
import psutil
from lightfm import LightFM

import schema
from metric import mapk
from scipy import sparse
from logzero import logger
from utils import train_valid_split

In [2]:
transactions = pd.read_pickle('input/transformed/transactions_train.pkl')[schema.TRANSACTIONS]
articles = pd.read_pickle('input/transformed/articles.pkl')[schema.ARTICLES]
customers = pd.read_pickle('input/transformed/customers.pkl')[schema.CUSTOMERS]
TOPK = 12

In [3]:
tmp = datetime.date(2020, 9, 16) - datetime.timedelta(days=21)
transactions = transactions.query("t_dat >= @tmp")

users = sorted(transactions.customer_id_idx.unique())
items = sorted(transactions.article_id_idx.unique())
mp_user = {x: i for i, x in enumerate(users)}
mp_item = {x: i for i, x in enumerate(items)}
transactions.customer_id_idx = transactions.customer_id_idx.apply(lambda x: mp_user[x])
transactions.article_id_idx = transactions.article_id_idx.apply(lambda x: mp_item[x])

articles = articles.query("article_id_idx in @items").reset_index(drop=True)
articles.article_id_idx = articles.article_id_idx.apply(lambda x: mp_item[x])

customers = customers.query("customer_id_idx in @users").reset_index(drop=True)
customers.customer_id_idx = customers.customer_id_idx.apply(lambda x: mp_user[x])

n_user = len(users)
n_item = len(items)
print(n_user, n_item)

233174 28297


In [4]:
def create_customer_features(customers):
    df = customers[['age']].reset_index(drop=True)
    df['age_is_null'] = df['age'].isnull().astype(int)
    df['age_10_20'] = ((10 <= df['age']) & (df['age'] < 20)).astype(int)
    df['age_20_30'] = ((20 <= df['age']) & (df['age'] < 30)).astype(int)
    df['age_30_40'] = ((30 <= df['age']) & (df['age'] < 40)).astype(int)
    df['age_40_50'] = ((40 <= df['age']) & (df['age'] < 50)).astype(int)
    df['age_50_60'] = ((50 <= df['age']) & (df['age'] < 60)).astype(int)
    df['age_60_70'] = ((60 <= df['age']) & (df['age'] < 70)).astype(int)
    df['age_70_100'] = ((70 <= df['age']) & (df['age'] < 100)).astype(int)
    df = df.drop(['age'], axis=1)
    idxs = sparse.identity(len(df), dtype='f', format='csr')
    a = df.values / np.sum(df.values, axis=0)
    # return idxs
    return sparse.hstack([idxs, a]).astype('float32')

def create_article_features(articles):
    df = articles.drop(['article_id_idx', 'department_no_idx', 'product_type_no_idx'], axis=1)
    for c in df.columns:
        df = pd.concat([df, pd.get_dummies(df[c], prefix=c)], axis=1).drop(c, axis=1)
    idxs = sparse.identity(len(df), dtype='f', format='csr')
    a = df.values / np.sum(df.values, axis=0)
    # return idxs
    return sparse.hstack([idxs, a]).astype('float32')

user_features = create_customer_features(customers)
item_features = create_article_features(articles)
print(user_features.shape, item_features.shape)
print(user_features.min(), user_features.max())
print(item_features.min(), item_features.max())

(233174, 233182) (28297, 28511)
0.0 1.0
0.0 1.0


In [5]:
train_days = 21
no_components = 1024
learning_schedule = 'adadelta'
loss = 'bpr'
learning_rate = 0.005
item_alpha = 1e-8
user_alpha = 1e-8
max_sampled = 10

lightfm_params = {
    'no_components': no_components,
    'learning_schedule': learning_schedule,
    'loss': loss,
    'learning_rate': learning_rate,
    'item_alpha': item_alpha,
    'user_alpha': user_alpha,
    'max_sampled': max_sampled,
}
print(lightfm_params)

transactions_train, transactions_valid = train_valid_split(transactions, datetime.date(2020, 9, 16), 21)

val = transactions_valid.groupby('customer_id_idx')['article_id_idx'].apply(list).reset_index()

train = sparse.lil_matrix((n_user, n_item))
train[transactions_train.customer_id_idx, transactions_train.article_id_idx] = 1

model = LightFM(**lightfm_params)

[I 220308 15:30:14 utils:29] train: [2020-08-26, 2020-09-16)
[I 220308 15:30:14 utils:31] # of records: 803079
[I 220308 15:30:14 utils:16] valid: [2020-09-16, 2020-09-23)
[I 220308 15:30:14 utils:18] # of records: 240311


{'no_components': 1024, 'learning_schedule': 'adadelta', 'loss': 'bpr', 'learning_rate': 0.005, 'item_alpha': 1e-08, 'user_alpha': 1e-08, 'max_sampled': 10}


In [6]:
for epoch in range(10000):
    model.fit_partial(train, user_features=user_features, item_features=item_features, epochs=5, num_threads=psutil.cpu_count(logical=False), verbose=True)

    user_biases, user_embeddings = model.get_user_representations(user_features)
    item_biases, item_embeddings = model.get_item_representations(item_features)

    # naive mapk
    index = faiss.index_factory(no_components, "Flat", faiss.METRIC_INNER_PRODUCT)
    index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, index)
    index.add(item_embeddings)
    _, idxs = index.search(user_embeddings, TOPK)

    naive_mapk = mapk(val.article_id_idx, idxs[val.customer_id_idx])

    # fine mapk
    index = faiss.index_factory(no_components, "Flat", faiss.METRIC_INNER_PRODUCT)
    index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, index)
    index.add(item_embeddings)
    _, candidates = index.search(user_embeddings, 100)

    item_high_bias = np.argsort(item_biases)[::-1][:100]
    item_high_bias = np.array([item_high_bias] * n_user)
    candidates = np.hstack([candidates, item_high_bias])

    user_idxs = np.repeat(range(n_user), 100 + 100)

    result = model.predict(user_idxs, candidates.flatten(), user_features=user_features, item_features=item_features, num_threads=psutil.cpu_count(logical=False))
    result = result.reshape(n_user, 100 + 100)

    idxs_each_user = np.argsort(result, axis=1)[:,::-1][:,:12]
    pred = np.array([candidates[i, x] for i, x in enumerate(idxs_each_user)])

    fine_mapk = mapk(val.article_id_idx, pred[val.customer_id_idx])

    logger.info(f"epoch={epoch} naive_mapk={naive_mapk}, fine_mapk={fine_mapk}")

Epoch: 100%|██████████| 5/5 [02:26<00:00, 29.33s/it]
[I 220308 15:33:05 262837137:35] epoch=0 naive_mapk=0.016052232593096398, fine_mapk=0.01585857209207743
Epoch: 100%|██████████| 5/5 [02:27<00:00, 29.48s/it]
[I 220308 15:35:53 262837137:35] epoch=1 naive_mapk=0.018376696328719473, fine_mapk=0.01761700442582566
Epoch: 100%|██████████| 5/5 [02:29<00:00, 29.98s/it]
[I 220308 15:38:43 262837137:35] epoch=2 naive_mapk=0.019321531362885024, fine_mapk=0.01812911974368896
Epoch: 100%|██████████| 5/5 [02:30<00:00, 30.04s/it]
[I 220308 15:41:34 262837137:35] epoch=3 naive_mapk=0.01965725530836565, fine_mapk=0.01838752951030088
Epoch: 100%|██████████| 5/5 [02:28<00:00, 29.71s/it]
[I 220308 15:44:24 262837137:35] epoch=4 naive_mapk=0.019907902669869196, fine_mapk=0.018586802384635094
Epoch: 100%|██████████| 5/5 [02:28<00:00, 29.74s/it]
[I 220308 15:47:13 262837137:35] epoch=5 naive_mapk=0.02006273660369779, fine_mapk=0.01872121031794825
Epoch: 100%|██████████| 5/5 [02:25<00:00, 29.14s/it]
[I 220

KeyboardInterrupt: 