In [40]:
import re
import logging
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import math
import pickle
from collections import defaultdict
from utils.utils import prepare
from data.configs.demo import config
from collections import defaultdict
from transformers import BertTokenizer,BertModel
from utils.Manager import Manager

from models.Embeddings.BERT import BERT_Embedding
from models.Encoders.CNN import CNN_Encoder,CNN_User_Encoder
from models.Encoders.RNN import RNN_Encoder,RNN_User_Encoder
from models.Encoders.MHA import MHA_Encoder, MHA_User_Encoder
from models.Modules.DRM import Matching_Reducer, BM25_Reducer
from models.Rankers.BERT import BERT_Onepass_Ranker, BERT_Original_Ranker
from models.Rankers.CNN import CNN_Ranker

from models.Encoders.BERT import BERT_Encoder
from models.Encoders.Pooling import *

from models.ESM import ESM
from models.TTMS import TTMS

from models.Modules.Attention import MultiheadAttention

In [None]:
manager = Manager(config)
loaders = prepare(manager)

record = list(loaders[0])[0]

In [34]:
class Matching_Reducer(nn.Module):
    """
    basic document reducer: topk of each historical article
    """
    def __init__(self, config):
        super().__init__()

        self.name = "matching"

        self.k = config.k
        self.diversify = config.diversify

        config.term_num = config.k * config.his_size

        if self.diversify:
            self.newsUserAlign = nn.Linear(config.hidden_dim * 2, config.hidden_dim)
            nn.init.xavier_normal_(self.newsUserAlign.weight)

        if config.threshold != -float('inf'):
            threshold = torch.tensor([config.threshold])
            self.register_buffer('threshold', threshold)

    def forward(self, news_selection_embedding, news_embedding, user_repr, news_repr, his_attn_mask, his_attn_mask_k):
        """
        Extract words from news text according to the overall user interest

        Args:
            news_selection_embedding: encoded word-level embedding, [batch_size, his_size, signal_length, hidden_dim]
            news_embedding: word-level news embedding, [batch_size, his_size, signal_length, hidden_dim]
            news_repr: news-level representation, [batch_size, his_size, hidden_dim]
            user_repr: user representation, [batch_size, 1, hidden_dim]

        Returns:
            ps_terms: weighted embedding for personalized terms, [batch_size, his_size, k, hidden_dim]
            ps_term_mask: attention mask of output terms, [batch_size, his_size, k]
        """
        # strip off [CLS]
        news_selection_embedding = news_selection_embedding[:, :, 1:]
        news_embedding = news_embedding[:, :, 1:]
        if self.diversify:
            news_user_repr = torch.cat([user_repr.expand(news_repr.size()), news_repr], dim=-1)
            selection_query = self.newsUserAlign(news_user_repr).unsqueeze(-1)
        else:
            selection_query = user_repr.expand(news_repr.size()).unsqueeze(-1)

        # [bs, hs, sl - 1]
        scores = F.normalize(news_selection_embedding, dim=-1).matmul(F.normalize(selection_query, dim=-1)).squeeze(-1)
        # mask the padded term
        scores = scores.masked_fill(~his_attn_mask_k[:, :, 1:], -float('inf'))

        score_k, score_kid = scores.topk(dim=-1, k=self.k, sorted=False)

        ps_terms = news_embedding.gather(dim=-2,index=score_kid.unsqueeze(-1).expand(score_kid.size() + (news_embedding.size(-1),)))
        # [bs, hs, k]
        ps_term_mask = his_attn_mask[:, :, 1:].gather(dim=-1, index=score_kid)

        if hasattr(self, 'threshold'):
            mask_pos = score_k < self.threshold
            # ps_terms = personalized_terms * (nn.functional.softmax(score_k.masked_fill(score_k < self.threshold, 0), dim=-1).unsqueeze(-1))
            ps_terms = ps_terms * (score_k.masked_fill(mask_pos, 0).unsqueeze(-1))
            ps_term_mask = ps_term_mask * (~mask_pos)

        else:
            ps_terms = ps_terms * (F.softmax(score_k, dim=-1).unsqueeze(-1))
            # ps_terms = ps_terms * (score_k.unsqueeze(-1))
        return ps_terms, ps_term_mask, score_kid

