In [1]:
import re
import numpy as np
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from collections import defaultdict
from utils.utils import prepare
from data.configs.demo import config

from transformers import BertTokenizer,BertModel
from utils.Manager import Manager

from models.Embeddings.BERT import BERT_Embedding
from models.Encoders.CNN import CNN_Encoder
from models.Encoders.RNN import RNN_User_Encoder
from models.Modules.DRM import Matching_Reducer
from models.Modules.DRM import BM25_Reducer
from models.Rankers.BERT import BERT_Onepass_Ranker
from models.Rankers.BERT import BERT_Original_Ranker
from models.Rankers.CNN import CNN_Ranker
from models.ESM import ESM

In [2]:
manager = Manager(config)
loaders = prepare(manager)

record = list(loaders[0])[0]

[2021-08-22 13:38:51,634] INFO (utils.utils) Hyper Parameters are 
scale:demo
mode:tune
epochs:8
batch_size:10
k:5
threshold:-inf
title_length:20
abs_length:40
signal_length:80
npratio:4
his_size:50
cdd_size:5
impr_size:10
dropout_p:0.2
device:cpu
lr:0.0001
bert_lr:3e-05
metrics:auc,mean_mrr,ndcg@5,ndcg@10
embedding:bert
selector:sfi
reducer:matching
interactor:onepass
embedding_dim:300
hidden_dim:150
rank:0
world_size:0
step:0
seeds:42
interval:10
ascend_history:False
no_dedup:False
scheduler:linear
warmup:100
pin_memory:False
shuffle:False
bert:bert-base-uncased
num_workers:0
smoothing:0.3
path:../../../Data/
tb:False
[2021-08-22 13:38:51,635] INFO (utils.utils) preparing dataset...
[2021-08-22 13:38:51,640] INFO (utils.MIND) using cached user behavior from data/cache/bert/MINDdemo_train/10/behaviors..pkl
[2021-08-22 13:38:51,654] INFO (utils.MIND) using cached news tokenization from data/cache/bert/MINDdemo_train/news.pkl
[2021-08-22 13:38:52,175] INFO (utils.utils) deduplicating...

In [4]:
class ESM(nn.Module):
    def __init__(self, config, embedding, encoderN, encoderU, reducer, fuser, ranker):
        super().__init__()

        self.scale = config.scale
        self.cdd_size = config.cdd_size
        self.batch_size = config.batch_size
        self.his_size = config.his_size
        self.device = config.device

        self.k = config.k

        self.embedding = embedding
        self.encoderN = encoderN
        self.encoderU = encoderU
        self.reducer = reducer
        self.fuser = fuser
        self.ranker = ranker

        self.final_dim = ranker.final_dim

        self.learningToRank = nn.Sequential(
            nn.Linear(self.final_dim + 1, int(self.final_dim/2)),
            nn.ReLU(),
            nn.Linear(int(self.final_dim/2),1)
        )

        self.name = '__'.join(['esm', self.encoderN.name, self.encoderU.name, self.reducer.name, self.ranker.name])
        config.name = self.name

    def clickPredictor(self, reduced_tensor, cdd_news_repr, user_repr):
        """ calculate batch of click probabolity

        Args:
            reduced_tensor: [batch_size, cdd_size, final_dim]
            cdd_news_repr: news-level representation, [batch_size, cdd_size, hidden_dim]
            user_repr: user representation, [batch_size, 1, hidden_dim]

        Returns:
            score of each candidate news, [batch_size, cdd_size]
        """
        score_coarse = cdd_news_repr.matmul(user_repr.transpose(-2,-1))
        score = torch.cat([reduced_tensor, score_coarse], dim=-1)

        return self.learningToRank(score).squeeze(dim=-1)

    def _forward(self,x):
        if x["cdd_encoded_index"].size(0) != self.batch_size:
            self.batch_size = x["cdd_encoded_index"].size(0)

        cdd_news = x["cdd_encoded_index"].long().to(self.device)
        cdd_news_embedding = self.embedding(cdd_news)
        _, cdd_news_repr = self.encoderN(
            cdd_news_embedding
        )
        if self.reducer.name == 'bm25':
            his_news = x["his_reduced_index"].long().to(self.device)
        else:
            his_news = x["his_encoded_index"].long().to(self.device)
        his_news_embedding = self.embedding(his_news)
        his_news_encoded_embedding, his_news_repr = self.encoderN(
            his_news_embedding
        )

        user_repr = self.encoderU(his_news_repr)
        if self.reducer.name == 'matching':
            ps_terms, ps_term_mask, scores = self.reducer(his_news_encoded_embedding, his_news_embedding, user_repr, x["his_attn_mask"].to(self.device), x["his_attn_mask_k"].to(self.device).bool())
        else:
            ps_terms, ps_term_mask = self.reducer(his_news_encoded_embedding, his_news_embedding, user_repr, x["his_attn_mask"].to(self.device))

        if self.fuser:
            ps_terms, ps_term_mask = self.fuser(ps_terms, ps_term_mask)

        # reduced_tensor = self.ranker(torch.cat([cdd_news_repr.unsqueeze(-2), cdd_news_embedding], dim=-2), torch.cat([user_repr, ps_terms], dim=-2))
        reduced_tensor = self.ranker(cdd_news_embedding, ps_terms, x["cdd_attn_mask"].to(self.device), ps_term_mask)

        return self.clickPredictor(reduced_tensor, cdd_news_repr, user_repr), scores

    def forward(self,x):
        """
        Decoupled function, score is unormalized click score
        """
        score, scores = self._forward(x)

        if self.training:
            prob = nn.functional.log_softmax(score, dim=1)
        else:
            prob = torch.sigmoid(score)

        return prob, scores

