# DeepFM
细节部分：
    
   对于嵌入部分，每一特征代表一个feild  这里的数据每个特征是一个值
   
   对于离散值和连续值进行处理，得到一个 ∑n_i * featrue_size_i
    
   通过对每个特征的词嵌入，将特征转移到特定维度维度，对于每个filed
   嵌入矩阵的大小是  feature_size_i * embedding
   
   对于Deep部分：将得到的Embedding，拼接成，输入的大小(n_feild * embedding)
   
   对于FM部分：这里的Vi,f使用的是之前的Embedding的权重，也就是说K = embedding，这里在做矩阵相乘时候】， 由于是one-hot形式，所以不是直用x的整个向量，而是用 1值得索引乘以嵌入矩阵，这里的嵌入矩阵也就是Embedding部分的矩阵。再通过公式计算二阶特征交叉部分。
   
   最后一层DNN 是将Deep，FM 部分的输出元素求和,在通过激活得到预测值，在他通过交叉熵计算损失。
   
#### 参考文章
[deepFM in pytorch](https://blog.csdn.net/w55100/article/details/90295932)

[DeepFM全方面解析（附pytorch源码）](https://zhuanlan.zhihu.com/p/84526966)

[原文翻译](https://zhuanlan.zhihu.com/p/57873613)

In [1]:
import numpy as np

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

from data.dataset import CriteoDataset

### 加载数据

In [2]:
train_data = CriteoDataset('./data',train=True)

In [43]:
train_data

<data.dataset.CriteoDataset at 0x26dd3097908>

In [14]:
# 900000 items for training, 10000 items for valid, of all 1000000 items
Num_train = 800
# 产生batch 的数据   随机产生
loader_train = DataLoader(train_data,batch_size=50,
                    sampler=sampler.SubsetRandomSampler(range(Num_train)))

In [23]:
feature_sizes = np.loadtxt('./data/feature_sizes.txt', delimiter=',')
feature_sizes = [int(x) for x in feature_sizes]
len(feature_sizes)

39

In [48]:
loader_train.dataset

<data.dataset.CriteoDataset at 0x26dd3097908>

### DeepModel

In [20]:
import torch.nn as nn
import torch.nn.functional as F
from time import time

In [61]:
class DeepFM(nn.Module):
    def __init__(self, feature_sizes, embedding_size=4,
                 hidden_dims=[32, 32], num_classes=10, dropout=[0.5, 0.5], 
                 use_cuda=False, verbose=False):
        super().__init__()
        self.field_size = len(feature_sizes)
        self.feature_sizes = feature_sizes
        self.embedding_size = embedding_size
        self.hidden_dims = hidden_dims
        self.num_classes = num_classes
        self.dtype = torch.float
        
        self.device = torch.device('cpu')
        
        """
        init fm part
        """
#       (0): Linear(in_features=1, out_features=4, bias=True)
#         连续值  使用全连接神经网络
        fm_first_order_Linears = nn.ModuleList(
                [nn.Linear(feature_size, self.embedding_size) for feature_size in self.feature_sizes[:13]])

#       (0): Embedding(1, 4) (1): Embedding(1, 4)
#          离散值  使用Embedding
        fm_first_order_embeddings = nn.ModuleList(
                [nn.Embedding(feature_size, self.embedding_size) for feature_size in self.feature_sizes[13:40]])
        
#         (12): Linear(in_features=1, out_features=4, bias=True)
#         (13): Embedding(1, 4)
        self.fm_first_order_models = fm_first_order_Linears.extend(fm_first_order_embeddings)    
        
#         二阶
        fm_second_order_Linears = nn.ModuleList(
                [nn.Linear(feature_size, self.embedding_size) for feature_size in self.feature_sizes[:13]])
        fm_second_order_embeddings = nn.ModuleList(
                [nn.Embedding(feature_size, self.embedding_size) for feature_size in self.feature_sizes[13:40]])
        self.fm_second_order_models = fm_second_order_Linears.extend(fm_second_order_embeddings)
        
        
        """
            init deep part
            这部分是整个模型这最后的DNN部分
        """
        all_dims = [self.field_size * self.embedding_size] + \
            self.hidden_dims + [self.num_classes]
        
        for i in range(1, len(hidden_dims) + 1):
            setattr(self, 'linear_'+str(i),
                    nn.Linear(all_dims[i-1], all_dims[i]))
            # nn.init.kaiming_normal_(self.fc1.weight)
            setattr(self, 'batchNorm_' + str(i),
                    nn.BatchNorm1d(all_dims[i]))
            setattr(self, 'dropout_'+str(i),
                    nn.Dropout(dropout[i-1]))


    def forward(self, Xi, Xv):
        """
        Inputs:
        - Xi: 索引值 (N, field_size, 1)
        - Xv: 具体的值 (N, field_size, 1)
        """
        """
            fm part
        """
        emb = self.fm_first_order_models[20]
        for num in Xi[:, 20, :][0]:
            if num > self.feature_sizes[20]:
                print("index out")
#         将所有特征进行emdedding 采用索引去嵌入矩阵找Embedding 减少计算量
#         这部分也是FM1 部分  即 ∑vi,f xn
        fm_first_order_emb_arr = []
        for i, emb in enumerate(self.fm_first_order_models):
#             连续值
            if i <=12:
                Xi_tem = Xi[:, i, :].to(device=self.device, dtype=torch.float)
                fm_first_order_emb_arr.append((torch.sum(emb(Xi_tem).unsqueeze(1), 1).t() * Xv[:, i]).t())
            else:
                Xi_tem = Xi[:, i, :].to(device=self.device, dtype=torch.long)
                fm_first_order_emb_arr.append((torch.sum(emb(Xi_tem), 1).t() * Xv[:, i]).t())
        
        fm_first_order = torch.cat(fm_first_order_emb_arr, 1)
        
#         这部分也是FM2 部分  即 ∑vi,f xn
        fm_second_order_emb_arr = []
        for i, emb in enumerate(self.fm_second_order_models):
            if i <=12:
                Xi_tem = Xi[:, i, :].to(device=self.device, dtype=torch.float)
                fm_second_order_emb_arr.append((torch.sum(emb(Xi_tem).unsqueeze(1), 1).t() * Xv[:, i]).t())
            else:
                Xi_tem = Xi[:, i, :].to(device=self.device, dtype=torch.long)
                fm_second_order_emb_arr.append((torch.sum(emb(Xi_tem), 1).t() * Xv[:, i]).t())
        
        fm_sum_second_order_emb = sum(fm_second_order_emb_arr)
        fm_sum_second_order_emb_square = fm_sum_second_order_emb * \
            fm_sum_second_order_emb  # (x+y)^2
        fm_second_order_emb_square = [
            item*item for item in fm_second_order_emb_arr]
        fm_second_order_emb_square_sum = sum(
            fm_second_order_emb_square)  # x^2+y^2
        fm_second_order = (fm_sum_second_order_emb_square -
                           fm_second_order_emb_square_sum) * 0.5
        
        
#             shape = n_fild * embedding_size
        deep_emb = torch.cat(fm_second_order_emb_arr, 1)
        deep_out = deep_emb
        for i in range(1, len(self.hidden_dims) + 1):
            deep_out = getattr(self, 'linear_' + str(i))(deep_out)
            deep_out = getattr(self, 'batchNorm_' + str(i))(deep_out)
            deep_out = getattr(self, 'dropout_' + str(i))(deep_out)
        
#         sum
        bias = torch.nn.Parameter(torch.randn(Xi.size(0)))
        total_sum = torch.sum(fm_first_order, 1) + \
                    torch.sum(fm_second_order, 1) + \
                    torch.sum(deep_out, 1) + bias
        return total_sum

    def fit(self, loader_train, loader_val, optimizer, epochs=1, verbose=False, print_every=5):
       
        model = self.train().to(device=self.device)
        criterion = F.binary_cross_entropy_with_logits

        for epoch in range(epochs):
            for t, (xi, xv, y) in enumerate(loader_train):
                xi = xi.to(device=self.device, dtype=self.dtype)
                xv = xv.to(device=self.device, dtype=torch.float)
                y = y.to(device=self.device, dtype=self.dtype)
                
                total = model(xi, xv)
#                print(total.shape)
#                print(y.shape)
                loss = criterion(total, y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if verbose and t % print_every == 0:
                    print('Epoch %d Iteration %d, loss = %.4f' % (epoch, t, loss.item()))
                    self.check_accuracy(loader_val, model)
                    print()
    
    def check_accuracy(self, loader, model):
        if loader.dataset.train:
            print('Checking accuracy on validation set')
        else:
            print('Checking accuracy on test set')   
        num_correct = 0
        num_samples = 0
        model.eval()  # set model to evaluation mode
        with torch.no_grad():
            for xi, xv, y in loader:
                xi = xi.to(device=self.device, dtype=self.dtype)  # move to device, e.g. GPU
                xv = xv.to(device=self.device, dtype=self.dtype)
                y = y.to(device=self.device, dtype=self.dtype)
                total = model(xi, xv)
                preds = (F.sigmoid(total) > 0.5).to(dtype=self.dtype)
#                print(preds.dtype)
#                print(y.dtype)
#                print(preds.eq(y).cpu().sum())
                num_correct += (preds == y).sum()
                num_samples += preds.size(0)
#                print("successful")
            acc = float(num_correct) / num_samples
            print('Got %d / %d correct (%.2f%%)' % (num_correct, num_samples, 100 * acc))



In [62]:
model = DeepFM(feature_sizes, use_cuda=False) 

In [63]:
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.0)

In [65]:
val_data = CriteoDataset('./data', train=True)
loader_val = DataLoader(val_data, batch_size=50,
                        sampler=sampler.SubsetRandomSampler(range(Num_train, 899)))

In [66]:
model.fit(loader_train, loader_val, optimizer, epochs=100, verbose=True)

Epoch 0 Iteration 0, loss = 867554.1875
Checking accuracy on validation set
Got 36 / 99 correct (36.36%)





Epoch 0 Iteration 5, loss = 2145652.5000
Checking accuracy on validation set
Got 36 / 99 correct (36.36%)

Epoch 0 Iteration 10, loss = 414614.0000
Checking accuracy on validation set
Got 37 / 99 correct (37.37%)

Epoch 0 Iteration 15, loss = 364650.5938
Checking accuracy on validation set
Got 37 / 99 correct (37.37%)

Epoch 1 Iteration 0, loss = 362912.5625
Checking accuracy on validation set
Got 37 / 99 correct (37.37%)

Epoch 1 Iteration 5, loss = 270534.9062
Checking accuracy on validation set
Got 37 / 99 correct (37.37%)

Epoch 1 Iteration 10, loss = 1284424.0000
Checking accuracy on validation set
Got 38 / 99 correct (38.38%)

Epoch 1 Iteration 15, loss = 836077.8125
Checking accuracy on validation set
Got 38 / 99 correct (38.38%)

Epoch 2 Iteration 0, loss = 436198.8750
Checking accuracy on validation set
Got 38 / 99 correct (38.38%)

Epoch 2 Iteration 5, loss = 953946.2500
Checking accuracy on validation set
Got 37 / 99 correct (37.37%)

Epoch 2 Iteration 10, loss = 501731.5938

Checking accuracy on validation set
Got 39 / 99 correct (39.39%)

Epoch 19 Iteration 15, loss = 1723087.8750
Checking accuracy on validation set
Got 39 / 99 correct (39.39%)

Epoch 20 Iteration 0, loss = 1361971.8750
Checking accuracy on validation set
Got 39 / 99 correct (39.39%)

Epoch 20 Iteration 5, loss = 325328.9375
Checking accuracy on validation set
Got 39 / 99 correct (39.39%)

Epoch 20 Iteration 10, loss = 288436.9688
Checking accuracy on validation set
Got 40 / 99 correct (40.40%)

Epoch 20 Iteration 15, loss = 449572.5312
Checking accuracy on validation set
Got 40 / 99 correct (40.40%)

Epoch 21 Iteration 0, loss = 401086.4688
Checking accuracy on validation set
Got 40 / 99 correct (40.40%)

Epoch 21 Iteration 5, loss = 766924.5625
Checking accuracy on validation set
Got 40 / 99 correct (40.40%)

Epoch 21 Iteration 10, loss = 206574.5156
Checking accuracy on validation set
Got 40 / 99 correct (40.40%)

Epoch 21 Iteration 15, loss = 263214.4062
Checking accuracy on validatio

Got 46 / 99 correct (46.46%)

Epoch 38 Iteration 15, loss = 434176.9062
Checking accuracy on validation set
Got 46 / 99 correct (46.46%)

Epoch 39 Iteration 0, loss = 473357.8125
Checking accuracy on validation set
Got 46 / 99 correct (46.46%)

Epoch 39 Iteration 5, loss = 480722.6875
Checking accuracy on validation set
Got 46 / 99 correct (46.46%)

Epoch 39 Iteration 10, loss = 525859.9375
Checking accuracy on validation set
Got 46 / 99 correct (46.46%)

Epoch 39 Iteration 15, loss = 431755.5312
Checking accuracy on validation set
Got 46 / 99 correct (46.46%)

Epoch 40 Iteration 0, loss = 334610.0312
Checking accuracy on validation set
Got 46 / 99 correct (46.46%)

Epoch 40 Iteration 5, loss = 297428.9062
Checking accuracy on validation set
Got 46 / 99 correct (46.46%)

Epoch 40 Iteration 10, loss = 349859.3750
Checking accuracy on validation set
Got 46 / 99 correct (46.46%)

Epoch 40 Iteration 15, loss = 158721.9531
Checking accuracy on validation set
Got 46 / 99 correct (46.46%)

Ep

Epoch 57 Iteration 15, loss = 393742.7500
Checking accuracy on validation set
Got 50 / 99 correct (50.51%)

Epoch 58 Iteration 0, loss = 423454.4688
Checking accuracy on validation set
Got 50 / 99 correct (50.51%)

Epoch 58 Iteration 5, loss = 295889.9375
Checking accuracy on validation set
Got 50 / 99 correct (50.51%)

Epoch 58 Iteration 10, loss = 175239.9531
Checking accuracy on validation set
Got 51 / 99 correct (51.52%)

Epoch 58 Iteration 15, loss = 195371.1250
Checking accuracy on validation set
Got 51 / 99 correct (51.52%)

Epoch 59 Iteration 0, loss = 137006.7344
Checking accuracy on validation set
Got 51 / 99 correct (51.52%)

Epoch 59 Iteration 5, loss = 351781.1250
Checking accuracy on validation set
Got 51 / 99 correct (51.52%)

Epoch 59 Iteration 10, loss = 110046.2500
Checking accuracy on validation set
Got 51 / 99 correct (51.52%)

Epoch 59 Iteration 15, loss = 187943.7188
Checking accuracy on validation set
Got 51 / 99 correct (51.52%)

Epoch 60 Iteration 0, loss = 506

Epoch 76 Iteration 15, loss = 418174.0000
Checking accuracy on validation set
Got 59 / 99 correct (59.60%)

Epoch 77 Iteration 0, loss = 63881.1719
Checking accuracy on validation set
Got 59 / 99 correct (59.60%)

Epoch 77 Iteration 5, loss = 240106.2188
Checking accuracy on validation set
Got 59 / 99 correct (59.60%)

Epoch 77 Iteration 10, loss = 97255.1406
Checking accuracy on validation set
Got 58 / 99 correct (58.59%)

Epoch 77 Iteration 15, loss = 262856.9062
Checking accuracy on validation set
Got 58 / 99 correct (58.59%)

Epoch 78 Iteration 0, loss = 274786.2812
Checking accuracy on validation set
Got 58 / 99 correct (58.59%)

Epoch 78 Iteration 5, loss = 932487.6250
Checking accuracy on validation set
Got 58 / 99 correct (58.59%)

Epoch 78 Iteration 10, loss = 60115.8281
Checking accuracy on validation set
Got 58 / 99 correct (58.59%)

Epoch 78 Iteration 15, loss = 190426.2812
Checking accuracy on validation set
Got 58 / 99 correct (58.59%)

Epoch 79 Iteration 0, loss = 108155

Epoch 96 Iteration 5, loss = 120553.8984
Checking accuracy on validation set
Got 61 / 99 correct (61.62%)

Epoch 96 Iteration 10, loss = 268369.9688
Checking accuracy on validation set
Got 61 / 99 correct (61.62%)

Epoch 96 Iteration 15, loss = 86923.5938
Checking accuracy on validation set
Got 62 / 99 correct (62.63%)

Epoch 97 Iteration 0, loss = 124562.5703
Checking accuracy on validation set
Got 62 / 99 correct (62.63%)

Epoch 97 Iteration 5, loss = 87609.6562
Checking accuracy on validation set
Got 62 / 99 correct (62.63%)

Epoch 97 Iteration 10, loss = 40275.4844
Checking accuracy on validation set
Got 62 / 99 correct (62.63%)

Epoch 97 Iteration 15, loss = 167317.3906
Checking accuracy on validation set
Got 62 / 99 correct (62.63%)

Epoch 98 Iteration 0, loss = 143867.7969
Checking accuracy on validation set
Got 62 / 99 correct (62.63%)

Epoch 98 Iteration 5, loss = 84273.2734
Checking accuracy on validation set
Got 62 / 99 correct (62.63%)

Epoch 98 Iteration 10, loss = 28647.1