In [1]:
import re
import logging
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import math
import time
import pickle
from collections import defaultdict
from utils.utils import prepare
from data.configs.demo import config
from collections import defaultdict

from transformers import AutoTokenizer, AutoModel, BertModel, BertConfig
from utils.Manager import Manager

from models.Embeddings.BERT import BERT_Embedding
from models.Encoders.CNN import CNN_Encoder,CNN_User_Encoder
from models.Encoders.RNN import RNN_Encoder,RNN_User_Encoder
from models.Encoders.MHA import MHA_Encoder, MHA_User_Encoder
from models.Modules.DRM import Matching_Reducer, Slicing_Reducer
from models.Rankers.BERT import BERT_Onepass_Ranker, BERT_Original_Ranker
from models.Rankers.CNN import CNN_Ranker
from models.Encoders.Pooling import Attention_Pooling, Average_Pooling

from models.Encoders.BERT import BERT_Encoder
from models.Encoders.Pooling import *

from models.ESM import ESM
from models.TTMS import TTMS
 
from models.Modules.Attention import MultiheadAttention, get_attn_mask
torch.set_printoptions(threshold=10_000)

In [2]:
# config.reducer = 'entity'
# config.embedding = 'deberta'
# config.bert = 'microsoft/deberta-base'
# config.device = 0

manager = Manager(config)
loaders = prepare(manager)
X1 = list(loaders[0])
X2 = list(loaders[1])
x1 = X1[0]
x2 = X2[0]