In [37]:
class TTMS(nn.Module):
    def __init__(self, config, embedding, encoderN, encoderU, reducer, aggregator=None):
        super().__init__()

        self.scale = config.scale
        self.cdd_size = config.cdd_size
        self.batch_size = config.batch_size
        self.his_size = config.his_size
        self.device = config.device

        self.embedding = embedding
        self.encoderN = encoderN
        self.encoderU = encoderU

        self.reducer = reducer
        self.bert = BERT_Encoder(config)

        self.aggregator = aggregator

        if not aggregator:
            self.userProject = nn.Sequential(
                nn.Linear(self.bert.hidden_dim, self.bert.hidden_dim),
                nn.Tanh()
            )

        self.name = '__'.join(['ttms', self.encoderN.name, self.encoderU.name, config.reducer])
        config.name = self.name

    def clickPredictor(self, cdd_news_repr, user_repr):
        """ calculate batch of click probabolity

        Args:
            cdd_news_repr: news-level representation, [batch_size, cdd_size, hidden_dim]
            user_repr: user representation, [batch_size, 1, hidden_dim]

        Returns:
            score of each candidate news, [batch_size, cdd_size]
        """
        # print(user_repr.mean(), cdd_news_repr.mean(), user_repr.max(), cdd_news_repr.max(), user_repr.sum(), cdd_news_repr.sum())
        score = cdd_news_repr.matmul(user_repr.transpose(-2,-1)).squeeze(-1)
        return score

    def _forward(self,x):
        cdd_subword_prefix = F.normalize(x["cdd_subword_prefix"].to(self.device), p=1, dim=-1)
        his_subword_prefix = F.normalize(x["his_subword_prefix"].to(self.device), p=1, dim=-1)
        if self.reducer.name == 'matching':
            his_news = x["his_encoded_index"].long().to(self.device)
            his_news_embedding = self.embedding(his_news, his_subword_prefix)
            his_news_encoded_embedding, his_news_repr = self.encoderN(
                his_news_embedding
            )
            user_repr = self.encoderU(his_news_repr)

            # his_attn_mask = his_subword_prefix.matmul(x["his_attn_mask"].to(self.device).float().unsqueeze(-1)).squeeze(-1)
            # his_reduced_mask = his_subword_prefix.matmul(x["his_reduced_mask"].to(self.device).float().unsqueeze(-1)).squeeze(-1)
            ps_terms, ps_term_mask, kid = self.reducer(his_news_encoded_embedding, his_news_embedding, user_repr, his_news_repr, x["his_attn_mask"].to(self.device), x["his_reduced_mask"].to(self.device).bool())
            
        elif self.reducer.name == 'bow':
            his_reduced_news = x["his_reduced_index"].long().to(self.device)
            his_news_embedding = self.embedding(his_reduced_news, bow=True)
            his_reduced_encoded_embedding, his_reduced_repr = self.encoderN(his_news_embedding)
            user_repr = self.encoderU(his_reduced_repr)
            ps_terms, ps_term_mask, kid = self.reducer(his_reduced_encoded_embedding, his_news_embedding, user_repr, his_reduced_repr, x["his_attn_mask"].to(self.device))
            del user_repr, his_reduced_encoded_embedding, his_reduced_repr

        elif self.reducer.name == 'bm25':
            his_news = x["his_reduced_index"].long().to(self.device)
            his_news_embedding = self.embedding(his_news)
            his_news_encoded_embedding, his_news_repr = self.encoderN(
                his_news_embedding
            )

            kid = None
            user_repr = None
            ps_terms, ps_term_mask = self.reducer(his_news_encoded_embedding, his_news_embedding, user_repr, his_news_repr, x["his_reduced_mask"].to(self.device))

        # append CLS to each historical news, aggregator historical news representation to user repr
        if self.aggregator:
            ps_terms = torch.cat([his_news_embedding[:, :, 0].unsqueeze(-2), ps_terms], dim=-2)
            ps_term_mask = torch.cat([torch.ones(*ps_term_mask.shape[0:2], 1, device=ps_term_mask.device), ps_term_mask], dim=-1)
            ps_terms, his_news_repr = self.bert(ps_terms, ps_term_mask)
            user_repr = self.aggregator(his_news_repr)

        # append CLS to the entire browsing history, directly deriving user repr
        else:
            batch_size = ps_terms.size(0)
            ps_terms = torch.cat([his_news_embedding[:, 0, 0].unsqueeze(1).unsqueeze(1), ps_terms.reshape(batch_size, 1, -1, ps_terms.size(-1))], dim=-2)
            ps_term_mask = torch.cat([torch.ones(batch_size, 1, 1, device=ps_term_mask.device), ps_term_mask.reshape(batch_size, 1, -1)], dim=-1)
            _, user_cls = self.bert(ps_terms, ps_term_mask)
            user_repr = self.userProject(user_cls)

        cdd_news = x["cdd_encoded_index"].long().to(self.device)
        _, cdd_news_repr = self.bert(
            self.embedding(cdd_news, cdd_subword_prefix), x['cdd_attn_mask'].to(self.device)#cdd_subword_prefix.matmul(x['cdd_attn_mask'].to(self.device).float().unsqueeze(-1)).squeeze(-1)
        )

        return self.clickPredictor(cdd_news_repr, user_repr), kid

    def forward(self,x):
        """
        Decoupled function, score is unormalized click score
        """
        score, kid = self._forward(x)

        if self.training:
            prob = nn.functional.log_softmax(score, dim=1)
        else:
            prob = torch.sigmoid(score)

        return prob, kid

