In [1]:
import re
import logging
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import math
import pickle
from collections import defaultdict
from utils.utils import prepare
from data.configs.demo import config
from collections import defaultdict
from transformers import BertTokenizer,BertModel
from utils.Manager import Manager

from models.Embeddings.BERT import BERT_Embedding
from models.Encoders.CNN import CNN_Encoder,CNN_User_Encoder
from models.Encoders.RNN import RNN_Encoder,RNN_User_Encoder
from models.Encoders.MHA import MHA_Encoder, MHA_User_Encoder
from models.Modules.DRM import Matching_Reducer, BM25_Reducer
from models.Rankers.BERT import BERT_Onepass_Ranker, BERT_Original_Ranker
from models.Rankers.CNN import CNN_Ranker

from models.Encoders.BERT import BERT_Encoder
from models.Encoders.Pooling import *

from models.ESM import ESM
from models.TTMS import TTMS

from models.Modules.Attention import MultiheadAttention

In [3]:
manager = Manager(config)
loaders = prepare(manager)

record = list(loaders[0])[0]

[2021-08-25 15:14:26,642] INFO (utils.utils) Hyper Parameters are 
scale:demo
k:5
threshold:-inf
signal_length:100
his_size:50
impr_size:10
lr:0.0001
hidden_dim:384
world_size:0
step:0
[2021-08-25 15:14:26,643] INFO (utils.utils) preparing dataset...
[2021-08-25 15:14:26,649] INFO (utils.MIND) loading cached user behavior from data/cache/bert/MINDdemo_train/10/behaviors..pkl
[2021-08-25 15:14:26,666] INFO (utils.MIND) loading cached news tokenization from data/cache/bert/MINDdemo_train/news.pkl
[2021-08-25 15:14:27,226] INFO (utils.utils) deduplicating...
[2021-08-25 15:14:28,816] INFO (utils.MIND) loading cached user behavior from data/cache/bert/MINDdemo_dev/10/behaviors..pkl
[2021-08-25 15:14:28,819] INFO (utils.MIND) loading cached news tokenization from data/cache/bert/MINDdemo_dev/news.pkl
[2021-08-25 15:14:29,249] INFO (utils.utils) deduplicating...


In [4]:
class TTMS(nn.Module):
    def __init__(self, config, embedding, encoderN, encoderU):
        super().__init__()

        self.scale = config.scale
        self.cdd_size = config.cdd_size
        self.batch_size = config.batch_size
        self.his_size = config.his_size
        self.device = config.device

        self.embedding = embedding
        self.encoderN = encoderN
        self.encoderU = encoderU

        self.reducer = Matching_Reducer(config)
        self.bert = BERT_Encoder(config)

        config.hidden_dim = config.embedding_dim
        self.aggregate = Attention_Pooling(config)

        self.name = '__'.join(['ttms', self.encoderN.name, self.encoderU.name])
        config.name = self.name

    def clickPredictor(self, cdd_news_repr, user_repr):
        """ calculate batch of click probabolity

        Args:
            cdd_news_repr: news-level representation, [batch_size, cdd_size, hidden_dim]
            user_repr: user representation, [batch_size, 1, hidden_dim]

        Returns:
            score of each candidate news, [batch_size, cdd_size]
        """
        print(user_repr[0])
        print(cdd_news_repr[0][0])

        score = cdd_news_repr.matmul(user_repr.transpose(-2,-1)).squeeze(-1)

        return score

    def _forward(self,x):
        his_news = x["his_encoded_index"].long().to(self.device)
        his_news_embedding = self.embedding(his_news)
        his_news_encoded_embedding, his_news_repr = self.encoderN(
            his_news_embedding
        )

        user_repr = self.encoderU(his_news_repr)

        ps_terms, ps_term_mask = self.reducer(his_news_encoded_embedding, his_news_embedding, user_repr, his_news_repr, x["his_attn_mask"].to(self.device), x["his_attn_mask_k"].to(self.device).bool())

        # append CLS to each historical news, aggregate historical news representation to user repr
        ps_terms = torch.cat([his_news_embedding[:, :, 0].unsqueeze(-2), ps_terms], dim=-2)
        ps_term_mask = torch.cat([torch.ones(*ps_term_mask.shape[0:2], 1, device=ps_term_mask.device), ps_term_mask], dim=-1)
        ps_terms, his_news_repr = self.bert(ps_terms, ps_term_mask)
        user_repr = self.aggregate(his_news_repr)

        # append CLS to the entire browsing history, directly deriving user repr
        # batch_size = ps_terms.size(0)
        # ps_terms = torch.cat([his_news_embedding[:, 0, 0].unsqueeze(1).unsqueeze(1), ps_terms.view(batch_size, 1, -1, ps_terms.size(-1))], dim=-2)
        # ps_term_mask = torch.cat([torch.ones(batch_size, 1, 1, device=ps_term_mask.device), ps_term_mask.view(batch_size, 1, -1)], dim=-1)
        # _, user_repr = self.bert(ps_terms, ps_term_mask)


        cdd_news = x["cdd_encoded_index"].long().to(self.device)
        _, cdd_news_repr = self.bert(
            self.embedding(cdd_news), x['cdd_attn_mask'].to(self.device)
        )

        return self.clickPredictor(cdd_news_repr, user_repr)

    def forward(self,x):
        """
        Decoupled function, score is unormalized click score
        """
        score = self._forward(x)

        if self.training:
            prob = nn.functional.log_softmax(score, dim=1)
        else:
            prob = torch.sigmoid(score)

        return prob

