In [1]:
import os
import sys

import logging
import math,random
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

from models.Attention import Attention
from models.Encoders.CNN import CNN_Encoder
from models.Encoders.RNN import RNN_User_Encoder
from models.Interactors.CNN import CNN_Interactor
from models.Interactors.FIM import FIM_Interactor
from models.Interactors.KNRM import KNRM_Interactor
from models.ESM import ESM
from models.base_model import BaseModel

from data.configs.demo import config
from utils.utils import prepare

from models.base_model import BaseModel
from models.Modules.DRM import DRM_Matching
from models.Modules.TFM import TFM

In [4]:
config.device = 'cuda:0'
config.k = 3

vocab, loaders = prepare(config)
record = next(iter(loaders[0]))

[2021-07-22 16:44:58,527] INFO (root) Hyper Parameters are scale:demo, mode:train, batch_size:10, title_size:20, abs_size:40, his_size:50, learning_rate:0.001, vert_num:18, subvert_num:293, npratio:4, dropout_p:0.2, query_dim:200, embedding_dim:300, filter_num:150, head_num:16, epochs:8, metrics:auc,mean_mrr,ndcg@5,ndcg@10, device:cuda:0, attrs:['title'], k:3, save_step:[0], validate:False, interval:10, spadam:False, val_freq:2, schedule:None, multiview:False, onehot:False
[2021-07-22 16:44:58,528] INFO (root) preparing dataset...
[2021-07-22 16:45:01,170] INFO (torchtext.vocab) Loading vectors from /home/peitian_zhang/Data/.vector_cache/glove.840B.300d.txt.pt


In [5]:
class DRM_Matching(nn.Module):
    """
    basic document reducer: topk
    """
    def __init__(self, k, threshold = -float('inf')):
        super().__init__()

        self.name = "matching-based"

        self.k = k
        self.threshold = threshold

    def forward(self, news_embedding, user_repr):
        """
        Extract words from news text according to the overall user interest

        Args:
            news_embedding: word-level news embedding, [batch_size, his_size, signal_length, hidden_dim]
            user_repr: user representation, [batch_size, 1, hidden_dim]

        Returns:
            weighted_pt: weighted embedding for personalized terms, [batch_size, his_size, k, hidden_dim]
        """
        # [bs, hs, sl]
        scores = F.normalize(news_embedding, dim=-1).matmul(F.normalize(user_repr, dim=-1).transpose(-2,-1).unsqueeze(1)).squeeze(-1)
        
        score_k, score_kid = scores.topk(dim=-1, k=self.k)
        personalized_terms = news_embedding.gather(dim=-2,index=score_kid.unsqueeze(-1).expand(score_kid.size() + (news_embedding.size(-1),)))

        weighted_ps_terms = personalized_terms * (score_k.masked_fill(score_k < self.threshold, 0).unsqueeze(-1))
        # weighted_ps_terms.retain_grad()
        # print(weighted_ps_terms.grad, weighted_ps_terms.requires_grad)

        return weighted_ps_terms, score_kid

In [6]:
encoderN = CNN_Encoder(config, vocab)
encoderU = RNN_User_Encoder(encoderN.hidden_dim)
docReducer = DRM_Matching(config.k)
termFuser = TFM(config.his_size, config.k)
interactor = CNN_Interactor(config.title_size + 1, config.k * config.his_size + 1, encoderN.hidden_dim)

esm = ESM(config, encoderN, encoderU, docReducer, None, interactor).to(config.device)

In [8]:
a.grad,b.grad

