In [None]:
import re
import logging
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import math
import time
import pickle
from collections import defaultdict
from utils.utils import prepare
from data.configs.demo import config
from collections import defaultdict

from transformers import AutoTokenizer, AutoModel
from utils.Manager import Manager

from models.Embeddings.BERT import BERT_Embedding
from models.Encoders.CNN import CNN_Encoder,CNN_User_Encoder
from models.Encoders.RNN import RNN_Encoder,RNN_User_Encoder
from models.Encoders.MHA import MHA_Encoder, MHA_User_Encoder
from models.Modules.DRM import Matching_Reducer, Slicing_Reducer
from models.Rankers.BERT import BERT_Onepass_Ranker, BERT_Original_Ranker
from models.Rankers.CNN import CNN_Ranker
from models.Encoders.Pooling import Attention_Pooling, Average_Pooling

from models.Encoders.BERT import BERT_Encoder
from models.Encoders.Pooling import *

from models.ESM import ESM
from models.TTMS import TTMS
 
from models.Modules.Attention import MultiheadAttention

loss = nn.NLLLoss()

# m = AutoModel.from_pretrained('microsoft/deberta-base', cache_dir='../../../Data/bert_cache')
t = AutoTokenizer.from_pretrained('microsoft/deberta-base', cache_dir='../../../Data/bert_cache')

In [None]:
# config.reducer = 'entity'
config.embedding = 'deberta'
config.bert = 'microsoft/deberta-base'
# config.device = 0

manager = Manager(config)
loaders = prepare(manager)
x1 = list(loaders[0])[0]
x2 = list(loaders[1])[0]

In [None]:
manager.hidden_dim = 768

In [None]:
embedding = BERT_Embedding(manager)

encoderN = CNN_Encoder(manager)
# encoderN = RNN_Encoder(manager)
# encoderN = MHA_Encoder(manager)

# encoderU = CNN_User_Encoder(manager)
encoderU = RNN_User_Encoder(manager)
# encoderU = MHA_User_Encoder(manager)
# encoderU = Attention_Pooling(manager)
# encoderU = Average_Pooling(manager)

reducer = Matching_Reducer(manager)
# reducer = Slicing_Reducer(manager)

# ranker = CNN_Ranker(manager)
# ranker = BERT_Onepass_Ranker(manager)
# ranker = BERT_Original_Ranker(manager)

model = TTMS(manager, embedding, encoderN, encoderU, reducer).to(manager.device)

In [None]:
manager.load(model,3534)

In [None]:
class Matching_Reducer(nn.Module):
    """
    basic document reducer: topk of each historical article
    """
    def __init__(self, config):
        super().__init__()

        self.name = "matching"

        self.k = config.k
        self.diversify = config.diversify
        self.his_size = config.his_size
        self.embedding_dim = config.embedding_dim

        config.term_num = config.k * config.his_size

        keep_k_modifier = torch.zeros(1, config.signal_length)
        keep_k_modifier[:, :self.k+1] = 1
        self.register_buffer('keep_k_modifier', keep_k_modifier, persistent=False)

        if self.diversify:
            self.newsUserAlign = nn.Linear(config.hidden_dim * 2, config.hidden_dim)
            nn.init.xavier_normal_(self.newsUserAlign.weight)

        if config.threshold != -float('inf'):
            threshold = torch.tensor([config.threshold])
            self.register_buffer('threshold', threshold)

        if config.sep_his:
            config.term_num += (self.his_size - 1)
            self.sep_embedding = nn.Parameter(torch.randn(1, 1, 1, config.embedding_dim))
            self.register_buffer('extra_sep_mask', torch.ones(1, 1, 1), persistent=False)
            nn.init.xavier_normal_(self.sep_embedding)

        if not config.no_order_embed:
            self.order_embedding = nn.Parameter(torch.randn(config.his_size, 1, config.embedding_dim))
            nn.init.xavier_normal_(self.order_embedding)


    def forward(self, news_selection_embedding, news_embedding, user_repr, news_repr, his_attn_mask, his_refined_mask):
        """
        Extract words from news text according to the overall user interest

        Args:
            news_selection_embedding: encoded word-level embedding, [batch_size, his_size, signal_length, hidden_dim]
            news_embedding: word-level news embedding, [batch_size, his_size, signal_length, hidden_dim]
            news_repr: news-level representation, [batch_size, his_size, hidden_dim]
            user_repr: user representation, [batch_size, 1, hidden_dim]

        Returns:
            ps_terms: weighted embedding for personalized terms, [batch_size, term_num, embedding_dim]
            ps_term_mask: attention mask of output terms, [batch_size, term_num]
            kid: the index of personalized terms
        """
        batch_size = news_embedding.size(0)

        # strip off [CLS]
        news_selection_embedding = news_selection_embedding[:, :, 1:]
        news_embedding = news_embedding[:, :, 1:]
        if self.diversify:
            news_user_repr = torch.cat([user_repr.expand(news_repr.size()), news_repr], dim=-1)
            selection_query = self.newsUserAlign(news_user_repr).unsqueeze(-1)
        else:
            selection_query = user_repr.unsqueeze(-1)

        # [bs, hs, sl - 1]
        scores = F.normalize(news_selection_embedding, dim=-1).matmul(F.normalize(selection_query, dim=-2)).squeeze(-1)
        # print(scores[0])
        pad_pos = ~(((his_refined_mask + self.keep_k_modifier)[:, :, 1:]).bool())
        # mask the padded term
        scores = scores.masked_fill(pad_pos, -float('inf'))
        print(scores[0,0])

        score_k, score_kid = scores.topk(dim=-1, k=self.k)

        ps_terms = news_embedding.gather(dim=-2,index=score_kid.unsqueeze(-1).expand(score_kid.size() + (news_embedding.size(-1),)))
        # [bs, hs, k]
        ps_term_mask = his_attn_mask[:, :, 1:].gather(dim=-1, index=score_kid)

        if hasattr(self, 'threshold'):
            mask_pos = score_k < self.threshold
            # ps_terms = personalized_terms * (nn.functional.softmax(score_k.masked_fill(score_k < self.threshold, 0), dim=-1).unsqueeze(-1))
            ps_terms = ps_terms * (score_k.masked_fill(mask_pos, 0).unsqueeze(-1))
            ps_term_mask = ps_term_mask * (~mask_pos)

        else:
            ps_terms = ps_terms * (F.softmax(score_k, dim=-1).unsqueeze(-1))
            # ps_terms = ps_terms * (score_k.unsqueeze(-1))
        if hasattr(self, 'order_embedding'):
            ps_terms += self.order_embedding

        if hasattr(self, 'sep_embedding'):
            ps_terms = torch.cat([ps_terms, self.sep_embedding.expand(batch_size, self.his_size, 1, self.embedding_dim)], dim=-2).view(batch_size, -1, self.embedding_dim)[:, :-1]
            ps_term_mask = torch.cat([ps_term_mask, self.extra_sep_mask.expand(batch_size, self.his_size, 1)], dim=-1).view(batch_size, -1)[:, :-1]
        else:
            ps_terms = ps_terms.view(batch_size, -1, self.embedding_dim)
            ps_term_mask = ps_term_mask.view(batch_size, -1)

        return ps_terms, ps_term_mask, score_kid