In [5]:
class Matching_Reducer(nn.Module):
    """
    basic document reducer: topk of each historical article
    """
    def __init__(self, config):
        super().__init__()

        self.name = "matching"

        self.k = config.k

        config.term_num = config.k * config.his_size

        threshold = torch.tensor([config.threshold])
        self.register_buffer('threshold', threshold)

    def forward(self, news_selection_embedding, news_embedding, user_repr, his_attn_mask, his_attn_mask_k):
        """
        Extract words from news text according to the overall user interest

        Args:
            news_selection_embedding: encoded word-level embedding, [batch_size, his_size, signal_length, hidden_dim]
            news_embedding: word-level news embedding, [batch_size, his_size, signal_length, hidden_dim]
            user_repr: user representation, [batch_size, 1, hidden_dim]

        Returns:
            ps_terms: weighted embedding for personalized terms, [batch_size, his_size, k, hidden_dim]
            ps_term_mask: attention mask of output terms, [batch_size, his_size, k]
        """
        # strip off [CLS]
        news_selection_embedding = news_selection_embedding[:, :, 1:]
        news_embedding = news_embedding[:, :, 1:]

        # [bs, hs, sl - 1]
        scores = F.normalize(news_selection_embedding, dim=-1).matmul(F.normalize(user_repr, dim=-1).transpose(-2,-1).unsqueeze(1)).squeeze(-1)
        # mask the padded term
        scores = scores.masked_fill(~his_attn_mask_k[:, :, 1:], -float('inf'))

        score_k, score_kid = scores.topk(dim=-1, k=self.k, sorted=False)

        personalized_terms = news_embedding.gather(dim=-2,index=score_kid.unsqueeze(-1).expand(score_kid.size() + (news_embedding.size(-1),)))
        ps_term_mask = his_attn_mask[:, :, 1:].gather(dim=-1, index=score_kid)

        mask_pos = score_k < self.threshold
        # ps_terms = personalized_terms * (nn.functional.softmax(score_k.masked_fill(score_k < self.threshold, 0), dim=-1).unsqueeze(-1))
        ps_terms = personalized_terms * (score_k.masked_fill(mask_pos, 0).unsqueeze(-1))
        ps_term_mask = ps_term_mask * (~mask_pos)

        scores.retain_grad()

        # weighted_ps_terms.retain_grad()
        # print(weighted_ps_terms.grad, weighted_ps_terms.requires_grad)

        return ps_terms, ps_term_mask, scores

In [None]:
from transformers import BertModel,BertConfig

