In [1]:
import re
import logging
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import math
import time
import pickle
from collections import defaultdict
from utils.utils import prepare, convert_tokens_to_words
from data.configs.demo import config
from collections import defaultdict

from transformers import AutoTokenizer, AutoModel
from utils.Manager import Manager

from models.Embeddings.BERT import BERT_Embedding
from models.Encoders.CNN import CNN_Encoder,CNN_User_Encoder
from models.Encoders.RNN import RNN_Encoder,RNN_User_Encoder
from models.Encoders.MHA import MHA_Encoder, MHA_User_Encoder
from models.Modules.DRM import Matching_Reducer, Slicing_Reducer
from models.Rankers.BERT import BERT_Onepass_Ranker, BERT_Original_Ranker
from models.Rankers.CNN import CNN_Ranker
from models.Encoders.Pooling import Attention_Pooling, Average_Pooling

from models.Encoders.BERT import BERT_Encoder
from models.Encoders.Pooling import *

from models.ESM import ESM
from models.TTMS import TTMS
 
from models.Modules.Attention import MultiheadAttention

loss = nn.NLLLoss()

m = AutoModel.from_pretrained('microsoft/deberta-base', cache_dir='../../../Data/bert_cache')
t = AutoTokenizer.from_pretrained('microsoft/deberta-base', cache_dir='../../../Data/bert_cache')

In [5]:
inputs = t('I love you',return_tensors="pt")
m(**inputs), inputs, t.convert_ids_to_tokens(inputs.input_ids[0].tolist())

(BaseModelOutput(last_hidden_state=tensor([[[ 0.0323, -0.0305, -0.0739,  ...,  0.0099,  0.0663, -0.0718],
          [-0.6678, -0.7011, -0.5957,  ...,  0.5535, -0.1700, -0.1757],
          [-0.8192,  0.0860, -0.5883,  ...,  0.6707, -0.5009, -0.1986],
          [-0.9229,  0.1552, -0.1609,  ...,  1.1191, -0.5578,  0.1927],
          [ 0.1646,  0.0510, -0.1161,  ...,  0.0469,  0.1763,  0.0288]]],
        grad_fn=<AddBackward0>), hidden_states=None, attentions=None),
 {'input_ids': tensor([[  1, 100, 657,  47,   2]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])},
 ['[CLS]', 'I', 'Ġlove', 'Ġyou', '[SEP]'])

In [6]:
attn_mask = torch.tensor([[1.,0,1,1,0],[1,0,0,0,0]])
hidden_states = torch.rand((2,3,4))
extended_attn_mask = m.encoder.get_attention_mask(attn_mask)
rel_pos = m.encoder.get_rel_pos(hidden_states)

rel_pos, extended_attn_mask

(tensor([[[ 0, -1, -2],
          [ 1,  0, -1],
          [ 2,  1,  0]]]),
 tensor([[[[1, 1, 1, 1, 0],
           [1, 1, 1, 1, 0],
           [1, 1, 1, 1, 0],
           [1, 1, 1, 1, 0],
           [0, 0, 0, 0, 0]]],
 
 
         [[[1, 0, 0, 0, 0],
           [0, 0, 0, 0, 0],
           [0, 0, 0, 0, 0],
           [0, 0, 0, 0, 0],
           [0, 0, 0, 0, 0]]]], dtype=torch.uint8))

In [2]:
# config.reducer = 'entity'
config.embedding = 'deberta'
config.bert = 'microsoft/deberta-base'
# config.device = 0

manager = Manager(config)
loaders = prepare(manager)
x1 = list(loaders[0])[0]
x2 = list(loaders[1])[0]

[2021-09-08 15:46:47,558] INFO (utils.utils) Hyper Parameters are 
scale:demo
batch_size:5
k:5
threshold:-inf
signal_length:100
his_size:50
impr_size:10
lr:0.0001
bert_lr:3e-05
hidden_dim:384
world_size:0
step:0
ascend_history:False
no_dedup:False
diversify:False
granularity:avg
no_sep_his:False
no_order_embed:False
bert:microsoft/deberta-base
[2021-09-08 15:46:47,559] INFO (utils.utils) preparing dataset...
[2021-09-08 15:46:47,564] INFO (utils.MIND) process NO.0 loading cached user behavior from data/cache/deberta/MINDdemo_train/10/behaviors..pkl
[2021-09-08 15:46:47,577] INFO (utils.MIND) process NO.0 loading cached news tokenization from data/cache/deberta/MINDdemo_train/news.pkl
[2021-09-08 15:46:48,481] INFO (utils.utils) deduplicating...
[2021-09-08 15:46:49,979] INFO (utils.MIND) process NO.0 loading cached user behavior from data/cache/deberta/MINDdemo_dev/10/behaviors..pkl
[2021-09-08 15:46:50,016] INFO (utils.MIND) process NO.0 loading cached news tokenization from data/cach