In [38]:
reducer = Matching_Reducer(manager)
model = TTMS(manager, embedding, encoderN, encoderU, reducer).to(manager.device)

In [39]:
model(record)

tensor([[[[ 2.8471e-01,  4.7919e-02,  2.6784e-02,  ...,  0.0000e+00,
            1.6809e-01,  3.6454e-01],
          [ 0.0000e+00, -3.0080e-01, -3.4460e-01,  ..., -3.3522e-02,
           -1.8364e-01, -1.5247e-02],
          [-0.0000e+00, -9.3748e-02,  0.0000e+00,  ...,  6.2561e-02,
           -0.0000e+00,  1.0163e-01],
          [-0.0000e+00,  3.0413e-02,  7.0649e-02,  ...,  5.8038e-02,
           -1.5277e-01,  0.0000e+00],
          [ 1.0416e-01, -5.7094e-02,  8.2009e-03,  ..., -3.2584e-02,
            0.0000e+00,  4.0886e-02]],

         [[ 1.8386e-01,  4.0401e-01, -1.5275e-01,  ...,  2.5630e-01,
            3.8620e-02, -2.3655e-01],
          [-1.6229e-01, -2.4824e-02,  1.0714e-01,  ...,  3.1961e-01,
           -2.6707e-01, -2.2870e-01],
          [-0.0000e+00,  8.5058e-02,  6.0091e-02,  ..., -1.8317e-02,
            0.0000e+00, -6.7271e-02],
          [ 1.2432e-01, -1.9674e-02, -2.7509e-01,  ..., -6.3117e-02,
           -2.3561e-02,  6.8449e-02],
          [-1.1163e-01,  7.8947e-02

(tensor([[-0.1970, -1.9182, -5.6736, -4.7352, -3.9270],
         [-1.4232, -1.6907, -1.0101, -1.7027, -3.5647],
         [-3.5522, -1.8727, -2.6749, -0.3170, -3.8924],
         [-0.8469, -1.1921, -1.5415, -2.9597, -6.3218],
         [-2.1022, -4.6903, -6.2862, -0.7599, -0.9186]],
        grad_fn=<LogSoftmaxBackward>),
 tensor([[[ 7,  5, 21, 29,  1],
          [28, 47,  6, 19, 16],
          [ 7, 64, 49, 55, 16],
          ...,
          [ 0,  1,  4,  2,  3],
          [ 0,  1,  2,  4,  3],
          [ 0,  1,  4,  2,  3]],
 
         [[37,  7, 10, 26, 20],
          [40, 14,  7, 23, 81],
          [ 5, 47, 10, 17,  8],
          ...,
          [75, 63, 45, 32, 14],
          [22, 20, 50, 47, 49],
          [35, 28, 19, 34, 20]],
 
         [[64,  0, 11, 26, 34],
          [26,  8, 25, 10, 22],
          [ 6,  0,  2, 37,  7],
          ...,
          [ 0,  1,  2,  4,  3],
          [ 0,  1,  2,  4,  3],
          [ 1,  0,  4,  3,  2]],
 
         [[ 5, 10,  0,  7, 14],
          [25,  0,

In [None]:
embedding = BERT_Embedding(manager)

encoderN = CNN_Encoder(manager)
# encoderN = RNN_Encoder(manager)
# encoderN = MHA_Encoder(manager)

# encoderU = CNN_User_Encoder(manager)
encoderU = RNN_User_Encoder(manager)
# encoderU = MHA_User_Encoder(manager)

reducer = Matching_Reducer(manager)
# reducer = BM25_Reducer(manager)

# ranker = CNN_Ranker(manager)
# ranker = BERT_Onepass_Ranker(manager)
# ranker = BERT_Original_Ranker(manager)