In [2]:
import re
import logging
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import math
import pickle
from collections import defaultdict
from utils.utils import prepare, convert_tokens_to_words
from data.configs.demo import config
from collections import defaultdict
from transformers import BertTokenizer,BertModel
from utils.Manager import Manager

from models.Embeddings.BERT import BERT_Embedding
from models.Encoders.CNN import CNN_Encoder,CNN_User_Encoder
from models.Encoders.RNN import RNN_Encoder,RNN_User_Encoder
from models.Encoders.MHA import MHA_Encoder, MHA_User_Encoder
from models.Modules.DRM import Matching_Reducer, BM25_Reducer
from models.Rankers.BERT import BERT_Onepass_Ranker, BERT_Original_Ranker
from models.Rankers.CNN import CNN_Ranker

from models.Encoders.BERT import BERT_Encoder
from models.Encoders.Pooling import *

from models.ESM import ESM
from models.TTMS import TTMS

from models.Modules.Attention import MultiheadAttention

In [3]:
# config.reducer = 'bm25'
manager = Manager(config)
loaders = prepare(manager)

record = list(loaders[1])[0]

[2021-08-31 23:12:03,957] INFO (utils.utils) Hyper Parameters are 
scale:demo
batch_size:5
k:5
threshold:-inf
signal_length:100
his_size:50
impr_size:10
lr:0.0001
bert_lr:3e-05
hidden_dim:384
world_size:0
step:0
ascend_history:False
no_dedup:False
diversify:False
granularity:avg
[2021-08-31 23:12:03,959] INFO (utils.utils) preparing dataset...
[2021-08-31 23:12:03,964] INFO (utils.MIND) process NO.0 loading cached user behavior from data/cache/bert/MINDdemo_train/10/behaviors..pkl
[2021-08-31 23:12:03,979] INFO (utils.MIND) process NO.0 loading cached news tokenization from data/cache/bert/MINDdemo_train/news.pkl
[2021-08-31 23:12:04,751] INFO (utils.MIND) reducing news of ../../../Data/MIND/MINDdemo_train/news.tsv...
[2021-08-31 23:12:04,908] INFO (utils.utils) deduplicating...
[2021-08-31 23:12:06,264] INFO (utils.MIND) process NO.0 loading cached user behavior from data/cache/bert/MINDdemo_dev/10/behaviors..pkl
[2021-08-31 23:12:06,267] INFO (utils.MIND) process NO.0 loading cached 

