In [1]:
import re
import logging
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import math
import time
import pickle
from collections import defaultdict
from data.configs.demo import config
from collections import defaultdict

from transformers import AutoTokenizer, AutoModel, BertModel, BertConfig
from utils.Manager import Manager

from models.Embeddings.BERT import BERT_Embedding
from models.Encoders.CNN import CNN_Encoder,CNN_User_Encoder
from models.Encoders.RNN import RNN_Encoder,RNN_User_Encoder
from models.Encoders.MHA import MHA_Encoder, MHA_User_Encoder
from models.Modules.DRM import Matching_Reducer, Slicing_Reducer
from models.Rankers.BERT import BERT_Onepass_Ranker, BERT_Original_Ranker
from models.Rankers.CNN import CNN_Ranker
from models.Encoders.Pooling import Attention_Pooling, Average_Pooling

from models.BaseModel import BaseModel

from models.Encoders.BERT import BERT_Encoder
from models.Encoders.Pooling import *

from models.ESM import ESM
from models.TTMS import TTMS
 
from models.Modules.Attention import MultiheadAttention, get_attn_mask, XSoftmax
torch.set_printoptions(threshold=100000)

In [None]:
# m = AutoModel.from_pretrained('bert-base-uncased',cache_dir=config.path + 'bert_cache/')
# m2 = AutoModel.from_pretrained('microsoft/deberta-base',cache_dir=config.path + 'bert_cache/')
# m3 = AutoModel.from_pretrained("microsoft/unilm-base-cased",cache_dir=config.path + 'bert_cache/')

# t = AutoTokenizer.from_pretrained('bert-base-uncased', cache_dir=config.path + "bert_cache/")
# t2 = DebertaTokenizerFast.from_pretrained('microsoft/deberta-base', cache_dir=config.path + "bert_cache/")

In [2]:
# config.reducer = 'entity'
# config.embedding = 'deberta'
# config.bert = 'microsoft/deberta-base'
# config.device = 0
config.seed = None
manager = Manager(config)
loaders = manager.prepare()
X1 = list(loaders[0])
X2 = list(loaders[1])
x1 = X1[0]
x2 = X2[0]