In [5]:
embedding = BERT_Embedding(manager)

encoderN = CNN_Encoder(manager)
# encoderN = RNN_Encoder(manager)
# encoderN = MHA_Encoder(manager)

# encoderU = CNN_User_Encoder(manager)
encoderU = RNN_User_Encoder(manager)
# encoderU = MHA_User_Encoder(manager)

# docReducer = Matching_Reducer(manager)
# docReducer = BM25_Reducer(manager)

# ranker = CNN_Ranker(manager)
# ranker = BERT_Onepass_Ranker(manager)
# ranker = BERT_Original_Ranker(manager)

model = TTMS(manager, embedding, encoderN, encoderU).to(manager.device)

In [None]:
model(record)

tensor([[ 3.9672e-02,  2.4889e-01, -3.2479e-01,  3.2325e-01,  4.9139e-01,
         -3.2715e-01,  3.6095e-01,  3.7133e-01, -2.9406e-01, -3.1607e-01,
          1.2882e-02, -2.6066e-01, -9.8634e-02,  1.0472e-01,  6.2066e-02,
          8.5859e-02,  3.4953e-01, -5.8611e-02,  6.2024e-01,  5.8177e-03,
          2.0220e-01, -4.7170e-01,  5.9268e-02,  2.0477e-01,  7.9676e-02,
         -1.9533e-01,  1.6602e-01,  1.6525e-01, -3.4361e-01, -3.0317e-01,
          3.3046e-01, -4.2683e-01, -1.2095e-01,  2.8960e-01,  5.0703e-02,
         -1.1985e-01,  2.3184e-01, -5.4352e-01, -3.5424e-01,  2.5918e-02,
         -2.4820e-01, -6.1629e-01, -2.4545e-01, -1.2437e-01,  3.0897e-01,
         -2.8026e-01,  3.0536e-02,  7.3582e-02,  5.4243e-01,  3.3926e-02,
         -1.6106e-01,  4.6877e-01,  2.2723e-01, -5.1455e-01,  2.1250e-01,
          1.0438e+00, -3.5191e-01, -2.0370e-01, -1.7700e-02, -5.3558e-01,
          3.6929e-01, -4.9108e-01,  5.4051e-01, -6.5740e-01,  6.7561e-01,
         -2.8046e-02,  3.0360e-01,  1.

tensor([[-3.3258e+01, -2.3292e+01, -3.0447e-01, -1.4518e+00, -3.5635e+00],
        [-1.6287e+00, -7.1115e-01, -1.3718e+00, -8.3284e+00, -2.8328e+00],
        [-2.4748e+01, -1.5640e+01, -1.5555e+01, -1.4681e+01, -7.1526e-07],
        [-2.7020e+01, -1.9071e+01, -2.3484e+01, -2.7617e+01,  0.0000e+00],
        [-3.7350e+01, -2.8984e+01, -3.9298e+01, -2.6822e+01,  0.0000e+00],
        [-1.0016e+01, -1.0926e+01, -1.0405e+00, -4.2831e+00, -4.5748e-01],
        [-6.2562e+00, -1.0601e+01, -4.3730e-01, -2.8990e+00, -1.2133e+00],
        [-1.3679e-02, -7.7812e+00, -8.7475e+00, -5.3554e+00, -4.7931e+00],
        [-1.7614e+01, -4.8645e+00, -1.0341e+01, -7.7940e-03, -1.1084e+01],
        [-1.5301e+01, -2.2115e+01, -2.0745e+01, -1.5321e+01, -4.7684e-07]],
       grad_fn=<LogSoftmaxBackward>)