In [3]:
import re
import logging
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import math
import time
import pickle
# from thop import profile
from collections import defaultdict
from data.configs.demo import config
from collections import defaultdict

from transformers import AutoTokenizer, AutoModel, BertModel, BertConfig, AutoModelForSequenceClassification
from utils.Manager import Manager

from models.Embeddings.BERT import BERT_Embedding
from models.Encoders.CNN import CNN_Encoder,CNN_User_Encoder
from models.Encoders.RNN import RNN_Encoder,RNN_User_Encoder
from models.Encoders.MHA import MHA_Encoder, MHA_User_Encoder
from models.Modules.DRM import Matching_Reducer, Identical_Reducer
from models.Rankers.BERT import BERT_Onepass_Ranker, BERT_Original_Ranker
from models.Rankers.CNN import CNN_Ranker
from models.Encoders.Pooling import Attention_Pooling, Average_Pooling
from models.UniLM.modeling import TuringNLRv3Model, TuringNLRv3ForSequenceClassification, relative_position_bucket
from models.UniLM.configuration_tnlrv3 import TuringNLRv3Config

from models.BaseModel import BaseModel

from models.Encoders.BERT import BERT_Encoder
from models.Encoders.Pooling import *

from models.ESM import ESM
from models.TESRec import TESRec
from models.XFormer import XFormer
 
from models.Modules.Attention import MultiheadAttention, get_attn_mask, XSoftmax
torch.set_printoptions(threshold=100000)

In [4]:
# m = AutoModel.from_pretrained('bert-base-uncased',cache_dir=config.path + 'bert_cache/')
# m2 = AutoModel.from_pretrained('microsoft/deberta-base',cache_dir=config.path + 'bert_cache/')
# m3 = TuringNLRv3ForSequenceClassification.from_pretrained(config.unilm_path, config=TuringNLRv3Config.from_pretrained(config.unilm_config_path))

t = AutoTokenizer.from_pretrained('bert-base-uncased', cache_dir=config.path + "bert_cache/")
# t2 = AutoTokenizer.from_pretrained('microsoft/deberta-base', cache_dir=config.path + "bert_cache/")
# t3 = AutoTokenizer.from_pretrained('allenai/longformer-base-4096', cache_dir=config.path + "bert_cache/")

In [5]:
# config.reducer = 'entity'
# config.embedding = 'deberta'
# config.bert = 'microsoft/deberta-base'
# config.device = 0
# config.bert = 'longformer'
# config.seed = None
config.mode = "inspect"
config.recall_type = "s"
config.scale = "large"
config.case = True

manager = Manager(config)

loaders = manager.prepare()
X1 = list(loaders[0])
# X2 = list(loaders[1])
x1 = X1[0]
# x2 = X2[0]

