## FM

参考:
[特征交叉](https://zhuanlan.zhihu.com/p/499809627)

https://zhuanlan.zhihu.com/p/165610397?ivk_sa=1024320u
https://github.com/rixwew/pytorch-fm/tree/master/torchfm/model
https://github.com/shenweichen/DeepCTR-Torch

In [15]:
import numpy as np
import pandas as pd
import torch
import torch.utils.data as Data

In [8]:
data = pd.read_csv('../../samples/movielens_sample.txt')
# 将评分离散化为二维: 好评 vs 差评
# noinspection PyShadowingNames
data['applaud'] = data['rating'].apply(lambda x: 0 if x <= 3 else 1)
data

Unnamed: 0,user_id,movie_id,rating,timestamp,title,genres,gender,age,occupation,zip,applaud
0,3299,235,4,968035345,Ed Wood (1994),Comedy|Drama,F,25,4,19119,1
1,3630,3256,3,966536874,Patriot Games (1992),Action|Thriller,M,18,4,77005,0
2,517,105,4,976203603,"Bridges of Madison County, The (1995)",Drama|Romance,F,25,14,55408,1
3,785,2115,3,975430389,Indiana Jones and the Temple of Doom (1984),Action|Adventure,M,18,19,29307,0
4,5848,909,5,957782527,"Apartment, The (1960)",Comedy|Drama,M,50,20,20009,1
...,...,...,...,...,...,...,...,...,...,...,...
195,1427,3596,3,974840560,Screwed (2000),Comedy,M,25,12,21401,0
196,3868,1626,3,965855033,Fire Down Below (1997),Action|Drama|Thriller,M,18,12,73112,0
197,249,2369,3,976730191,Desperately Seeking Susan (1985),Comedy|Romance,F,18,14,48126,0
198,5720,349,4,958503395,Clear and Present Danger (1994),Action|Adventure|Thriller,M,25,0,60610,1


In [13]:
items = data.to_numpy()[:, :1]
np.max(items, axis=0)

array([6040], dtype=object)

In [44]:
user_and_item_ids = data[['user_id', 'movie_id']].to_numpy(dtype=np.long)
x_train_tensor = torch.from_numpy(data[['user_id', 'movie_id']].to_numpy(dtype=np.long))
y_train_tensor = torch.from_numpy(data['applaud'].to_numpy(dtype=np.float32)).float()
# 将训练数据的特征和标签组合
dataset = Data.TensorDataset(x_train_tensor, y_train_tensor)

user_item_n_fields = np.max(user_and_item_ids, axis=0) + 1
print(f'user_item_n_fields: {user_item_n_fields}')

# 把 dataset 放入 DataLoader
data_iter = Data.DataLoader(
    dataset=dataset,  # torch TensorDataset format
    batch_size=30,  # mini batch size
    shuffle=True,  # 要不要打乱数据 (打乱比较好)
    num_workers=2,  # 多线程来读数据, 这里数据直接内存生成
)

user_item_n_fields: [6041 3949]


In [14]:
class FeaturesEmbedding(torch.nn.Module):

    def __init__(self, field_dims, embed_dim):
        """
        构造一个embedding层, 同时embedding user和item
        :param field_dims: (user size, item size)
        :param embed_dim: embedding dim size
        """
        super().__init__()
        self.embedding = torch.nn.Embedding(sum(field_dims), embed_dim)
        self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
        torch.nn.init.xavier_uniform_(self.embedding.weight.data)

    # noinspection PyShadowingNames
    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x = x + x.new_tensor(self.offsets).unsqueeze(0)
        return self.embedding(x)

In [45]:
embedding = FeaturesEmbedding(user_item_n_fields, 10)

x = [each for i, each in enumerate(data_iter) if i == 0 ][0][0]
x

tensor([[ 366, 1077],
        [2841,  680],
        [1321, 2240],
        [ 753,  434],
        [3618, 3374],
        [ 877, 1485],
        [2271, 2671],
        [4966, 2100],
        [3759, 2151],
        [5039, 1792],
        [4658, 1009],
        [5365, 1892],
        [ 615,  296],
        [5108,  367],
        [3808,   61],
        [5893, 2144],
        [5371, 3194],
        [3568, 1230],
        [3558, 1580],
        [1106, 3624],
        [1836, 2736],
        [1685, 2664],
        [1579, 2420],
        [5746, 1242],
        [2230, 2873],
        [1601, 1396],
        [  80, 2059],
        [4802, 1208],
        [6040, 3224],
        [5056, 2700]], dtype=torch.int32)

In [46]:
embed_res = embedding(x)
print(embed_res.shape)
embed_res

torch.Size([30, 2, 10])


tensor([[[-1.0238e-03, -7.9253e-03, -2.0530e-02, -1.1655e-02,  7.4949e-03,
          -5.1574e-03,  1.5906e-02, -1.9512e-02, -1.9675e-02, -1.0940e-02],
         [ 1.9630e-02, -1.0728e-02,  1.4925e-02,  9.8429e-03, -1.8240e-02,
          -1.2869e-03, -1.9622e-02,  1.5011e-02,  1.5504e-02,  4.8443e-03]],

        [[-1.5660e-02,  9.2649e-03,  6.6648e-03, -1.1224e-04,  1.8061e-02,
           8.0601e-03,  1.0101e-02,  1.5616e-02,  2.7211e-03,  1.9118e-02],
         [ 1.9375e-03, -2.2568e-02, -1.2676e-02,  1.7642e-02,  7.7042e-04,
          -2.4413e-02, -1.5074e-02,  1.2074e-02, -1.3733e-02, -6.7527e-03]],

        [[ 1.6729e-02,  1.2651e-02, -3.2266e-03,  5.7436e-03,  6.0802e-03,
          -1.1706e-02, -1.0880e-03, -2.2708e-02, -2.1249e-03,  2.0017e-02],
         [ 1.8121e-02,  4.8907e-03, -1.0081e-02, -6.1574e-04,  1.2938e-02,
           1.4851e-02, -8.4624e-03, -1.9395e-02, -1.5049e-02,  1.7139e-02]],

        [[ 1.4806e-02,  1.1371e-02, -1.9086e-02, -6.5733e-03, -1.5890e-02,
          -1.