In [None]:
class TTMS(nn.Module):
    def __init__(self, config, embedding, encoderN, encoderU, reducer):
        super().__init__()

        self.scale = config.scale
        self.cdd_size = config.cdd_size
        self.batch_size = config.batch_size
        self.his_size = config.his_size
        self.signal_length = config.signal_length
        self.device = config.device

        self.embedding = embedding
        self.encoderN = encoderN
        self.encoderU = encoderU

        self.reducer = reducer
        self.bert = BERT_Encoder(config)

        self.granularity = config.granularity
        if self.granularity != 'token':
            self.register_buffer('cdd_dest', torch.zeros((self.batch_size, config.impr_size, config.signal_length * config.signal_length)), persistent=False)
            if config.reducer in ["bm25", "entity", "first"]:
                self.register_buffer('his_dest', torch.zeros((self.batch_size, self.his_size, (config.k + 1) * (config.k + 1))), persistent=False)
            else:
                self.register_buffer('his_dest', torch.zeros((self.batch_size, self.his_size, config.signal_length * config.signal_length)), persistent=False)

        self.userProject = nn.Sequential(
            nn.Linear(self.bert.hidden_dim, self.bert.hidden_dim),
            nn.Tanh()
        )

        self.register_buffer('extra_cls_mask', torch.ones(1,1), persistent=False)

        config.name = '__'.join(['ttms', config.embedding, config.encoderN, config.encoderU, config.reducer, config.granularity])


    def clickPredictor(self, cdd_news_repr, user_repr):
        """ calculate batch of click probabolity

        Args:
            cdd_news_repr: news-level representation, [batch_size, cdd_size, hidden_dim]
            user_repr: user representation, [batch_size, 1, hidden_dim]

        Returns:
            score of each candidate news, [batch_size, cdd_size]
        """
        # print(user_repr.mean(), cdd_news_repr.mean(), user_repr.max(), cdd_news_repr.max(), user_repr.sum(), cdd_news_repr.sum())
        # score = F.normalize(cdd_news_repr, dim=-1).matmul(F.normalize(user_repr, dim=-1).transpose(-2,-1)).squeeze(-1)
        score = cdd_news_repr.matmul(user_repr.transpose(-2,-1)).squeeze(-1)
        return score

    def _forward(self,x):
        if self.granularity != 'token':
            batch_size = x['cdd_subword_index'].size(0)
            cdd_size = x['cdd_subword_index'].size(1)

            if self.training:
                if batch_size != self.batch_size:
                    cdd_dest = self.cdd_dest[:batch_size, :cdd_size]
                    his_dest = self.his_dest[:batch_size]
                else:
                    cdd_dest = self.cdd_dest[:, :cdd_size]
                    his_dest = self.his_dest

            # batch_size always equals 1 when evaluating
            else:
                cdd_dest = self.cdd_dest[[0], :cdd_size]
                his_dest = self.his_dest[[0]]

            cdd_subword_index = x['cdd_subword_index'].to(self.device)
            his_subword_index = x['his_subword_index'].to(self.device)
            his_signal_length = his_subword_index.size(-2)
            cdd_subword_index = cdd_subword_index[:, :, :, 0] * self.signal_length + cdd_subword_index[:, :, :, 1]
            his_subword_index = his_subword_index[:, :, :, 0] * his_signal_length + his_subword_index[:, :, :, 1]

            if self.training:
                cdd_subword_prefix = cdd_dest.scatter(dim=-1, index=cdd_subword_index, value=1) * x["cdd_mask"].to(self.device)
            else:
                cdd_subword_prefix = cdd_dest.scatter(dim=-1, index=cdd_subword_index, value=1)
            cdd_subword_prefix = cdd_subword_prefix.view(batch_size, cdd_size, self.signal_length, self.signal_length)

            his_subword_prefix = his_dest.scatter(dim=-1, index=his_subword_index, value=1) * x["his_mask"].to(self.device)
            his_subword_prefix = his_subword_prefix.view(batch_size, self.his_size, his_signal_length, his_signal_length)

            if self.granularity == 'avg':
                # average subword embeddings as the word embedding
                cdd_subword_prefix = F.normalize(cdd_subword_prefix, p=1, dim=-1)
                his_subword_prefix = F.normalize(his_subword_prefix, p=1, dim=-1)

            cdd_attn_mask = cdd_subword_prefix.matmul(x['cdd_attn_mask'].to(self.device).float().unsqueeze(-1)).squeeze(-1)
            his_attn_mask = his_subword_prefix.matmul(x["his_attn_mask"].to(self.device).float().unsqueeze(-1)).squeeze(-1)
            his_refined_mask = None
            if 'his_refined_mask' in x:
                his_refined_mask = his_subword_prefix.matmul(x["his_refined_mask"].to(self.device).float().unsqueeze(-1)).squeeze(-1)

        else:
            cdd_subword_prefix = None
            his_subword_prefix = None
            cdd_attn_mask = x['cdd_attn_mask'].to(self.device)
            his_attn_mask = x["his_attn_mask"].to(self.device)
            his_refined_mask = None
            if 'his_refined_mask' in x:
                his_refined_mask = x["his_refined_mask"].to(self.device)


        cdd_news = x["cdd_encoded_index"].long().to(self.device)
        _, cdd_news_repr = self.bert(
            self.embedding(cdd_news, cdd_subword_prefix), cdd_attn_mask
        )

        his_news = x["his_encoded_index"].long().to(self.device)
        his_news_embedding = self.embedding(his_news, his_subword_prefix)
        his_news_encoded_embedding, his_news_repr = self.encoderN(
            his_news_embedding
        )
        # no need to calculate this if ps_terms are fixed in advance
        if self.reducer.name == 'matching':
            user_repr = self.encoderU(his_news_repr)
        else:
            user_repr = None

        ps_terms, ps_term_mask, kid = self.reducer(his_news_encoded_embedding, his_news_embedding, user_repr, his_news_repr, his_attn_mask, his_refined_mask)

        # append CLS to the entire browsing history, directly deriving user repr
        batch_size = ps_terms.size(0)
        ps_terms = torch.cat([his_news_embedding[:, 0, 0].unsqueeze(1), ps_terms], dim=-2)
        ps_term_mask = torch.cat([self.extra_cls_mask.expand(batch_size, 1), ps_term_mask], dim=-1)
        _, user_cls = self.bert(ps_terms.unsqueeze(1), ps_term_mask.unsqueeze(1))
        user_repr = self.userProject(user_cls)

        return self.clickPredictor(cdd_news_repr, user_repr), kid

    def forward(self,x):
        """
        Decoupled function, score is unormalized click score
        """
        score, kid = self._forward(x)

        if self.training:
            prob = nn.functional.log_softmax(score, dim=1)
        else:
            prob = torch.sigmoid(score)

        return prob, kid