[2021-11-08 08:05:30,871] INFO (utils.Manager) Hyper Parameters are 
{
    "scale": "large",
    "mode": "inspect",
    "batch_size": 5,
    "batch_size_news": 100,
    "batch_size_history": 100,
    "k": 3,
    "threshold": -Infinity,
    "abs_length": 40,
    "signal_length": 100,
    "news": null,
    "his_size": 50,
    "cdd_size": 5,
    "impr_size": 10,
    "dropout_p": 0.2,
    "lr": 0.0001,
    "bert_lr": 3e-05,
    "embedding": "bert",
    "encoderN": "cnn",
    "encoderU": "lstm",
    "selector": "sfi",
    "reducer": "matching",
    "ranker": "onepass",
    "pooler": "attn",
    "bert_dim": 768,
    "embedding_dim": 768,
    "hidden_dim": 384,
    "base_rank": 0,
    "world_size": 0,
    "seed": 42,
    "granularity": "token",
    "debias": false,
    "full_attn": true,
    "descend_history": false,
    "shuffle_pos": false,
    "save_pos": false,
    "sep_his": false,
    "diversify": false,
    "no_dedup": false,
    "segment_embed": false,
    "no_rm_punc": false,
    "fa

In [9]:
class Matching_Reducer(nn.Module):
    """
    select top k terms from each historical news with max cosine similarity

    1. keep the first K terms unmasked
    2. add order embedding to terms from different historical news
    3. insert [SEP] token to separate terms from different news if called
    """
    def __init__(self, manager):
        super().__init__()
        self.name = "matching"
        self.k = manager.k
        self.his_size = manager.his_size
        self.embedding_dim = manager.embedding_dim

        self.diversify = manager.diversify
        self.sep_his = manager.sep_his
        # if aggregator is enabled, do not flatten the personalized terms
        self.flatten = (manager.aggregator is None)

        manager.term_num = manager.k * manager.his_size

        # strip [CLS]
        keep_k_modifier = torch.zeros(1, manager.signal_length - 1)
        keep_k_modifier[:, :self.k] = 1
        self.register_buffer('keep_k_modifier', keep_k_modifier, persistent=False)

        if self.diversify:
            self.newsUserAlign = nn.Linear(manager.hidden_dim * 2, manager.hidden_dim)
            nn.init.xavier_normal_(self.newsUserAlign.weight)

        if manager.threshold != -float('inf'):
            threshold = torch.tensor([manager.threshold])
            self.register_buffer('threshold', threshold)

        if self.sep_his:
            manager.term_num += (self.his_size - 1)
            self.sep_embedding = nn.Parameter(torch.randn(1, 1, self.embedding_dim))
            self.register_buffer('extra_sep_mask', torch.ones(1, 1, 1), persistent=False)

        if manager.segment_embed:
            self.segment_embedding = nn.Parameter(torch.randn(manager.his_size, 1, manager.embedding_dim))
            nn.init.xavier_normal_(self.segment_embedding)


    def forward(self, news_selection_embedding, news_embedding, user_repr, news_repr, his_attn_mask, his_refined_mask):
        """
        Extract words from news text according to the overall user interest

        Args:
            news_selection_embedding: encoded word-level embedding, [batch_size, his_size, signal_length, hidden_dim]
            news_embedding: word-level news embedding, [batch_size, his_size, signal_length, hidden_dim]
            news_repr: news-level representation, [batch_size, his_size, hidden_dim]
            user_repr: user representation, [batch_size, 1, hidden_dim]
            his_refined_mask: dedupicated attention mask, [batch_size, his_size, signal_length]
        Returns:
            ps_terms: weighted embedding for personalized terms, [batch_size, term_num, embedding_dim]
            ps_term_mask: attention mask of output terms, [batch_size, term_num]
            kid: the index of personalized terms
        """
        batch_size = news_embedding.size(0)

        if self.diversify:
            news_user_repr = torch.cat([user_repr.expand(news_repr.size()), news_repr], dim=-1)
            selection_query = self.newsUserAlign(news_user_repr).unsqueeze(-1)
        else:
            selection_query = user_repr.unsqueeze(-1)

        news_selection_embedding = news_selection_embedding[:, :, 1:]

        news_embedding_text = news_embedding[:, :, 1:]
        his_attn_mask = his_attn_mask[:, :, 1:]

        # [bs, hs, sl - 1]
        scores = F.normalize(news_selection_embedding, dim=-1).matmul(F.normalize(selection_query, dim=-2)).squeeze(-1)
        # scores = news_selection_embedding.matmul(selection_query).squeeze(-1)/math.sqrt(selection_query.size(-1))
        pad_pos = ~((his_refined_mask[:, :, 1:] + self.keep_k_modifier).bool())

        # mask the padded term
        scores = scores.masked_fill(pad_pos, -float('inf'))

        score_k, score_kid = scores.topk(dim=-1, k=self.k)

        ps_terms = news_embedding_text.gather(dim=-2,index=score_kid.unsqueeze(-1).expand(*score_kid.size(), news_embedding_text.size(-1)))
        # [bs, hs, k]
        ps_term_mask = his_attn_mask.gather(dim=-1, index=score_kid)

        if hasattr(self, 'threshold'):
            mask_pos = score_k < self.threshold
            # ps_terms = personalized_terms * (nn.functional.softmax(score_k.masked_fill(score_k < self.threshold, 0), dim=-1).unsqueeze(-1))
            ps_terms = ps_terms * (F.softmax(score_k.masked_fill(mask_pos, 0), dim=-1).unsqueeze(-1))
            ps_term_mask = ps_term_mask * (~mask_pos)
        else:
            ps_terms = ps_terms * (F.softmax(score_k, dim=-1).unsqueeze(-1))

        if hasattr(self, 'segment_embedding'):
            ps_terms += self.segment_embedding

        # flatten the selected terms into one dimension
        if self.flatten:
            # separate historical news only practical when squeeze=True
            if self.sep_his:
                # [bs, hs, ed]
                sep_embedding = self.sep_embedding.expand(batch_size, self.his_size, 1, self.embedding_dim)
                # add extra [SEP] token to separate terms from different history news, slice to -1 to strip off the last [SEP]
                ps_terms = torch.cat([ps_terms, sep_embedding], dim=-2).view(batch_size, -1, self.embedding_dim)[:, :-1]
                ps_term_mask = torch.cat([ps_term_mask, self.extra_sep_mask.expand(batch_size, self.his_size, 1)], dim=-1).view(batch_size, -1)[:, :-1]

            else:
                # [bs, 1, ed]
                ps_terms = ps_terms.reshape(batch_size, -1, self.embedding_dim)
                ps_term_mask = ps_term_mask.reshape(batch_size, -1)

        return ps_terms, ps_term_mask, score_kid, scores

In [11]:
# Two tower baseline
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class TESRec(BaseModel):
    """
    Tow tower model with selection

    1. encode candidate news with bert
    2. encode ps terms with the same bert, using [CLS] embedding as user representation
    3. predict by scaled dot product
    """
    def __init__(self, manager, embedding, encoderN, encoderU, reducer, aggregator=None):
        super().__init__(manager)

        self.embedding = embedding
        # only these reducers need selection encoding
        if manager.reducer in manager.get_need_encode_reducers():
            self.encoderN = encoderN
            self.encoderU = encoderU
        self.reducer = reducer
        self.aggregator = aggregator
        self.bert = BERT_Encoder(manager)

        if manager.debias:
            self.userBias = nn.Parameter(torch.randn(1,self.bert.hidden_dim))
            nn.init.xavier_normal_(self.userBias)

        self.hidden_dim = manager.bert_dim

        self.granularity = manager.granularity
        if self.granularity != 'token':
            self.register_buffer('cdd_dest', torch.zeros((self.batch_size, self.impr_size, self.signal_length * self.signal_length)), persistent=False)
            if manager.reducer in ["bm25", "entity", "first"]:
                self.register_buffer('his_dest', torch.zeros((self.batch_size, self.his_size, (manager.k + 2) * (manager.k + 2))), persistent=False)
            else:
                self.register_buffer('his_dest', torch.zeros((self.batch_size, self.his_size, self.signal_length * self.signal_length)), persistent=False)


        if aggregator is not None:
            manager.name = '__'.join(['tesrec', manager.bert, manager.encoderN, manager.encoderU, manager.reducer, manager.aggregator, manager.granularity, str(manager.k)])
        else:
            manager.name = '__'.join(['tesrec', manager.bert, manager.encoderN, manager.encoderU, manager.reducer, manager.granularity, str(manager.k)])

        self.name = manager.name


    def encode_news(self, x):
        """
        encode candidate news
        """
        # encode news with MIND_news
        if self.granularity != 'token':
            batch_size = x['cdd_subword_index'].size(0)
            cdd_dest = self.cdd_dest[:batch_size]
            cdd_subword_index = x['cdd_subword_index'].to(self.device)
            cdd_subword_index = cdd_subword_index[:, :, 0] * self.signal_length + cdd_subword_index[:, :, 1]

            cdd_subword_prefix = cdd_dest.scatter(dim=-1, index=cdd_subword_index, value=1)
            cdd_subword_prefix = cdd_subword_prefix.view(batch_size, self.signal_length, self.signal_length)

            if self.granularity == 'avg':
                # average subword embeddings as the word embedding
                cdd_subword_prefix = F.normalize(cdd_subword_prefix, p=1, dim=-1)
            cdd_attn_mask = cdd_subword_prefix.matmul(x['cdd_attn_mask'].to(self.device).float().unsqueeze(-1)).squeeze(-1)

        else:
            cdd_subword_prefix = None
            cdd_attn_mask = x['cdd_attn_mask'].to(self.device)

        cdd_news = x["cdd_encoded_index"].to(self.device)
        _, cdd_news_repr, _ = self.bert(
            self.embedding(cdd_news, cdd_subword_prefix), cdd_attn_mask
        )

        return cdd_news_repr


    def encode_user(self, x):
        """
        encoder user
        """
        if self.granularity != 'token':
            batch_size = x['his_encoded_index'].size(0)
            his_dest = self.his_dest[:batch_size]

            his_subword_index = x['his_subword_index'].to(self.device)
            his_signal_length = his_subword_index.size(-2)
            his_subword_index = his_subword_index[:, :, :, 0] * his_signal_length + his_subword_index[:, :, :, 1]

            his_subword_prefix = his_dest.scatter(dim=-1, index=his_subword_index, value=1) * x["his_mask"].to(self.device)
            his_subword_prefix = his_subword_prefix.view(batch_size, self.his_size, his_signal_length, his_signal_length)

            if self.granularity == 'avg':
                # average subword embeddings as the word embedding
                his_subword_prefix = F.normalize(his_subword_prefix, p=1, dim=-1)

            his_attn_mask = his_subword_prefix.matmul(x["his_attn_mask"].to(self.device).float().unsqueeze(-1)).squeeze(-1)
            his_refined_mask = None
            if 'his_refined_mask' in x:
                his_refined_mask = his_subword_prefix.matmul(x["his_refined_mask"].to(self.device).float().unsqueeze(-1)).squeeze(-1)

        else:
            his_subword_prefix = None
            his_attn_mask = x["his_attn_mask"].to(self.device)
            his_refined_mask = None
            if 'his_refined_mask' in x:
                his_refined_mask = x["his_refined_mask"].to(self.device)

        his_news = x["his_encoded_index"].to(self.device)
        his_news_embedding = self.embedding(his_news, his_subword_prefix)
        if hasattr(self, 'encoderN'):
            his_news_encoded_embedding, his_news_repr = self.encoderN(
                his_news_embedding, his_attn_mask
            )
        else:
            his_news_encoded_embedding = None
            his_news_repr = None
        # no need to calculate this if ps_terms are fixed in advance

        if self.reducer.name == 'matching':
            user_repr_ext = self.encoderU(his_news_repr, his_mask=x['his_mask'].to(self.device), user_index=x['user_id'].to(self.device))
        else:
            user_repr_ext = None

        ps_terms, ps_term_mask, kid, scores = self.reducer(his_news_encoded_embedding, his_news_embedding, user_repr_ext, his_news_repr, his_attn_mask, his_refined_mask)

        _, user_repr, _ = self.bert(ps_terms, ps_term_mask, ps_term_input=True)

        if self.aggregator is not None:
            user_repr = self.aggregator(user_repr)

        if hasattr(self, 'userBias'):
            user_repr = user_repr + self.userBias

        return user_repr, kid, his_news_encoded_embedding, user_repr_ext, scores


    def compute_score(self, cdd_news_repr, user_repr):
        """ calculate batch of click probabolity

        Args:
            cdd_news_repr: news-level representation, [batch_size, cdd_size, hidden_dim]
            user_repr: user representation, [batch_size, 1, hidden_dim]

        Returns:
            score of each candidate news, [batch_size, cdd_size]
        """
        score = cdd_news_repr.matmul(user_repr.transpose(-2,-1)).squeeze(-1)/math.sqrt(cdd_news_repr.size(-1))
        return score


    def forward(self,x):
        cdd_repr = self.encode_news(x)
        user_repr, kid = self.encode_user(x)

        score = self.compute_score(cdd_repr, user_repr)

        if self.training:
            logits = nn.functional.log_softmax(score, dim=1)
        else:
            logits = torch.sigmoid(score)

        return logits, kid


    def predict_fast(self, x):
        # [bs, cs, hd]
        cdd_repr = self.news_reprs(x['cdd_id'].to(self.device))
        user_repr, _ = self.encode_user(x)
        scores = self.compute_score(cdd_repr, user_repr)
        logits = torch.sigmoid(scores)
        return logits

In [12]:
embedding = BERT_Embedding(manager)

encoderN = CNN_Encoder(manager)
# encoderN = RNN_Encoder(manager)
# encoderN = MHA_Encoder(manager)

# encoderU = CNN_User_Encoder(manager)
encoderU = RNN_User_Encoder(manager)
# encoderU = MHA_User_Encoder(manager)
# encoderU = Attention_Pooling(manager)
# encoderU = Average_Pooling(manager)

reducer = Matching_Reducer(manager)
# reducer = Identical_Reducer(manager)

# ranker = CNN_Ranker(manager)
# ranker = BERT_Onepass_Ranker(manager)
# ranker = BERT_Original_Ranker(manager)

model = TESRec(manager, embedding, encoderN, encoderU, reducer).to(manager.device)
# model = ESM(manager, embedding, encoderN, encoderU, reducer, ranker).to(manager.device)
# model = XFormer(manager).to(manager.device)

In [30]:
manager.load(model, 230000)

[2021-11-08 08:24:01,966] INFO (utils.Manager) loading model from data/model_params/tesrec__bert__cnn__lstm__matching__token__3/large_step230000.model...


In [31]:
a,b,c,d,e = model.encode_user(x1)

In [29]:
np.asarray(t.convert_ids_to_tokens(x1['his_encoded_index'][0,0])[1:])[[86,2,93]]

array(['##gb', 'fix', 'he'], dtype='<U11')

In [38]:
model.encoderN.cnn.weight

Parameter containing:
tensor([[[-0.0491, -0.0237,  0.0978],
         [ 0.0875,  0.0161, -0.0656],
         [ 0.0285, -0.0399,  0.0502],
         ...,
         [-0.0014,  0.0307, -0.0174],
         [ 0.0100, -0.0201,  0.0153],
         [ 0.0079,  0.0301, -0.0928]],

        [[-0.0331, -0.0284,  0.0490],
         [-0.0505, -0.0271, -0.0014],
         [ 0.0446,  0.0101,  0.0496],
         ...,
         [ 0.0815, -0.0096,  0.0206],
         [ 0.0013,  0.0414,  0.0997],
         [-0.0093,  0.0101,  0.0060]],

        [[ 0.0264, -0.0080, -0.0189],
         [ 0.0627,  0.0110, -0.0748],
         [ 0.0237,  0.0505,  0.0461],
         ...,
         [-0.0432, -0.0009,  0.0765],
         [-0.0431, -0.0080, -0.0108],
         [ 0.0030, -0.0535,  0.0183]],

        ...,

        [[-0.0090, -0.0659, -0.0015],
         [ 0.0181,  0.0028, -0.0396],
         [ 0.0556,  0.0066, -0.0121],
         ...,
         [ 0.0344,  0.0430,  0.0057],
         [ 0.0281,  0.0134, -0.0368],
         [-0.0003, -0.0285, 

In [37]:
d[0,0]

tensor([-4.4400e-03,  6.3619e-04, -2.2063e-01,  1.8354e-04,  3.8074e-04,
        -8.7050e-04,  3.6028e-04,  3.1499e-04,  9.9705e-05,  5.7517e-04,
        -2.1719e-02,  5.3033e-05,  5.7555e-05, -4.0111e-04, -1.7791e-03,
         5.1019e-04, -2.2478e-04,  3.9305e-04, -3.9745e-01, -7.8493e-03,
         2.3317e-04, -2.2985e-03,  8.6744e-01,  3.8022e-04, -7.0838e-01,
         1.2778e-04, -5.1304e-01,  1.8707e-04, -1.3358e-04,  4.0864e-05,
         2.8037e-03, -1.3580e-03, -9.4156e-04,  4.6346e-01, -1.2117e-03,
        -5.8578e-02, -9.8103e-04, -6.7563e-04,  2.5512e-03, -6.6214e-01,
         1.5321e-05, -1.7117e-01, -1.9915e-01, -1.5334e-04, -5.9781e-03,
        -6.9705e-05, -1.6098e-02, -2.7902e-03, -2.5514e-01,  4.5431e-05,
         1.7989e-03, -1.2389e-04, -4.8516e-04, -8.4373e-04, -2.3518e-04,
        -2.6455e-03,  5.0970e-04,  3.8207e-02, -1.5263e-03, -2.8502e-05,
         2.3914e-04,  9.7725e-05,  3.3706e-03,  1.7348e-04, -1.8785e-04,
        -1.6702e-03, -2.8307e-04,  7.0267e-04, -7.6

In [32]:
e[0,3], e.shape, e.topk(dim=-1, k=3)

(tensor([ 0.2419, -0.5074, -0.6057, -0.5703, -0.5971, -0.0613, -0.5894,  0.5420,
         -0.3725, -0.6369, -0.6338, -0.5033, -0.4872, -0.6208, -0.5951,    -inf,
            -inf, -0.6143, -0.5980, -0.6238, -0.6118, -0.6045, -0.6209, -0.5754,
         -0.5565, -0.5947,    -inf,    -inf, -0.4800, -0.4611, -0.6024, -0.6179,
         -0.5408,    -inf, -0.5862, -0.3453,    -inf, -0.3828,    -inf, -0.4651,
            -inf,    -inf, -0.3774, -0.5738, -0.6290,    -inf, -0.4551,    -inf,
         -0.4164,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf,  

In [None]:
c

In [12]:
trump = c[0,:,1:].gather(index=b[0].unsqueeze(-1).expand(50,5,384), dim=-1)

In [17]:
b[0,3]

tensor([44, 32, 21, 20, 43])

In [None]:
t.batch_decode(x1['his_encoded_index'][0])

In [8]:
t("trump")

{'input_ids': [101, 8398, 102], 'token_type_ids': [0, 0, 0], 'attention_mask': [1, 1, 1]}

In [None]:
recall = 0
count = 0
for x in tqdm(loaders[0], ncols=120, leave=True):
    kid = model.encode_user(x)[1]
    ps_terms = x['his_encoded_index'][:, :, 1:].to(manager.device).gather(index=kid, dim=-1).view(-1).tolist()
    

In [None]:
embedding = 32*50*2*100*768
ext_encoding = 32*50*(2*3*384*100 + 4*100*384) + 32*8*2*384*384*50
reduction = (2*100*768 + 100*math.log2(100))*32*50*2+32*50*3
bert_embedding = 32*150*768*2
bert_project = 32*150*768*2*3
bert_attn = 32*150*12*64*64*2 + 32*150*12*12*64*2
bert_intm = 32*150*768*768*2 + 32*150*768*2 + 32*150*768*3072*4
bert_pool = 32*150*768*4
bert = (bert_project + bert_attn + bert_intm) * 12 + bert_pool

total_isrec = embedding + ext_encoding + reduction + bert

embedding = 32*50*2*100*768 + 32*50*100*768*2
bert_project = 32*5000*768*2*3
bert_attn = 32*5000*12*64*64*2 + 32*5000*12*12*64*2
bert_intm = 32*5000*768*768*2 + 32*5000*768*2 + 32*5000*768*3072*4
bert_pool = 32*5000*768*4
bert = (bert_project + bert_attn + bert_intm) * 12 + bert_pool

total =  embedding + bert + 32*8*2*768*768*50

total_isrec, total