In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

from libreco.data import random_split, DatasetPure
from libreco.algorithms import NCF  # pure data, 
from libreco.evaluation import evaluate

Instructions for updating:
non-resource variables are not supported in the long term


In [2]:
df_filtered = pd.read_csv("data/english_reviews.csv")
df_filtered = df_filtered[['user_id', 'gmap_id', 'rating']]

df_filtered['user_id'] = df_filtered['user_id'].astype('category')
df_filtered['gmap_id'] = df_filtered['gmap_id'].astype('category')

# Convert user_id and gmap_id to their respective codes if using embedding layers in a neural model
df_filtered['user_id'] = df_filtered['user_id'].cat.codes
df_filtered['gmap_id'] = df_filtered['gmap_id'].cat.codes

data = df_filtered.rename(columns={
    'user_id': 'user',
    'gmap_id': 'item',
    'rating': 'label'
})

train_data, test_data = train_test_split(data, 
                                         test_size=0.3, random_state=209)

In [3]:
train_data, data_info= DatasetPure.build_trainset(train_data)
test_data = DatasetPure.build_testset(test_data)

In [4]:
data_info

n_users: 118311, n_items: 5891, data density: 0.0542 %

In [5]:
ncf = NCF(
    task="rating",
    data_info=data_info,
    loss_type="cross_entropy",
    embed_size=16,
    n_epochs=10,
    lr=1e-3,
    batch_size=2048,
    num_neg=1,
)

In [6]:
ncf.fit(
    train_data,
    neg_sampling=False, #for rating, this param is false else True
    verbose=2,
    metrics=["loss"],
)

# do final evaluation on test data
evaluate(
    model=ncf,
    data=test_data,
    neg_sampling=False,
    metrics=["loss"],
)