In [4]:
class TTMS(nn.Module):
    def __init__(self, config, embedding, encoderN, encoderU, reducer, aggregator=None):
        super().__init__()

        self.scale = config.scale
        self.cdd_size = config.cdd_size
        self.batch_size = config.batch_size
        self.his_size = config.his_size
        self.signal_length = config.signal_length
        self.device = config.device

        self.embedding = embedding
        self.encoderN = encoderN
        self.encoderU = encoderU

        self.reducer = reducer
        self.bert = BERT_Encoder(config)

        self.aggregator = aggregator

        self.granularity = config.granularity
        if self.granularity != 'token':
            self.register_buffer('cdd_dest', torch.zeros((self.batch_size, config.impr_size, config.signal_length * config.signal_length)), persistent=False)
            if self.reducer.name != 'bm25':
                self.register_buffer('his_dest', torch.zeros((self.batch_size, self.his_size, config.signal_length * config.signal_length)), persistent=False)
            else:
                self.register_buffer('his_dest', torch.zeros((self.batch_size, self.his_size, (config.k + 1) * (config.k + 1))), persistent=False)

        if not aggregator:
            self.userProject = nn.Sequential(
                nn.Linear(self.bert.hidden_dim, self.bert.hidden_dim),
                nn.Tanh()
            )

        self.name = '__'.join(['ttms', self.encoderN.name, self.encoderU.name, config.reducer])
        config.name = self.name

    def clickPredictor(self, cdd_news_repr, user_repr):
        """ calculate batch of click probabolity

        Args:
            cdd_news_repr: news-level representation, [batch_size, cdd_size, hidden_dim]
            user_repr: user representation, [batch_size, 1, hidden_dim]

        Returns:
            score of each candidate news, [batch_size, cdd_size]
        """
        # print(user_repr.mean(), cdd_news_repr.mean(), user_repr.max(), cdd_news_repr.max(), user_repr.sum(), cdd_news_repr.sum())
        score = cdd_news_repr.matmul(user_repr.transpose(-2,-1)).squeeze(-1)
        return score

    def _forward(self,x):
        if self.granularity != 'token':
            batch_size = x['cdd_subword_index'].size(0)
            cdd_size = x['cdd_subword_index'].size(1)

            if self.training:
                if batch_size != self.batch_size:
                    cdd_dest = self.cdd_dest[:batch_size, :cdd_size]
                    his_dest = self.his_dest[:batch_size]
                else:
                    cdd_dest = self.cdd_dest[:, :cdd_size]
                    his_dest = self.his_dest

            # batch_size always equals 1 when evaluating
            else:
                cdd_dest = self.cdd_dest[[0], :cdd_size]
                his_dest = self.his_dest[[0]]

            cdd_subword_index = x['cdd_subword_index'].to(self.device)
            his_subword_index = x['his_subword_index'].to(self.device)
            his_signal_length = his_subword_index.size(-2)
            cdd_subword_index = cdd_subword_index[:, :, :, 0] * self.signal_length + cdd_subword_index[:, :, :, 1]
            his_subword_index = his_subword_index[:, :, :, 0] * his_signal_length + his_subword_index[:, :, :, 1]

            if self.training:
                cdd_subword_prefix = cdd_dest.scatter(dim=-1, index=cdd_subword_index, value=1) * x["cdd_mask"].to(self.device)
            else:
                cdd_subword_prefix = cdd_dest.scatter(dim=-1, index=cdd_subword_index, value=1)
            cdd_subword_prefix = cdd_subword_prefix.view(batch_size, cdd_size, self.signal_length, self.signal_length)

            his_subword_prefix = his_dest.scatter(dim=-1, index=his_subword_index, value=1) * x["his_mask"].to(self.device)
            his_subword_prefix = his_subword_prefix.view(batch_size, self.his_size, his_signal_length, his_signal_length)

            if self.granularity == 'avg':
                # average subword embeddings as the word embedding
                cdd_subword_prefix = F.normalize(cdd_subword_prefix, p=1, dim=-1)
                his_subword_prefix = F.normalize(his_subword_prefix, p=1, dim=-1)

            cdd_attn_mask = cdd_subword_prefix.matmul(x['cdd_attn_mask'].to(self.device).float().unsqueeze(-1)).squeeze(-1)
            his_attn_mask = his_subword_prefix.matmul(x["his_attn_mask"].to(self.device).float().unsqueeze(-1)).squeeze(-1)
            his_refined_mask = None
            if 'his_refined_mask' in x:
                his_refined_mask = his_subword_prefix.matmul(x["his_refined_mask"].to(self.device).float().unsqueeze(-1)).squeeze(-1)

        else:
            cdd_subword_prefix = None
            his_subword_prefix = None
            cdd_attn_mask = x['cdd_attn_mask'].to(self.device)
            his_attn_mask = x["his_attn_mask"].to(self.device)
            his_refined_mask = None
            if 'his_refined_mask' in x:
                his_refined_mask = x["his_refined_mask"].to(self.device)


        cdd_news = x["cdd_encoded_index"].long().to(self.device)
        _, cdd_news_repr = self.bert(
            self.embedding(cdd_news, cdd_subword_prefix), cdd_attn_mask
        )

        his_news = x["his_encoded_index"].long().to(self.device)
        his_news_embedding = self.embedding(his_news, his_subword_prefix)
        his_news_encoded_embedding, his_news_repr = self.encoderN(
            his_news_embedding
        )
        user_repr = self.encoderU(his_news_repr)
        print(his_refined_mask[0][3])
        ps_terms, ps_term_mask, kid = self.reducer(his_news_encoded_embedding, his_news_embedding, user_repr, his_news_repr, his_attn_mask, his_refined_mask)
        print(ps_terms)
        # append CLS to each historical news, aggregator historical news representation to user repr
        if self.aggregator:
            ps_terms = torch.cat([his_news_embedding[:, :, 0].unsqueeze(-2), ps_terms], dim=-2)
            ps_term_mask = torch.cat([torch.ones(*ps_term_mask.shape[0:2], 1, device=ps_term_mask.device), ps_term_mask], dim=-1)
            ps_terms, his_news_repr = self.bert(ps_terms, ps_term_mask)
            user_repr = self.aggregator(his_news_repr)

        # append CLS to the entire browsing history, directly deriving user repr
        else:
            batch_size = ps_terms.size(0)
            ps_terms = torch.cat([his_news_embedding[:, 0, 0].unsqueeze(1).unsqueeze(1), ps_terms.reshape(batch_size, 1, -1, ps_terms.size(-1))], dim=-2)
            ps_term_mask = torch.cat([torch.ones(batch_size, 1, 1, device=ps_term_mask.device), ps_term_mask.reshape(batch_size, 1, -1)], dim=-1)
            _, user_cls = self.bert(ps_terms, ps_term_mask)
            user_repr = self.userProject(user_cls)

        return self.clickPredictor(cdd_news_repr, user_repr), kid

    def forward(self,x):
        """
        Decoupled function, score is unormalized click score
        """
        score, kid = self._forward(x)

        if self.training:
            prob = nn.functional.log_softmax(score, dim=1)
        else:
            prob = torch.sigmoid(score)

        return prob, kid

