In [63]:
import re
import numpy as np
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from collections import defaultdict
from utils.utils import prepare
from data.configs.demo import config

from transformers import BertTokenizer,BertModel
from utils.Manager import Manager

from models.Embeddings.BERT import BERT_Embedding
from models.Encoders.CNN import CNN_Encoder
from models.Encoders.RNN import RNN_User_Encoder
from models.Modules.DRM import Matching_Reducer
from models.Modules.DRM import BM25_Reducer
from models.Rankers.BERT import BERT_Onepass_Ranker
from models.Rankers.BERT import BERT_Original_Ranker
from models.Rankers.CNN import CNN_Ranker
from models.ESM import ESM

In [2]:
manager = Manager(config)
loaders = prepare(manager)
record = list(loaders[0])[0]

[2021-08-19 20:41:28,350] INFO (utils.utils) Hyper Parameters are \mscale:demo
mode:tune
epochs:8
batch_size:10
k:3
threshold:0
title_length:20
abs_length:40
signal_length:80
npratio:4
his_size:50
cdd_size:5
impr_size:10
dropout_p:0.2
device:cpu
lr:0.0001
bert_lr:3e-05
metrics:auc,mean_mrr,ndcg@5,ndcg@10
embedding:bert
selector:sfi
reducer:matching
interactor:onepass
embedding_dim:300
hidden_dim:150
rank:0
world_size:0
step:0
seeds:42
interval:10
val_freq:2
schedule:linear
order_history:False
warmup:100
pin_memory:False
shuffle:False
num_workers:0
path:../../../Data/
tb:False
bert:bert-base-uncased
[2021-08-19 20:41:28,351] INFO (utils.utils) preparing dataset...
[2021-08-19 20:41:28,355] INFO (utils.MIND) using cached user behavior from data/cache/bert/MINDdemo_train/10/behaviors..pkl
[2021-08-19 20:41:28,369] INFO (utils.MIND) using cached news tokenization from data/cache/bert/MINDdemo_train/news.pkl
[2021-08-19 20:41:28,601] INFO (utils.MIND) using cached user behavior from data/ca

In [3]:
class ESM(nn.Module):
    def __init__(self, config, embedding, encoderN, encoderU, docReducer, termFuser, ranker):
        super().__init__()

        self.scale = config.scale
        self.cdd_size = config.cdd_size
        self.batch_size = config.batch_size
        self.his_size = config.his_size
        self.device = config.device

        self.k = config.k

        self.embedding = embedding
        self.encoderN = encoderN
        self.encoderU = encoderU
        self.docReducer = docReducer
        self.termFuser = termFuser
        self.ranker = ranker

        self.hidden_dim = encoderN.hidden_dim
        self.final_dim = ranker.final_dim

        self.learningToRank = nn.Sequential(
            nn.Linear(self.final_dim + 1, int(self.final_dim/2)),
            nn.ReLU(),
            nn.Linear(int(self.final_dim/2),1)
        )

        self.name = '__'.join(['esm', self.encoderN.name, self.encoderU.name, self.docReducer.name, self.ranker.name])
        config.name = self.name

    def clickPredictor(self, reduced_tensor, cdd_news_repr, user_repr):
        """ calculate batch of click probabolity

        Args:
            reduced_tensor: [batch_size, cdd_size, final_dim]
            cdd_news_repr: news-level representation, [batch_size, cdd_size, hidden_dim]
            user_repr: user representation, [batch_size, 1, hidden_dim]

        Returns:
            score of each candidate news, [batch_size, cdd_size]
        """
        score_coarse = cdd_news_repr.matmul(user_repr.transpose(-2,-1))
        score = torch.cat([reduced_tensor, score_coarse], dim=-1)

        return self.learningToRank(score).squeeze(dim=-1)

    def _forward(self,x):
        if x["cdd_encoded_index"].size(0) != self.batch_size:
            self.batch_size = x["cdd_encoded_index"].size(0)

        cdd_news = x["cdd_encoded_index"].long().to(self.device)
        cdd_news_embedding = self.embedding(cdd_news)
        _, cdd_news_repr = self.encoderN(
            cdd_news_embedding
        )
        his_news = x["his_encoded_index"].long().to(self.device)
        his_news_embedding = self.embedding(his_news)
        his_news_encoded_embedding, his_news_repr = self.encoderN(
            his_news_embedding
        )

        user_repr = self.encoderU(his_news_repr)

        ps_terms, ps_term_ids = self.docReducer(his_news_encoded_embedding, his_news_embedding, user_repr, x["his_attn_mask"].to(self.device).bool())
        # if self.termFuser:
        #     ps_terms = self.termFuser(ps_terms, ps_term_ids, his_news)

        # reduced_tensor = self.ranker(torch.cat([cdd_news_repr.unsqueeze(-2), cdd_news_embedding], dim=-2), torch.cat([user_repr, ps_terms], dim=-2))

        reduced_tensor = self.ranker(cdd_news_embedding, ps_terms, x["cdd_attn_mask"].to(self.device))

        return self.clickPredictor(reduced_tensor, cdd_news_repr, user_repr), ps_term_ids

    def forward(self,x):
        """
        Decoupled function, score is unormalized click score
        """
        score, ps_term_ids = self._forward(x)

        if self.training:
            prob = nn.functional.log_softmax(score, dim=1)
        else:
            prob = torch.sigmoid(score)

        return prob, ps_term_ids