(tensor([[[[5.7440e-01, 1.4786e+00, 2.7934e-03, 1.4655e+00, 2.2495e-01],
           [3.9562e-01, 1.0184e+00, 1.9240e-03, 1.0094e+00, 1.5494e-01],
           [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
           [4.6744e-01, 1.2032e+00, 2.2733e-03, 1.1926e+00, 1.8306e-01]],
 
          [[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
           [5.5783e-01, 1.4359e+00, 2.7129e-03, 1.4233e+00, 2.1846e-01],
           [4.2650e-01, 1.0979e+00, 2.0742e-03, 1.0882e+00, 1.6703e-01],
           [5.6448e-01, 1.4530e+00, 2.7452e-03, 1.4402e+00, 2.2107e-01]],
 
          [[6.2962e-01, 1.6207e+00, 3.0620e-03, 1.6064e+00, 2.4658e-01],
           [2.9200e-01, 7.5163e-01, 1.4201e-03, 7.4502e-01, 1.1436e-01],
           [4.4374e-01, 1.1422e+00, 2.1580e-03, 1.1322e+00, 1.7378e-01],
           [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]],
 
 
         [[[1.5224e+00, 7.5430e-01, 3.4183e-01, 1.2612e+00, 1.4784e+00],
           [9.7296e-01, 4.8207e-01, 2.1

In [5]:
config.epochs = 8
config.val_freq = 2
esm.tune(config, loaders)


[2021-07-22 16:39:45,572] INFO (models.base_model) training...
epoch 1 , step 140 , loss: 1.5967:  49%|████▉     | 146/295 [00:05<00:05, 26.70it/s][2021-07-22 16:39:51,092] INFO (models.base_model) evaluating...






100%|██████████| 1812/1812 [00:15<00:00, 117.72it/s]
[2021-07-22 16:40:07,070] INFO (models.base_model) current result of esm-cnn-encoder-rnn-user-encoder-matching-based-2dcnn is {'auc': 0.537, 'mean_mrr': 0.2323, 'ndcg@5': 0.2374, 'ndcg@10': 0.3117, 'epoch': 1, 'step': 146}
[2021-07-22 16:40:08,076] INFO (models.base_model) saved model of step 146, epoch 1 at data/model_params/esm-cnn-encoder-rnn-user-encoder-matching-based-2dcnn/demo_epoch1_step146_[hs=50,topk=3].model
epoch 1 , step 290 , loss: 1.5680:  99%|█████████▊| 291/295 [00:27<00:00, 10.65it/s][2021-07-22 16:40:12,960] INFO (models.base_model) evaluating...






100%|██████████| 1812/1812 [00:14<00:00, 128.91it/s]
[2021-07-22 16:40:27,580] INFO (models.base_model) current result of esm-cnn-encoder-rnn-user-encoder-matching-based-2dcnn is {'auc': 0.5513, 'mean_mrr': 0.2427, 'ndcg@5': 0.2566, 'ndcg@10': 0.3233, 'epoch': 1, 'step': 292}
[2021-07-22 16:40:28,533] INFO (models.base_model) saved model of step 292, epoch 1 at data/model_params/esm-cnn-encoder-rnn-user-encoder-matching-based-2dcnn/demo_epoch1_step292_[hs=50,topk=3].model
epoch 1 , step 290 , loss: 1.5680: 100%|██████████| 295/295 [00:43<00:00,  6.84it/s]
epoch 2 , step 140 , loss: 1.4791:  49%|████▉     | 144/295 [00:05<00:05, 26.80it/s][2021-07-22 16:40:34,182] INFO (models.base_model) evaluating...






100%|██████████| 1812/1812 [00:14<00:00, 124.71it/s]
[2021-07-22 16:40:49,252] INFO (models.base_model) current result of esm-cnn-encoder-rnn-user-encoder-matching-based-2dcnn is {'auc': 0.5598, 'mean_mrr': 0.249, 'ndcg@5': 0.2637, 'ndcg@10': 0.3289, 'epoch': 2, 'step': 146}
[2021-07-22 16:40:50,141] INFO (models.base_model) saved model of step 146, epoch 2 at data/model_params/esm-cnn-encoder-rnn-user-encoder-matching-based-2dcnn/demo_epoch2_step146_[hs=50,topk=3].model
epoch 2 , step 290 , loss: 1.4663:  98%|█████████▊| 289/295 [00:26<00:00, 11.02it/s][2021-07-22 16:40:54,984] INFO (models.base_model) evaluating...