[2021-10-06 23:51:55,530] INFO (utils.Manager) Hyper Parameters are 
{
    "scale": "demo",
    "mode": "tune",
    "batch_size": 5,
    "batch_size_news": 100,
    "batch_size_history": 100,
    "k": 5,
    "threshold": -Infinity,
    "abs_length": 40,
    "signal_length": 100,
    "his_size": 50,
    "cdd_size": 5,
    "impr_size": 10,
    "dropout_p": 0.2,
    "lr": 0.0001,
    "bert_lr": 3e-05,
    "embedding": "bert",
    "encoderN": "cnn",
    "encoderU": "rnn",
    "selector": "sfi",
    "reducer": "matching",
    "ranker": "onepass",
    "embedding_dim": 768,
    "hidden_dim": 384,
    "base_rank": 0,
    "world_size": 0,
    "seed": null,
    "granularity": "avg",
    "debias": true,
    "full_attn": true,
    "descend_history": false,
    "shuffle_pos": false,
    "save_pos": false,
    "sep_his": false,
    "diversify": false,
    "no_dedup": false,
    "no_order_embed": false,
    "no_rm_punc": false,
    "fast": false,
    "scheduler": "linear",
    "warmup": 100,
    "shu

In [3]:
class TTMS(BaseModel):
    """
    Tow tower model with selection

    1. encode candidate news with bert
    2. encode ps terms with the same bert, using [CLS] embedding as user representation
    3. predict by scaled dot product
    """
    def __init__(self, manager, embedding, encoderN, encoderU, reducer):
        super().__init__(manager)

        self.embedding = embedding
        self.encoderN = encoderN
        self.encoderU = encoderU

        self.reducer = reducer
        self.bert = BERT_Encoder(manager)

        self.newsUserProject = nn.Sequential(
            nn.Linear(self.bert.hidden_dim, self.bert.hidden_dim),
            nn.Tanh()
        )

        if manager.debias:
            self.userBias = nn.Parameter(torch.randn(1,self.bert.hidden_dim))
            nn.init.xavier_normal_(self.userBias)

        self.hidden_dim = self.bert.hidden_dim

        self.granularity = manager.granularity
        if self.granularity != 'token':
            self.register_buffer('cdd_dest', torch.zeros((self.batch_size, manager.impr_size, manager.signal_length * manager.signal_length)), persistent=False)
            if manager.reducer in ["bm25", "entity", "first"]:
                self.register_buffer('his_dest', torch.zeros((self.batch_size, self.his_size, (manager.k + 1) * (manager.k + 1))), persistent=False)
            else:
                self.register_buffer('his_dest', torch.zeros((self.batch_size, self.his_size, manager.signal_length * manager.signal_length)), persistent=False)


        manager.name = '__'.join(['ttms', manager.embedding, manager.encoderN, manager.encoderU, manager.reducer, manager.granularity])
        self.name = manager.name


    def clickPredictor(self, cdd_news_repr, user_repr):
        """ calculate batch of click probabolity

        Args:
            cdd_news_repr: news-level representation, [batch_size, cdd_size, hidden_dim]
            user_repr: user representation, [batch_size, 1, hidden_dim]

        Returns:
            score of each candidate news, [batch_size, cdd_size]
        """
        score = cdd_news_repr.matmul(user_repr.transpose(-2,-1)).squeeze(-1)/math.sqrt(self.embedding.embedding_dim)
        return score


    def _forward(self,x):
        if self.granularity != 'token':
            batch_size = x['cdd_subword_index'].size(0)
            cdd_size = x['cdd_subword_index'].size(1)

            if self.training:
                cdd_dest = self.cdd_dest[:batch_size, :cdd_size]
                his_dest = self.his_dest[:batch_size]

            # batch_size always equals 1 when evaluating
            else:
                cdd_dest = self.cdd_dest[[0], :cdd_size]
                his_dest = self.his_dest[[0]]

            cdd_subword_index = x['cdd_subword_index'].to(self.device)
            his_subword_index = x['his_subword_index'].to(self.device)
            his_signal_length = his_subword_index.size(-2)
            cdd_subword_index = cdd_subword_index[:, :, :, 0] * self.signal_length + cdd_subword_index[:, :, :, 1]
            his_subword_index = his_subword_index[:, :, :, 0] * his_signal_length + his_subword_index[:, :, :, 1]

            if self.training:
                # * cdd_mask to filter out padded cdd news
                cdd_subword_prefix = cdd_dest.scatter(dim=-1, index=cdd_subword_index, value=1) * x["cdd_mask"].to(self.device)
            else:
                cdd_subword_prefix = cdd_dest.scatter(dim=-1, index=cdd_subword_index, value=1)
            # FIXME historical news not need this
            cdd_subword_prefix = cdd_subword_prefix.view(batch_size, cdd_size, self.signal_length, self.signal_length)

            his_subword_prefix = his_dest.scatter(dim=-1, index=his_subword_index, value=1) * x["his_mask"].to(self.device)
            his_subword_prefix = his_subword_prefix.view(batch_size, self.his_size, his_signal_length, his_signal_length)

            if self.granularity == 'avg':
                # average subword embeddings as the word embedding
                cdd_subword_prefix = F.normalize(cdd_subword_prefix, p=1, dim=-1)
                his_subword_prefix = F.normalize(his_subword_prefix, p=1, dim=-1)

            cdd_attn_mask = cdd_subword_prefix.matmul(x['cdd_attn_mask'].to(self.device).float().unsqueeze(-1)).squeeze(-1)
            his_attn_mask = his_subword_prefix.matmul(x["his_attn_mask"].to(self.device).float().unsqueeze(-1)).squeeze(-1)
            his_refined_mask = None
            if 'his_refined_mask' in x:
                his_refined_mask = his_subword_prefix.matmul(x["his_refined_mask"].to(self.device).float().unsqueeze(-1)).squeeze(-1)

        else:
            cdd_subword_prefix = None
            his_subword_prefix = None
            cdd_attn_mask = x['cdd_attn_mask'].to(self.device)
            his_attn_mask = x["his_attn_mask"].to(self.device)
            his_refined_mask = None
            if 'his_refined_mask' in x:
                his_refined_mask = x["his_refined_mask"].to(self.device)

        cdd_news = x["cdd_encoded_index"].to(self.device)
        _, cdd_news_repr = self.bert(
            self.embedding(cdd_news, cdd_subword_prefix), cdd_attn_mask
        )
        cdd_news_repr = self.newsUserProject(cdd_news_repr)

        his_news = x["his_encoded_index"].to(self.device)
        his_news_embedding = self.embedding(his_news, his_subword_prefix)
        his_news_encoded_embedding, his_news_repr = self.encoderN(
            his_news_embedding, his_attn_mask
        )
        # no need to calculate this if ps_terms are fixed in advance
        if self.reducer.name == 'matching':
            user_repr = self.encoderU(his_news_repr, his_mask=x['his_mask'].to(self.device), user_index=x["user_id"].to(self.device))
        else:
            user_repr = None

        ps_terms, ps_term_mask, kid = self.reducer(his_news_encoded_embedding, his_news_embedding, user_repr, his_news_repr, his_attn_mask, his_refined_mask)

        _, user_cls = self.bert(ps_terms, ps_term_mask)
        user_repr = self.newsUserProject(user_cls)
        if hasattr(self, 'userBias'):
            user_repr = user_repr + self.userBias

        return self.clickPredictor(cdd_news_repr, user_repr), kid, (cdd_news_repr, user_repr)


    def forward(self,x):
        """
        Decoupled function, score is unormalized click score
        """
        score, kid,c = self._forward(x)

        if self.training:
            prob = nn.functional.log_softmax(score, dim=1)
        else:
            prob = torch.sigmoid(score)

        return prob, kid, c


    def encode_news(self, x):
        """
        encode news of loader_news
        """
        if not self.ready_encode:
            self._init_encoding()

        # encode news with MIND_news
        if self.granularity != 'token':
            batch_size = x['cdd_subword_index'].size(0)
            cdd_dest = self.cdd_dest[:batch_size]
            cdd_subword_index = x['cdd_subword_index'].to(self.device)
            cdd_subword_index = cdd_subword_index[:, :, 0] * self.signal_length + cdd_subword_index[:, :, 1]
            cdd_subword_prefix = cdd_dest.scatter(dim=-1, index=cdd_subword_index, value=1)

            cdd_subword_prefix = cdd_subword_prefix.view(batch_size, self.signal_length, self.signal_length)

            if self.granularity == 'avg':
                # average subword embeddings as the word embedding
                cdd_subword_prefix = F.normalize(cdd_subword_prefix, p=1, dim=-1)

            cdd_attn_mask = cdd_subword_prefix.matmul(x['cdd_attn_mask'].to(self.device).float().unsqueeze(-1)).squeeze(-1)

        else:
            cdd_subword_prefix = None
            cdd_attn_mask = x['cdd_attn_mask'].to(self.device)

        cdd_news = x["cdd_encoded_index"].to(self.device)
        _, cdd_news_repr = self.bert(
            self.embedding(cdd_news, cdd_subword_prefix), cdd_attn_mask
        )
        cdd_news_repr = self.newsUserProject(cdd_news_repr.squeeze(1))

        return cdd_news_repr


    def encode_user(self, x):
        """
        encode user of loader_history
        """
        if not self.ready_encode:
            self._init_encoding()

        if self.granularity != 'token':
            batch_size = x['his_encoded_index'].size(0)
            his_dest = self.his_dest[:batch_size]

            his_subword_index = x['his_subword_index'].to(self.device)
            his_signal_length = his_subword_index.size(-2)
            his_subword_index = his_subword_index[:, :, :, 0] * his_signal_length + his_subword_index[:, :, :, 1]

            his_subword_prefix = his_dest.scatter(dim=-1, index=his_subword_index, value=1) * x["his_mask"].to(self.device)
            his_subword_prefix = his_subword_prefix.view(batch_size, self.his_size, his_signal_length, his_signal_length)

            if self.granularity == 'avg':
                # average subword embeddings as the word embedding
                his_subword_prefix = F.normalize(his_subword_prefix, p=1, dim=-1)

            his_attn_mask = his_subword_prefix.matmul(x["his_attn_mask"].to(self.device).float().unsqueeze(-1)).squeeze(-1)
            his_refined_mask = None
            if 'his_refined_mask' in x:
                his_refined_mask = his_subword_prefix.matmul(x["his_refined_mask"].to(self.device).float().unsqueeze(-1)).squeeze(-1)

        else:
            his_subword_prefix = None
            his_attn_mask = x["his_attn_mask"].to(self.device)
            his_refined_mask = None
            if 'his_refined_mask' in x:
                his_refined_mask = x["his_refined_mask"].to(self.device)

        his_news = x["his_encoded_index"].to(self.device)
        his_news_embedding = self.embedding(his_news, his_subword_prefix)
        his_news_encoded_embedding, his_news_repr = self.encoderN(
            his_news_embedding, his_attn_mask
        )
        # no need to calculate this if ps_terms are fixed in advance
        if self.reducer.name == 'matching':
            user_repr = self.encoderU(his_news_repr, his_mask=x['his_mask'].to(self.device), user_index=x['user_id'].to(self.device))
        else:
            user_repr = None

        ps_terms, ps_term_mask, _ = self.reducer(his_news_encoded_embedding, his_news_embedding, user_repr, his_news_repr, his_attn_mask, his_refined_mask)

        _, user_cls = self.bert(ps_terms, ps_term_mask)
        user_repr = self.newsUserProject(user_cls.squeeze(1))
        if hasattr(self, 'userBias'):
            user_repr = user_repr + self.userBias
        return user_repr

In [4]:
embedding = BERT_Embedding(manager)

encoderN = CNN_Encoder(manager)
# encoderN = RNN_Encoder(manager)
# encoderN = MHA_Encoder(manager)

# encoderU = CNN_User_Encoder(manager)
encoderU = RNN_User_Encoder(manager)
# encoderU = MHA_User_Encoder(manager)
# encoderU = Attention_Pooling(manager)
# encoderU = Average_Pooling(manager)

reducer = Matching_Reducer(manager)
# reducer = Slicing_Reducer(manager)

# ranker = CNN_Ranker(manager)
# ranker = BERT_Onepass_Ranker(manager)
# ranker = BERT_Original_Ranker(manager)

model = TTMS(manager, embedding, encoderN, encoderU, reducer).to(manager.device)
# model = ESM(manager, embedding, encoderN, encoderU, reducer, ranker).to(manager.device)

manager.load(model, 589, strict=False)

[2021-10-06 23:52:09,656] INFO (utils.Manager) loading model from data/model_params/ttms__bert__cnn__rnn__matching__avg/demo_step589_[k=5].model...


In [9]:
model.eval()
a,b,c = model(x2)
d = model.encode_news(xn)
e = model.encode_user(xu)

In [8]:
a

tensor([[0.8009, 0.8085, 0.8065, 0.7089, 0.7540, 0.7926, 0.8030, 0.8054, 0.6884,
         0.8302]], grad_fn=<SigmoidBackward>)

In [14]:
torch.sigmoid(torch.dot(e[0],d[28])/math.sqrt(768))

tensor(0.8085, grad_fn=<SigmoidBackward>)

In [5]:
manager.fast = True
manager.mode = "dev"

loaders = manager.prepare()
xn = list(loaders[1])[413]
xu = list(loaders[2])[0]

[2021-10-06 23:53:05,736] INFO (utils.Manager) Hyper Parameters are 
{
    "scale": "demo",
    "mode": "dev",
    "batch_size": 5,
    "batch_size_news": 100,
    "batch_size_history": 100,
    "k": 5,
    "threshold": -Infinity,
    "abs_length": 40,
    "signal_length": 100,
    "his_size": 50,
    "cdd_size": 5,
    "impr_size": 10,
    "dropout_p": 0.2,
    "lr": 0.0001,
    "bert_lr": 3e-05,
    "embedding": "bert",
    "encoderN": "cnn",
    "encoderU": "rnn",
    "selector": "sfi",
    "reducer": "matching",
    "ranker": "onepass",
    "embedding_dim": 768,
    "hidden_dim": 384,
    "base_rank": 0,
    "world_size": 0,
    "seed": null,
    "granularity": "avg",
    "debias": true,
    "full_attn": true,
    "descend_history": false,
    "shuffle_pos": false,
    "save_pos": false,
    "sep_his": false,
    "diversify": false,
    "no_dedup": false,
    "no_order_embed": false,
    "no_rm_punc": false,
    "fast": true,
    "scheduler": "linear",
    "warmup": 100,
    "shuff

In [8]:
xn['cdd_id'][28]

tensor(41328)