In [3]:
embedding = BERT_Embedding(manager)

encoderN = CNN_Encoder(manager)
# encoderN = RNN_Encoder(manager)
# encoderN = MHA_Encoder(manager)

# encoderU = CNN_User_Encoder(manager)
# encoderU = RNN_User_Encoder(manager)
# encoderU = MHA_User_Encoder(manager)
encoderU = Attention_Pooling(manager)
# encoderU = Average_Pooling(manager)

reducer = Matching_Reducer(manager)
# reducer = Slicing_Reducer(manager)

# ranker = CNN_Ranker(manager)
# ranker = BERT_Onepass_Ranker(manager)
# ranker = BERT_Original_Ranker(manager)

model = TTMS(manager, embedding, encoderN, encoderU, reducer).to(manager.device)

In [4]:
a,b = model(x1)
a,b

(tensor([[-2.3317, -2.3317, -1.3146, -1.3146, -1.3146],
         [-1.6094, -1.6094, -1.6094, -1.6094, -1.6094],
         [-1.6094, -1.6094, -1.6094, -1.6094, -1.6094],
         [-1.9124, -1.9123, -1.9123, -1.9124, -0.8939],
         [-1.6094, -1.6094, -1.6094, -1.6094, -1.6094]],
        grad_fn=<LogSoftmaxBackward>),
 tensor([[[26, 25, 12, 16,  9],
          [ 3, 36,  2,  8, 26],
          [11, 38,  1, 23,  4],
          ...,
          [ 1,  0,  4,  2,  3],
          [ 1,  0,  4,  2,  3],
          [ 1,  0,  4,  2,  3]],
 
         [[22, 21, 27, 14,  1],
          [ 9, 30, 16, 25, 11],
          [ 3, 20, 19, 24, 23],
          ...,
          [31, 24, 28, 29, 25],
          [16, 12,  3,  7, 14],
          [18, 33, 19, 37, 17]],
 
         [[32, 17,  1, 19,  3],
          [ 1, 12,  2,  7, 26],
          [32,  4,  2,  5, 15],
          ...,
          [ 1,  0,  4,  2,  3],
          [ 1,  0,  4,  2,  3],
          [ 1,  0,  4,  2,  3]],
 
         [[ 7, 19, 21, 20, 11],
          [10, 26,

In [15]:
ls = loss(a,target=x1['label'])
ls.backward()

In [16]:
a.grad

tensor([[-0.2000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.2000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.2000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.2000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.2000,  0.0000,  0.0000,  0.0000,  0.0000]])