reducer = Matching_Reducer(manager)
model = TTMS(manager, embedding, encoderN, encoderU, reducer).to(manager.device)

In [71]:
dest = torch.zeros((config.signal_length * config.signal_length))
index = x2['his_subword_index'][0,0,:,0] * config.signal_length + x2['his_subword_index'][0,0,:,1]
index

tensor([   0,  101,  202,  303,  404,  505,  606,  707,  808,  909, 1010, 1111,
        1212, 1313, 1414, 1515, 1616, 1717, 1818, 1919, 2020, 2121, 2222, 2323,
        2424, 2525, 2626, 2727, 2828, 2929, 3030, 3131, 3232, 3333, 3434, 3535,
        3636, 3737, 3838, 3939, 4040, 4141, 4242, 4343, 4444, 4545, 4646, 4747,
        4748,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0])

In [72]:
dest = F.normalize(dest.scatter(dim=-1,index=index,value=1).view(config.signal_length, config.signal_length),p=1,dim=-1)
dest.matmul(x2['his_refined_mask'][0,0].float())

In [75]:
t.convert_ids_to_tokens(x2['his_encoded_index'][0,0])[:49]

['[CLS]',
 'Donald',
 'ĠTrump',
 'ĠJr',
 '.',
 'Ġreflects',
 'Ġon',
 'Ġexplosive',
 'ĠView',
 'Ġchat',
 ':',
 'ĠI',
 'Ġdont',
 'Ġthink',
 'Ġthey',
 'Ġlike',
 'Ġme',
 'Ġmuch',
 'Ġanymore',
 'ĠAfter',
 'Ġa',
 'Ġheated',
 'Ġappearance',
 'Ġon',
 'ĠThe',
 'ĠView',
 'ĠThursday',
 ',',
 'ĠDonald',
 'ĠTrump',
 'ĠJr',
 '.',
 'Ġtalked',
 'Ġabout',
 'Ġthe',
 'Ġexperience',
 'Ġlater',
 'Ġthat',
 'Ġday',
 'Ġwith',
 'ĠSean',
 'ĠHannity',
 'Ġon',
 'ĠFox',
 'ĠNews',
 '.',
 'Ġtv',
 'Ġtv',
 'news']

In [79]:
model.eval()
a,b = model(x2)
a,b

tensor(0.0150, grad_fn=<SelectBackward>)


(tensor([[0.5936, 0.5504, 0.5749, 0.5547, 0.5725, 0.5211, 0.5547, 0.5392, 0.5201,
          0.5793]], grad_fn=<SigmoidBackward>),
 tensor([[[39, 17, 46, 25, 27],
          [ 8, 11, 43, 41, 16],
          [ 5, 11, 46, 50, 12],
          [14, 13, 37, 35, 31],
          [30,  1, 19, 12, 10],
          [ 4,  9,  3, 17, 14],
          [16, 18,  3, 40,  9],
          [ 6,  3, 28, 16, 10],
          [ 1,  2, 23,  4,  6],
          [22,  6, 20,  2,  9],
          [ 8, 27, 19,  5,  9],
          [16, 74, 57, 71, 31],
          [ 6,  8, 25, 10, 17],
          [ 9, 24, 11, 22, 14],
          [48, 22, 20,  7, 34],
          [ 1,  0,  4,  2,  3],
          [ 1,  0,  4,  2,  3],
          [ 1,  0,  4,  2,  3],
          [ 1,  0,  4,  2,  3],
          [ 1,  0,  4,  2,  3],
          [ 1,  0,  4,  2,  3],
          [ 1,  0,  4,  2,  3],
          [ 1,  0,  4,  2,  3],
          [ 1,  0,  4,  2,  3],
          [ 1,  0,  4,  2,  3],
          [ 1,  0,  4,  2,  3],
          [ 1,  0,  4,  2,  3],
      

In [None]:
ls = loss(a,target=x1['label'])
ls.backward()

In [None]:
a.grad