class BERT_Onepass_Ranker(nn.Module):
    """
    one-pass bert: cdd1 cdd2 ... cddn [SEP] pst1 pst2 ...
    """
    def __init__(self, config):
        from models.Rankers.Modules.OnePassAttn import BertSelfAttention
        # confirm the hidden dim to be 768
        assert config.embedding_dim == 768
        # confirm term_num + signal_length is less than 512
        # assert config.k * config.his_size + config.his_size + config.signal_length < 512

        super().__init__()

        self.name = 'onepass-bert'
        self.signal_length = config.signal_length
        self.term_num = config.term_num + 1
        self.embedding_dim = config.embedding_dim
        self.final_dim = self.embedding_dim

        bert_config = BertConfig()
        # primary bert
        prim_bert = BertModel(bert_config).encoder
        bert_config.signal_length = self.signal_length
        bert_config.term_num = config.term_num + 1
        bert_config.cdd_size = config.cdd_size
        for l in prim_bert.layer:
            l.attention.self = BertSelfAttention(bert_config)

        bert = BertModel.from_pretrained(
            config.bert,
            cache_dir=config.path + 'bert_cache/'
        )
        prim_bert.load_state_dict(bert.encoder.state_dict())
        self.bert = prim_bert

        self.dropOut = bert.embeddings.dropout

        # [2, embedding_dim]
        self.token_type_embedding = nn.Parameter(bert.embeddings.token_type_embeddings.weight)
        # [SEP] token
        self.sep_embedding = nn.Parameter(bert.embeddings.word_embeddings(torch.tensor([102])).clone().detach().requires_grad_(True).view(1,1,self.embedding_dim))


    def forward(self, cdd_news_embedding, ps_terms, cdd_attn_mask, his_attn_mask):
        """
        calculate interaction tensor and reduce it to a vector

        Args:
            cdd_news_embedding: word-level representation of candidate news, [batch_size, cdd_size, signal_length, embedding_dim]
            ps_terms: concatenated historical news or personalized terms, [batch_size, term_num, embedding_dim]
            cdd_attn_mask: attention mask of the candidate news, [batch_size, cdd_size, signal_length]
            his_attn_mask: attention mask of the history sequence, [batch_size, his_size, k]/[batch_size, cdd_size, k, signal_length]

        Returns:
            reduced_tensor: output tensor after CNN2d, [batch_size, cdd_size, final_dim]
        """
        batch_size = cdd_news_embedding.size(0)
        cdd_size = cdd_news_embedding.size(1)

        ps_terms = ps_terms.view(batch_size, -1, self.embedding_dim)
        # [bs,tn,hd]
        ps_terms = torch.cat([self.sep_embedding.expand(batch_size, 1, self.embedding_dim), ps_terms], dim=1)
        ps_terms[:,1:] += self.token_type_embedding[1]
        ps_terms[:,0] += self.token_type_embedding[0]

        # [bs, cs*sl, hd]
        cdd_news_embedding = (cdd_news_embedding + self.token_type_embedding[0]).view(batch_size, -1, self.embedding_dim)

        bert_input = torch.cat([cdd_news_embedding, ps_terms], dim=-2)
        # bert_input = self.dropOut(bert_input)

        # [bs, cs*sl]
        attn_mask = cdd_attn_mask.view(batch_size, -1)
        attn_mask = torch.cat([attn_mask, torch.ones(batch_size, 1, device=cdd_attn_mask.device), his_attn_mask.view(batch_size, -1)], dim=-1).view(batch_size, 1, 1, -1)
        
        bert_output = self.bert(bert_input, attention_mask=attn_mask).last_hidden_state[:, 0 : cdd_size * (self.signal_length) : self.signal_length].view(batch_size, cdd_size, self.embedding_dim)

        return bert_output

In [3]:
embedding = BERT_Embedding(manager)
encoderN = CNN_Encoder(manager)
encoderU = RNN_User_Encoder(manager)

docReducer = Matching_Reducer(manager)
# docReducer = BM25_Reducer(manager)

# ranker = CNN_Ranker(manager)
ranker = BERT_Onepass_Ranker(manager)
# ranker = BERT_Original_Ranker(manager)

esm = ESM(manager, embedding, encoderN, encoderU, docReducer, None, ranker).to(manager.device)

In [7]:
prob, scores = esm(record)

In [None]:
manager.inspect_ps_terms(esm, 3534, loaders[0], bm25=False)

In [6]:
esm.encoderU.lstm.weight_hh_l0 == a

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])