[2021-09-11 14:33:59,368] INFO (utils.utils) Hyper Parameters are 
scale:demo
batch_size:5
k:5
threshold:-inf
signal_length:100
his_size:50
impr_size:10
lr:0.0001
bert_lr:3e-05
hidden_dim:384
world_size:0
step:0
granularity:avg
ascend_history:False
diversify:False
no_dedup:False
no_order_embed:False
bert:bert-base-uncased
[2021-09-11 14:33:59,368] INFO (utils.utils) preparing dataset...
[2021-09-11 14:33:59,372] INFO (utils.MIND) process NO.0 loading cached user behavior from data/cache/bert/MINDdemo_train/10/behaviors..pkl
[2021-09-11 14:33:59,387] INFO (utils.MIND) process NO.0 loading cached news tokenization from data/cache/bert/MINDdemo_train/news.pkl
[2021-09-11 14:34:00,248] INFO (utils.utils) deduplicating...
[2021-09-11 14:34:01,906] INFO (utils.MIND) process NO.0 loading cached user behavior from data/cache/bert/MINDdemo_dev/10/behaviors..pkl
[2021-09-11 14:34:01,909] INFO (utils.MIND) process NO.0 loading cached news tokenization from data/cache/bert/MINDdemo_dev/news.pkl
[2

In [3]:
embedding = BERT_Embedding(manager)

encoderN = CNN_Encoder(manager)
# encoderN = RNN_Encoder(manager)
# encoderN = MHA_Encoder(manager)

# encoderU = CNN_User_Encoder(manager)
encoderU = RNN_User_Encoder(manager)
# encoderU = MHA_User_Encoder(manager)
# encoderU = Attention_Pooling(manager)
# encoderU = Average_Pooling(manager)

reducer = Matching_Reducer(manager)
# reducer = Slicing_Reducer(manager)

# ranker = CNN_Ranker(manager)
ranker = BERT_Onepass_Ranker(manager)
# ranker = BERT_Original_Ranker(manager)

# model = TTMS(manager, embedding, encoderN, encoderU, reducer).to(manager.device)
model = ESM(manager, embedding, encoderN, encoderU, reducer, ranker).to(manager.device)

In [4]:
class BertSelfAttention(nn.Module):
    def __init__(self, config):
        """
        one-pass bert, where other candidate news except itself are masked
        """
        super().__init__()

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.softmax = nn.Softmax(dim=-1)

        self.signal_length = config.signal_length
        self.all_length = config.cdd_size * self.signal_length
        self.term_num = config.term_num

        # default to term_num = his_size * k + 1
        self.register_buffer('one_pass_attn_mask_train', torch.cat([torch.eye(config.cdd_size).repeat_interleave(repeats=self.signal_length, dim=-1).repeat_interleave(repeats=self.signal_length, dim=0), torch.ones(config.cdd_size * self.signal_length, config.term_num)], dim=-1).unsqueeze(0).unsqueeze(0), persistent=False)
        self.one_pass_attn_mask_train = (1 - self.one_pass_attn_mask_train) * -10000

        self.register_buffer('one_pass_attn_mask_eval', torch.eye(config.impr_size).repeat_interleave(repeats=self.signal_length, dim=-1), persistent=False)
        self.register_buffer('ps_term_mask', torch.ones(1,self.term_num), persistent=False)
        self.one_pass_attn_mask_eval = (1 - self.one_pass_attn_mask_eval) * -10000
        self.ps_term_mask = (1 - self.ps_term_mask) * -10000

        # self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
        # if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
        #     self.max_position_embeddings = config.max_position_embeddings
        #     self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)


    def transpose_for_scores(self, x):
        """
        transpose the head_num dimension, to make every head operates in parallel
        """
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        """ customized bert self attention, attending to the references

        Args:
            hidden_states: normally encoded candidate news, [batch_size, signal_length, hidden_dim]
            references: normally personalized terms, [batch_size, term_num, hidden_dim]
        """
        # [CLS] + signal_length
        if self.training:
            one_pass_mask = self.one_pass_attn_mask_train
        else:
            attn_field_length = hidden_states.size(1) - self.term_num
            cdd_size = attn_field_length // self.signal_length
            one_pass_mask = torch.cat([(self.one_pass_attn_mask_eval[:cdd_size, :cdd_size * self.signal_length]).repeat_interleave(repeats=self.signal_length, dim=0), self.ps_term_mask.expand(attn_field_length, self.term_num)], dim=-1).unsqueeze(0).unsqueeze(0)

        attn_field = hidden_states[:, :-self.term_num]

        # [batch_size, head_num, *, head_dim]
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        cdd_layer = self.transpose_for_scores(self.query(attn_field))
        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(cdd_layer, key_layer.transpose(-1, -2))
        # [bs, hn, cdd_length, *]
        attention_scores = (attention_scores / math.sqrt(self.attention_head_size)) + one_pass_mask
        attention_scores = attention_scores + attention_mask
        # Normalize the attention scores to probabilities.
        attention_probs = self.softmax(attention_scores)
        attention_probs = self.dropout(attention_probs)

        # full attention
        # pst_layer = self.transpose_for_scores(self.query(hidden_states[:, -self.term_num:]))
        # attention_scores_pst = torch.matmul(pst_layer, pst_layer.transpose(-1, -2))
        # attention_scores_pst = attention_scores_pst / math.sqrt(self.attention_head_size)
        # attention_scores_pst = attention_scores_pst * attention_mask[:, :, -self.term_num:, -self.term_num:]
        # attention_probs_pst = self.softmax(attention_scores_pst)
        # attention_probs_pst = self.dropout(attention_probs_pst)
        # context_layer = torch.cat([torch.matmul(attention_probs, value_layer), torch.matmul(attention_probs_pst, value_layer[:, :, -self.term_num:])], dim=-2)

        context_layer = torch.cat([torch.matmul(attention_probs, value_layer), value_layer[:, :, -self.term_num:]], dim=-2)

        # [batch_size, signal_length, head_num, head_dim]
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        return (context_layer,)

class BERT_Onepass_Ranker(nn.Module):
    """
    one-pass bert:
        cdd1 cdd2 ... cddn [SEP] pst1 pst2 ...
    """
    def __init__(self, config):
        # confirm the hidden dim to be 768
        assert config.embedding_dim == 768
        # confirm term_num + signal_length is less than 512
        # assert config.k * config.his_size + config.his_size + config.signal_length < 512

        super().__init__()

        self.name = 'onepass-bert'
        self.signal_length = config.signal_length
        self.term_num = config.term_num + 1
        self.embedding_dim = config.embedding_dim
        self.final_dim = self.embedding_dim

        bert_config = BertConfig()
        # primary bert
        prim_bert = BertModel(bert_config).encoder
        bert_config.signal_length = self.signal_length
        bert_config.term_num = config.term_num + 1
        bert_config.cdd_size = config.cdd_size
        bert_config.impr_size = config.impr_size
        for l in prim_bert.layer:
            l.attention.self = BertSelfAttention(bert_config)

        bert = BertModel.from_pretrained(
            config.bert,
            cache_dir=config.path + 'bert_cache/'
        )
        prim_bert.load_state_dict(bert.encoder.state_dict())
        self.bert = prim_bert

        self.pooler = nn.Sequential(
            nn.Linear(self.embedding_dim, self.final_dim),
            nn.Tanh()
        )
        nn.init.xavier_normal_(self.pooler[0].weight)

        # [2, embedding_dim]
        self.token_type_embedding = nn.Parameter(torch.randn(2, self.embedding_dim))
        nn.init.xavier_normal_(self.token_type_embedding)

        # [SEP] token
        if config.embedding == 'bert':
            self.sep_embedding = nn.Parameter(bert.embeddings.word_embeddings(torch.tensor([102])).clone().detach().requires_grad_(True).view(1,1,self.embedding_dim))
        elif config.embedding == 'deberta':
            self.sep_embedding = nn.Parameter(bert.embeddings.word_embeddings(torch.tensor([2])).clone().detach().requires_grad_(True).view(1,1,self.embedding_dim))
        else:
            self.sep_embedding = nn.Parameter(torch.randn(1,1,self.embedding_dim))
            nn.init.xavier_normal_(self.sep_embedding)

        self.register_buffer('sep_attn_mask', torch.ones(1, 1), persistent=False)

    def forward(self, cdd_news_embedding, ps_terms, cdd_attn_mask, ps_term_mask):
        """
        calculate interaction tensor and reduce it to a vector

        Args:
            cdd_news_embedding: word-level representation of candidate news, [batch_size, cdd_size, signal_length, embedding_dim]
            ps_terms: concatenated historical news or personalized terms, [batch_size, term_num, embedding_dim]
            cdd_attn_mask: attention mask of the candidate news, [batch_size, cdd_size, signal_length]
            ps_term_mask: attention mask of the personalized terms, [batch_size, term_num]

        Returns:
            reduced_tensor: output tensor after CNN2d, [batch_size, cdd_size, final_dim]
        """
        batch_size = cdd_news_embedding.size(0)
        cdd_size = cdd_news_embedding.size(1)

        # [bs,tn,hd]
        ps_terms += self.token_type_embedding[1]

        # [bs, cs*sl, hd]
        cdd_news_embedding = cdd_news_embedding.view(batch_size, -1, self.embedding_dim)

        bert_input = torch.cat([cdd_news_embedding, self.sep_embedding.expand(batch_size, 1, self.embedding_dim), ps_terms], dim=-2)
        bert_input[:, :cdd_news_embedding.size(1) + 1] += self.token_type_embedding[0]

        # [bs, cs*sl]
        attn_mask = cdd_attn_mask.view(batch_size, -1)
        cdd_length = attn_mask.size(-1)

        attn_mask = torch.cat([attn_mask, self.sep_attn_mask.expand(batch_size, 1), ps_term_mask], dim=-1)
        attn_mask = get_attn_mask(attn_mask, query_length=cdd_length)
        attn_mask = (1.0 - attn_mask) * -10000.0

        bert_output = self.bert(bert_input, attention_mask=attn_mask).last_hidden_state[:, 0 : cdd_size * (self.signal_length) : self.signal_length].view(batch_size, cdd_size, self.embedding_dim)
        bert_output = self.pooler(bert_output)

        return bert_output


ranker = BERT_Onepass_Ranker(manager)
# ranker = BERT_Original_Ranker(manager)

# model = TTMS(manager, embedding, encoderN, encoderU, reducer).to(manager.device)
model = ESM(manager, embedding, encoderN, encoderU, reducer, ranker).to(manager.device)

In [5]:
model.eval()
x2 = X2[2]
a,b = model(x2)

torch.Size([1, 1, 200, 451])
torch.Size([1, 1, 200, 451])
torch.Size([1, 1, 200, 451])
torch.Size([1, 1, 200, 451])
torch.Size([1, 1, 200, 451])
torch.Size([1, 1, 200, 451])
torch.Size([1, 1, 200, 451])
torch.Size([1, 1, 200, 451])
torch.Size([1, 1, 200, 451])
torch.Size([1, 1, 200, 451])
torch.Size([1, 1, 200, 451])
torch.Size([1, 1, 200, 451])