Training start time: [35m2024-05-08 20:34:08[0m
Instructions for updating:
Colocations handled automatically by placer.


  net = tf.layers.batch_normalization(net, training=is_training)
Instructions for updating:
Colocations handled automatically by placer.
  net = tf.layers.batch_normalization(net, training=is_training)
2024-05-08 20:34:09.246982: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:388] MLIR V1 optimization pass is not enabled
train: 100%|██████████| 369/369 [00:04<00:00, 88.69it/s] 


Epoch 1 elapsed: 4.189s
	 [32mtrain_loss: 3.7902[0m


train: 100%|██████████| 369/369 [00:03<00:00, 110.60it/s]


Epoch 2 elapsed: 3.338s
	 [32mtrain_loss: 0.9363[0m


train: 100%|██████████| 369/369 [00:03<00:00, 108.54it/s]


Epoch 3 elapsed: 3.401s
	 [32mtrain_loss: 0.5924[0m


train: 100%|██████████| 369/369 [00:03<00:00, 109.49it/s]


Epoch 4 elapsed: 3.372s
	 [32mtrain_loss: 0.426[0m


train: 100%|██████████| 369/369 [00:03<00:00, 106.20it/s]


Epoch 5 elapsed: 3.476s
	 [32mtrain_loss: 0.3296[0m


train: 100%|██████████| 369/369 [00:03<00:00, 99.52it/s] 


Epoch 6 elapsed: 3.709s
	 [32mtrain_loss: 0.27[0m


train: 100%|██████████| 369/369 [00:03<00:00, 93.39it/s] 


Epoch 7 elapsed: 3.953s
	 [32mtrain_loss: 0.2325[0m


train: 100%|██████████| 369/369 [00:03<00:00, 108.35it/s]


Epoch 8 elapsed: 3.407s
	 [32mtrain_loss: 0.2051[0m


train: 100%|██████████| 369/369 [00:04<00:00, 86.40it/s] 


Epoch 9 elapsed: 4.273s
	 [32mtrain_loss: 0.1861[0m


train: 100%|██████████| 369/369 [00:03<00:00, 95.51it/s] 


Epoch 10 elapsed: 3.866s
	 [32mtrain_loss: 0.1725[0m


eval_pointwise:   0%|          | 0/20 [00:00<?, ?it/s]

[31mDetect 943 unknown interaction(s), position: [4097, 4098, 2051, 2052, 4099, 6, 4100, 2056, 4101, 4106, 2060, 6157, 17, 2068, 4116, 22, 4118, 26, 4123, 6172, 2078, 2081, 4131, 6179, 6180, 38, 2086, 4137, 6187, 2092, 2093, 46, 4145, 59, 6203, 2110, 2111, 66, 2114, 6212, 4165, 4167, 6220, 4175, 4178, 2136, 6235, 4188, 93, 2142, 6237, 6242, 99, 2147, 2148, 102, 6245, 4201, 106, 2155, 108, 111, 2159, 4209, 2162, 4211, 4214, 119, 6266, 123, 2172, 125, 6268, 4227, 6275, 135, 6283, 6292, 6293, 153, 2201, 6298, 6299, 4253, 6138, 6306, 164, 6307, 4262, 4264, 2217, 2220, 4272, 2225, 6322, 6325, 4280, 2237, 4285, 6333, 6340, 4294, 202, 6346, 206, 2257, 214, 2265, 4313, 219, 2281, 2282, 4332, 4333, 238, 6380, 2289, 2290, 4338, 244, 2293, 4339, 247, 4341, 249, 4343, 4348, 2304, 4352, 4356, 6406, 267, 6411, 4365, 4367, 6415, 2321, 274, 276, 4376, 282, 2330, 4379, 4380, 6428, 291, 6439, 2345, 4393, 301, 2349, 4398, 4402, 4404, 2359, 2360, 316, 4412, 318, 2370, 6469, 6473, 330, 6475, 334, 6478, 64

eval_pointwise:  35%|███▌      | 7/20 [00:00<00:00, 69.69it/s]

[31mDetect 909 unknown interaction(s), position: [3, 2051, 2053, 6151, 4107, 13, 18, 2074, 4124, 2077, 4130, 36, 2086, 6184, 6186, 46, 4142, 49, 50, 54, 2102, 4152, 61, 6206, 4159, 2112, 4160, 6208, 67, 6210, 4166, 4168, 2122, 6222, 6230, 4183, 88, 6233, 6238, 95, 4191, 6241, 2147, 106, 2157, 4205, 111, 112, 2160, 6259, 6261, 123, 6268, 126, 4224, 131, 4229, 6281, 138, 6282, 6284, 141, 2191, 4246, 153, 2203, 156, 157, 4253, 4255, 160, 2212, 165, 2219, 174, 2222, 4272, 6318, 6321, 2227, 4275, 6327, 6332, 191, 2240, 193, 4290, 2243, 2244, 197, 198, 4291, 4292, 4297, 4298, 6341, 204, 6354, 211, 2260, 4308, 215, 6360, 2268, 8186, 231, 6375, 6377, 238, 4335, 6383, 8189, 6387, 4342, 249, 250, 4349, 4350, 6398, 2304, 4354, 4356, 4358, 2311, 2315, 4364, 6412, 2320, 273, 2322, 275, 4374, 4376, 4378, 6427, 288, 6433, 292, 295, 6443, 6449, 4402, 308, 312, 6458, 6461, 6462, 4419, 4431, 2384, 2385, 340, 6485, 346, 6494, 2404, 4455, 4458, 363, 4466, 6515, 376, 6530, 390, 2440, 6536, 394, 6539, 6540

eval_pointwise: 100%|██████████| 20/20 [00:00<00:00, 88.70it/s]

[31mDetect 915 unknown interaction(s), position: [4104, 6154, 11, 2059, 2064, 2065, 2066, 4115, 20, 22, 6166, 6167, 2081, 6178, 35, 2090, 6189, 4148, 6197, 54, 4152, 60, 61, 62, 63, 6208, 6209, 6213, 77, 78, 79, 2127, 6225, 82, 4179, 6228, 91, 92, 6238, 6239, 96, 97, 98, 4201, 4207, 112, 4209, 2162, 115, 2167, 121, 4222, 4223, 6271, 4227, 6278, 135, 4231, 6284, 6291, 2196, 150, 2201, 4250, 6299, 6137, 6304, 4258, 6309, 6313, 4271, 6320, 4274, 2227, 6329, 6330, 188, 190, 2238, 192, 2239, 6337, 4291, 2244, 2245, 6341, 6343, 4299, 204, 6347, 206, 4302, 6351, 210, 2258, 4306, 215, 4313, 2267, 6363, 4317, 222, 223, 6367, 226, 4324, 4327, 233, 6377, 6380, 6383, 244, 251, 8191, 6398, 6399, 256, 257, 2304, 6401, 263, 265, 2313, 4361, 270, 2321, 4372, 282, 4378, 285, 286, 287, 2338, 4387, 295, 2345, 6441, 2348, 303, 4403, 2358, 312, 6458, 315, 6459, 317, 6460, 2369, 4417, 4419, 6468, 6472, 2377, 4425, 6474, 2381, 2383, 4431, 2386, 4439, 4440, 6493, 2399, 4447, 6500, 360, 4462, 4474, 6528, 4481




{'loss': 1.188821}