In [4]:
embedding = BERT_Embedding(manager)
encoderN = CNN_Encoder(manager)
encoderU = RNN_User_Encoder(manager)

docReducer = Matching_Reducer(manager)
# docReducer = BM25_Reducer(manager)

# ranker = CNN_Ranker(manager)
ranker = BERT_Onepass_Ranker(manager)
# ranker = BERT_Original_Ranker(manager)

esm = ESM(manager, embedding, encoderN, encoderU, docReducer, None, ranker).to(manager.device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transfo

In [5]:
manager.load(esm, 1178)

[2021-08-19 20:41:47,285] INFO (utils.Manager) loading model from data/model_params/esm__cnn__rnn-u__matching-reducer__onepass-bert-without-order/demo_step1178_[k=3].model...


In [6]:
a,b = esm(record)

In [7]:
c = record['his_encoded_index'].gather(dim=-1, index=b)

In [8]:
t = BertTokenizer.from_pretrained('bert-base-uncased')

In [27]:
import pickle
bm25 = pickle.load(open('/data/workspace/Peitian/Code/Document-Reduction/Code/data/cache/bert/MINDdemo_train/news_bm25.pkl','rb'))

his = bm25['encoded_news']
his_sorted = bm25['encoded_news_sorted']

In [60]:
k = 8

t.decode(c[0][k]), t.decode(record['his_encoded_index'][0][k]), t.decode(his_sorted[record['his_id'][0][k]])

('aboutjak,',
 '[CLS]\'wheel of fortune\'guest delivers hilarious, off the rails introduction we\'d like to solve the puzzle, pat : blair davis\'loveless marriage? on monday, " wheel of fortune " welcomed as a new contestant trucking business owner blair davis, who offered a biting introduction for himself. when host pat sajak asked the man from cardiff, california, about his family, davis plunged into',
 '[CLS] loveless introduction blair summaries fortune wheel davis cardiff marriage pat trucking darkest rails biting puzzle plunged sajak contestant welcomed delivers solve guest hilarious trapped offered personal heard host tvnews owner likely asked ever business 12 show california tv family like years monday last man one? new cls,. [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD

In [62]:
[101] + [12, 13] + [102]*-1

[101, 12, 13]

In [77]:
class BM25(object):
    """
    compute bm25 score
    """
    def __init__(self, k=2, epsilon=0.5):
        self.k = k
        self.epsilon = epsilon

    def _build_tf_idf(self, documents):
        """
        build term frequencies (how many times a term occurs in one news) and document frequencies (how many documents contains a term)
        """
        word_count = 0
        doc_count = len(documents)

        tfs = []
        df = defaultdict(int)
        for document in documents:
            tf = defaultdict(int)
            # ignore [CLS]
            for term in document[1:]:
                tf[term] += 1
                df[term] += 1
                word_count + 1

            tfs.append(tf)

        self.tfs = tfs

        idf = defaultdict(float)
        for term,freq in df.items():
            idf[term] = math.log((doc_count - freq + 0.5 ) / (freq + 0.5) + 1)

        self.idf = idf


    def __call__(self, documents):
        """
        compute bm25 score of each term in each document and sort the terms by it
        with b=0, totally ignoring the effect of document length

        Args:
            documents: list of strings
        """
        self._build_tf_idf(documents)

        document_length = len(documents[0])
        bm25_scores = []
        for tf in self.tfs:
            score = defaultdict(float)
            for term, freq in tf.items():
                score[term] = (self.idf[term] * freq * (self.k + 1)) / (freq + self.k)

            bm25_scores.append(dict(sorted(score.items(), key=lambda item: item[1], reverse=True)))

        sorted_documents = []
        sorted_attn_mask = []
        for bm25 in bm25_scores:
            bm25_length = len(bm25) + 1
            pad_length = document_length - bm25_length

            sorted_documents.append([101] + list(bm25.keys()) + [102]*pad_length)
            sorted_attn_mask.append([1] * bm25_length + [0] * pad_length)

        return sorted_documents, sorted_attn_mask

In [78]:
b = BM25()

In [83]:
documents = np.array([[101,102,102,102],[101,105,105,106]])

In [82]:
b(documents)

([[101, 102, 102, 102], [101, 101, 105, 106]], [[1, 1, 0, 0], [1, 1, 1, 1]])