In [7]:
embedding = BERT_Embedding(manager)

encoderN = CNN_Encoder(manager)
# encoderN = RNN_Encoder(manager)
# encoderN = MHA_Encoder(manager)

# encoderU = CNN_User_Encoder(manager)
encoderU = RNN_User_Encoder(manager)
# encoderU = MHA_User_Encoder(manager)

reducer = Matching_Reducer(manager)
# reducer = BM25_Reducer(manager)

# ranker = CNN_Ranker(manager)
# ranker = BERT_Onepass_Ranker(manager)
# ranker = BERT_Original_Ranker(manager)

model = TTMS(manager, embedding, encoderN, encoderU, reducer).to(manager.device)

In [11]:
t = BertTokenizer.from_pretrained('bert-base-uncased')

In [10]:
record['his_encoded_index'][0,3], record['his_refined_mask'][0,3]

(tensor([  101,  5606,  1997,  5190,  1997,  2111,  1999,  2662,  2024,  2091,
         24352,  1997,  1037,  5477,  2008,  1005,  2071,  8246,  1005,  5606,
          1997,  5190,  1997,  2111,  2444,  2091, 24352,  2013,  1037,  5477,
          1999,  2662,  2008,  3728,  2018,  2049,  3891, 23191,  2904,  1000,
          2013,  2659,  2000,  2152, 19353,  1997,  2895,  1000,  2011,  1996,
          2149,  2390,  3650,  1997,  6145,  1012,  2739,  2739,  2271,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0]),
 tensor([1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
         1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0,
         1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0

In [18]:
t.convert_ids_to_tokens(record['his_encoded_index'][0,3])

['[CLS]',
 'hundreds',
 'of',
 'thousands',
 'of',
 'people',
 'in',
 'california',
 'are',
 'down',
 '##river',
 'of',
 'a',
 'dam',
 'that',
 "'",
 'could',
 'fail',
 "'",
 'hundreds',
 'of',
 'thousands',
 'of',
 'people',
 'live',
 'down',
 '##river',
 'from',
 'a',
 'dam',
 'in',
 'california',
 'that',
 'recently',
 'had',
 'its',
 'risk',
 'characterization',
 'changed',
 '"',
 'from',
 'low',
 'to',
 'high',
 'urgency',
 'of',
 'action',
 '"',
 'by',
 'the',
 'us',
 'army',
 'corps',
 'of',
 'engineers',
 '.',
 'news',
 'news',
 '##us',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]']

In [17]:
tokens = t.decode(record['his_encoded_index'][0,3])
tokens

'[CLS] hundreds of thousands of people in california are downriver of a dam that\'could fail\'hundreds of thousands of people live downriver from a dam in california that recently had its risk characterization changed " from low to high urgency of action " by the us army corps of engineers. news newsus [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [15]:
record['his_subword_index'][0,3]

tensor([[ 0,  0],
        [ 1,  1],
        [ 2,  2],
        [ 3,  3],
        [ 4,  4],
        [ 5,  5],
        [ 6,  6],
        [ 6,  7],
        [ 6,  8],
        [ 6,  9],
        [ 6, 10],
        [ 7, 11],
        [ 8, 12],
        [ 9, 13],
        [10, 14],
        [11, 15],
        [12, 16],
        [13, 17],
        [14, 18],
        [15, 19],
        [16, 20],
        [17, 21],
        [18, 22],
        [19, 23],
        [20, 24],
        [21, 25],
        [22, 26],
        [23, 27],
        [24, 28],
        [25, 29],
        [26, 30],
        [27, 31],
        [28, 32],
        [29, 33],
        [30, 34],
        [31, 35],
        [32, 36],
        [33, 37],
        [34, 38],
        [35, 39],
        [36, 40],
        [37, 41],
        [38, 42],
        [39, 43],
        [39, 44],
        [40, 45],
        [40, 46],
        [41, 47],
        [42, 48],
        [43, 49],
        [44, 50],
        [45, 51],
        [46, 52],
        [47, 53],
        [48, 54],
        [4

In [13]:
tokens

'[CLS] hundreds of thousands of people in california are downriver of a dam that\'could fail\'hundreds of thousands of people live downriver from a dam in california that recently had its risk characterization changed " from low to high urgency of action " by the us army corps of engineers. news newsus [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [8]:
model.eval()
x = model(record)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000, 1.0000, 0.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        0.0000, 1.0000, 1.0000, 1.0000, 0.5000, 0.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 0.0000, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000])
tensor([[[[ 0.0915,  0.3206,  0.0029,  ...,  0.0651, -0.1172, -0.1191],
          [ 0.1801,  0.0201, -0

In [None]:
x

In [None]:
record['cdd_mask']

In [None]:
c[5][0]

In [None]:
c[6][0].sum()

In [None]:
record['his_mask'][0].sum()

In [None]:
words = convert_tokens_to_words(t.convert_ids_to_tokens(record['cdd_encoded_index'][1,0]))
words = [i for i in words if i!='[PAD]']

for i,j in enumerate(c[0][1].matmul(record['cdd_reduced_mask'][1].float().unsqueeze(-1)).squeeze(-1)[0]):

    if j == 0 and i < 40:
        print(i)
        print(words[i])

len(words), words

In [None]:
c[0][1].matmul(record['cdd_reduced_mask'][1].float().unsqueeze(-1)).squeeze(-1)

In [None]:
a = torch.zeros(3,1)
a[-3:] = 1
a

In [None]:
record['cdd_mask'], record['cdd_id']

In [None]:
c[0][1]#, record['cdd_subword_index'][1]

In [None]:
c[0].matmul(record['cdd_attn_mask'].float().unsqueeze(-1)).squeeze(-1)[1]

In [None]:
record['cdd_attn_mask']

In [None]:
t = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
tokens = t.convert_ids_to_tokens(record['his_encoded_index'][0][0])
tokens

In [None]:
d = c.matmul(record['cdd_attn_mask'].to(config.device).float().unsqueeze(-1)).squeeze(-1)[0][1]

In [None]:
for i,tok in enumerate(tokens):
    if tok.startswith('##'):
        print(i)
[i for i in tokens if i!='[PAD]']

In [None]:
record['his_subword_index'][0][0]

In [None]:
c[1][0][0][29]

In [None]:
for i,j in enumerate(c[3][0,0]):
    if j == 0 and i < 30:
        print(i)
        print(words[i])
c[2][0][0], c[3][0,0]