In [1]:
import re
import numpy as np
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from collections import defaultdict
from utils.utils import prepare
from data.configs.demo import config

from transformers import BertTokenizer,BertModel
from utils.Manager import Manager

from models.Embeddings.BERT import BERT_Embedding
from models.Encoders.CNN import CNN_Encoder
from models.Encoders.RNN import RNN_User_Encoder
from models.Modules.DRM import Matching_Reducer
from models.Modules.DRM import BM25_Reducer
from models.Rankers.BERT import BERT_Onepass_Ranker
from models.Rankers.BERT import BERT_Original_Ranker
from models.Rankers.CNN import CNN_Ranker
from models.ESM import ESM

In [2]:
manager = Manager(config)
loaders = prepare(manager)

[2021-08-20 16:55:15,566] INFO (utils.utils) Hyper Parameters are 
scale:demo
mode:tune
epochs:8
batch_size:10
k:5
threshold:-inf
title_length:20
abs_length:40
signal_length:80
npratio:4
his_size:50
cdd_size:5
impr_size:10
dropout_p:0.2
device:cpu
lr:0.0001
bert_lr:3e-05
metrics:auc,mean_mrr,ndcg@5,ndcg@10
embedding:bert
selector:sfi
reducer:matching
interactor:onepass
embedding_dim:300
hidden_dim:150
rank:0
world_size:0
step:0
seeds:42
interval:10
ascend_history:False
disable_dedup:False
schedule:linear
warmup:100
pin_memory:False
shuffle:False
bert:bert-base-uncased
num_workers:0
path:../../../Data/
tb:False
[2021-08-20 16:55:15,568] INFO (utils.utils) preparing dataset...
[2021-08-20 16:55:15,572] INFO (utils.MIND) using cached user behavior from data/cache/bert/MINDdemo_train/10/behaviors..pkl
[2021-08-20 16:55:15,586] INFO (utils.MIND) using cached news tokenization from data/cache/bert/MINDdemo_train/news.pkl
[2021-08-20 16:55:15,916] INFO (utils.utils) deduplicating...
[2021-08-

In [3]:
class ESM(nn.Module):
    def __init__(self, config, embedding, encoderN, encoderU, docReducer, termFuser, ranker):
        super().__init__()

        self.scale = config.scale
        self.cdd_size = config.cdd_size
        self.batch_size = config.batch_size
        self.his_size = config.his_size
        self.device = config.device

        self.k = config.k

        self.embedding = embedding
        self.encoderN = encoderN
        self.encoderU = encoderU
        self.docReducer = docReducer
        self.termFuser = termFuser
        self.ranker = ranker

        self.hidden_dim = encoderN.hidden_dim
        self.final_dim = ranker.final_dim

        self.learningToRank = nn.Sequential(
            nn.Linear(self.final_dim + 1, int(self.final_dim/2)),
            nn.ReLU(),
            nn.Linear(int(self.final_dim/2),1)
        )

        self.name = '__'.join(['esm', self.encoderN.name, self.encoderU.name, self.docReducer.name, self.ranker.name])
        config.name = self.name

    def clickPredictor(self, reduced_tensor, cdd_news_repr, user_repr):
        """ calculate batch of click probabolity

        Args:
            reduced_tensor: [batch_size, cdd_size, final_dim]
            cdd_news_repr: news-level representation, [batch_size, cdd_size, hidden_dim]
            user_repr: user representation, [batch_size, 1, hidden_dim]

        Returns:
            score of each candidate news, [batch_size, cdd_size]
        """
        score_coarse = cdd_news_repr.matmul(user_repr.transpose(-2,-1))
        score = torch.cat([reduced_tensor, score_coarse], dim=-1)

        return self.learningToRank(score).squeeze(dim=-1)

    def _forward(self,x):
        if x["cdd_encoded_index"].size(0) != self.batch_size:
            self.batch_size = x["cdd_encoded_index"].size(0)

        cdd_news = x["cdd_encoded_index"].long().to(self.device)
        cdd_news_embedding = self.embedding(cdd_news)
        _, cdd_news_repr = self.encoderN(
            cdd_news_embedding
        )
        his_news = x["his_encoded_index"].long().to(self.device)
        his_news_embedding = self.embedding(his_news)
        his_news_encoded_embedding, his_news_repr = self.encoderN(
            his_news_embedding
        )

        user_repr = self.encoderU(his_news_repr)

        ps_terms, ps_term_ids = self.docReducer(his_news_encoded_embedding, his_news_embedding, user_repr, x["his_attn_mask"].to(self.device).bool())
        # if self.termFuser:
        #     ps_terms = self.termFuser(ps_terms, ps_term_ids, his_news)

        reduced_tensor = self.ranker(cdd_news_embedding, ps_terms, x["cdd_attn_mask"].to(self.device))

        return self.clickPredictor(reduced_tensor, cdd_news_repr, user_repr), ps_term_ids

    def forward(self,x):
        """
        Decoupled function, score is unormalized click score
        """
        score, ps_term_ids = self._forward(x)

        if self.training:
            prob = nn.functional.log_softmax(score, dim=1)
        else:
            prob = torch.sigmoid(score)

        return prob, ps_term_ids

In [4]:
embedding = BERT_Embedding(manager)
encoderN = CNN_Encoder(manager)
encoderU = RNN_User_Encoder(manager)

docReducer = Matching_Reducer(manager)
# docReducer = BM25_Reducer(manager)

# ranker = CNN_Ranker(manager)
ranker = BERT_Onepass_Ranker(manager)
# ranker = BERT_Original_Ranker(manager)

esm = ESM(manager, embedding, encoderN, encoderU, docReducer, None, ranker).to(manager.device)

In [5]:
manager.inspect_ps_terms(esm, 3534, loaders[0], bm25=False)

[2021-08-20 16:55:37,918] INFO (utils.Manager) loading model from data/model_params/esm__cnn__rnn-u__matching-reducer__onepass-bert/demo_step3534_[k=5].model...
