# PPRec Training Modules, BPE Loss

In [1]:
""" Modules """

import torch
from torch import nn
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter


"""  Knowledge Aware News Encoder which uses self attention and cross attention modules    """
class KnowledgeAwareNewsEncoder(nn.Module):
    def __init__(self,hparams,
        word2vec_embedding=None,
        seed=None,
        **kwargs,):
        super().__init__()
        

        self.word_self_attention = torch.nn.MultiheadAttention(hparams.embed_dim,hparams.head_num,batch_first=True)
        self.entity_self_attention = torch.nn.MultiheadAttention(hparams.embed_dim,hparams.head_num,batch_first=True)
        self.word_cross_attention = torch.nn.MultiheadAttention(hparams.embed_dim,hparams.head_num, batch_first=True)
        self.entity_cross_attention = torch.nn.MultiheadAttention(hparams.embed_dim,hparams.head_num, batch_first=True)

        
        self.word2vec = nn.Embedding.from_pretrained(word2vec_embedding)
        self.entity2vec = nn.Embedding.from_pretrained(word2vec_embedding)
        self.final_attention_layer = torch.nn.MultiheadAttention(hparams.embed_dim,hparams.head_num, batch_first=True)
        
        
        

    def forward(self, words, entities):
        word_embeddings = self.word2vec(words)
        entity_embeddings =  self.entity2vec(entities)
        word_embeddings = torch.reshape(word_embeddings,(word_embeddings.shape[0],word_embeddings.shape[1]*word_embeddings.shape[2],word_embeddings.shape[3]))
        entity_embeddings = torch.reshape(entity_embeddings,(entity_embeddings.shape[0],entity_embeddings.shape[1]*entity_embeddings.shape[2],entity_embeddings.shape[3]))
        
        """ Word level self attention """
        word_self_attn_output,_ = self.word_self_attention(word_embeddings, word_embeddings, word_embeddings)
        
        """ Entity(in this case NER clusters) level self attention """
        entity_self_attn_output,_ = self.entity_self_attention(entity_embeddings, entity_embeddings, entity_embeddings)
        
        """ Cross attention between words and entities   """
        word_cross_output,_ = self.word_cross_attention(word_embeddings,entity_embeddings,entity_embeddings)
        entity_cross_output,_ = self.word_cross_attention(entity_embeddings, word_embeddings, word_embeddings)
        

        word_output = torch.add(word_self_attn_output,word_cross_output)
        entity_output = torch.add(entity_self_attn_output,entity_cross_output)
        news_encoder,_ = self.final_attention_layer(word_output, entity_output, entity_output)
        return news_encoder



class TimeAwarePopularityEncoder(nn.Module):
    def __init__(self,word2vec_embedding=None,
        seed=None,
        **kwargs,):
        super(TimeAwarePopularityEncoder, self).__init__()
        self.word2vec = nn.Embedding.from_pretrained(word2vec_embedding)
        self.news_model = nn.Sequential(
          nn.Linear(768,256),
          nn.Tanh(),
          nn.Linear(256,256),
          nn.Tanh(),
          nn.Linear(256,128),
          nn.Tanh(),
          nn.Linear(128,1,bias=False)
        )
        self.dense = nn.Linear(30,1)
        
        
        self.recency_model = nn.Sequential(
            nn.Linear(768,64),
            nn.Tanh(),
            nn.Linear(64,64),
            nn.Tanh(),
            nn.Linear(64,1,bias=False)
        )
        self.gate = nn.Sequential(
            nn.Linear(31,128),
            nn.Tanh(),
            nn.Linear(128,64),
            nn.Tanh(),
            nn.Linear(64,1),
            nn.Sigmoid()
        )
        self.ctr_model = nn.Sigmoid()
        self.combined_embed = nn.Linear(1,1)


    def forward(self,news, recency, ctr):
        news_embed = self.word2vec(news)
        recency_embed = self.word2vec(recency)
        ctr_embed = self.word2vec(ctr)
        content_score = self.news_model(news_embed)
        recency_score = self.recency_model(recency_embed)
        recency_tensor = recency.unsqueeze(-1)

        combined_input = torch.cat([news,recency_tensor],2)
        combined_input = combined_input.to(torch.float32)

        combined_score = self.gate(combined_input)
        final_content_score = content_score.squeeze(-1)
        final_content_score = self.dense(final_content_score)
        
        combined_prefinal_score = (1-combined_score)*recency_score+combined_score*final_content_score
        ctr_score = self.ctr_model(ctr_embed)

        combined_final_score = self.combined_embed(combined_prefinal_score)
        return ctr_score+combined_final_score
    
class ContentPopularityJointAttention(nn.Module):
    """

    Implementation of the content-popularity joint attention module
    for the popularity-aware user encoder.

    This is based on formula (2) in 3.4 of the paper.

    """

    def __init__(self, max_clicked: int, m_size: int, p_size: int, weight_size: int):
        super().__init__()
        self.Wu = nn.Parameter(torch.rand(weight_size, m_size + p_size))
        self.b = nn.Parameter(torch.rand(weight_size))

        self.weight_size = weight_size
        self.m_size = m_size
        self.p_size = p_size
        self.max_clicked = max_clicked

    def forward(self, m: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
        """

        Calculates the user interest embeddings u, based on the
        the popularity embeddings p, and the contextual news
        representations m.

        m is a tensor of shape (batch_size, max_clicked, m_size)
        p is a tensor of shape (batch_size, max_clicked, p_size)
        u is a tensor of shape (batch_size, m_size)
        where max_clicked is the number of clicked articles by the user.

        """

        assert len(m.size()) == 3
        batch_size, max_clicked, m_size = m.size()
        assert m_size == self.m_size
        assert max_clicked == self.max_clicked

        assert len(p.size()) == 3
        assert p.size(0) == batch_size
        assert p.size(1) == max_clicked
        assert p.size(2) == self.p_size

        mp = torch.cat((m, p), dim=2)  # (batch_size, max_clicked, m_size + p_size)
        assert len(mp.size()) == 3
        assert mp.size(0) == batch_size
        assert mp.size(1) == max_clicked
        assert mp.size(2) == self.m_size + self.p_size

        Wu_mp = torch.matmul(mp, self.Wu.T)  # (batch_size, max_clicked, weight_size)
        assert len(Wu_mp.size()) == 3
        assert Wu_mp.size(0) == batch_size
        assert Wu_mp.size(1) == max_clicked
        assert Wu_mp.size(2) == self.weight_size

        tanh_Wu_mp = torch.tanh(Wu_mp)  # (batch_size, max_clicked, weight_size)
        assert len(tanh_Wu_mp.size()) == 3
        assert tanh_Wu_mp.size(0) == batch_size
        assert tanh_Wu_mp.size(1) == max_clicked
        assert tanh_Wu_mp.size(2) == self.weight_size

        b_tanh_Wu_mp = torch.matmul(tanh_Wu_mp, self.b)  # (batch_size, max_clicked)
        assert len(b_tanh_Wu_mp.size()) == 2
        assert b_tanh_Wu_mp.size(0) == batch_size
        assert b_tanh_Wu_mp.size(1) == max_clicked

        sum_b_tanh_Wu_mp = torch.sum(b_tanh_Wu_mp, dim=1)  # (batch_size)
        assert len(sum_b_tanh_Wu_mp.size()) == 1
        assert sum_b_tanh_Wu_mp.size(0) == batch_size

        a = torch.div(
            b_tanh_Wu_mp, sum_b_tanh_Wu_mp.unsqueeze(1)
        )  # (batch_size, max_clicked)
        assert len(a.size()) == 2
        assert a.size(0) == batch_size
        assert a.size(1) == max_clicked

        am = torch.mul(a.unsqueeze(2), m)  # (batch_size, max_clicked, m_size)
        assert len(am.size()) == 3
        assert am.size(0) == batch_size
        assert am.size(1) == max_clicked
        assert am.size(2) == self.m_size

        u = torch.sum(am, dim=1)  # (batch_size, m_size)
        assert len(u.size()) == 2
        assert u.size(0) == batch_size
        assert u.size(1) == self.m_size

        return u
    

class PopularityAwareUserEncoder(nn.Module):
    def __init__(self,
                 hparams,
        word2vec_embedding=None,
                 seed=None,
                **kwargs,):
                 super().__init__()

                 self.word2vec = nn.Embedding.from_pretrained(word2vec_embedding)
        
                 self.pop_embed = nn.Embedding.from_pretrained(word2vec_embedding)
                 self.news_self_attention = torch.nn.MultiheadAttention(hparams.embed_dim,hparams.head_num, batch_first=True)
                 self.cpja = ContentPopularityJointAttention(hparams.max_clicked, hparams.m_size, hparams.p_size,hparams.weight_size)
                 self.max_clicked = hparams.max_clicked
                 self.title_length = hparams.title_size

    def forward(self,news,popularity):
        
        
        popularity_embedding = self.pop_embed(popularity)
        popularity_embedding = popularity_embedding.squeeze(axis=2)
        news_embedding = self.word2vec(news)
        news_embedding = torch.reshape(news_embedding,(news_embedding.shape[0],news_embedding.shape[1]*news_embedding.shape[2],news_embedding.shape[3]))
        news_attention_embedding,_ = self.news_self_attention(news_embedding,news_embedding,news_embedding)

        news_attention_embedding = torch.reshape(news_attention_embedding,(news_attention_embedding.shape[0],self.max_clicked,self.title_length,news_attention_embedding.shape[2]))
        news_attention_embedding = torch.mean(news_attention_embedding, dim=2, keepdim=False)
        
        pop_aware_user_encoder = self.cpja(news_attention_embedding,popularity_embedding)
        return pop_aware_user_encoder
    


class PPRec(nn.Module):
    """

    Implementation of PPRec. Figure 2 in the paper shows the architecture.
    Outputs a ranking score for some candidate news articles.

    """

    def __init__(
        self,
        hparams_pprec,
        word2vec_embedding= None 
    ):

        super().__init__()

        self.knowledge_news_model  = KnowledgeAwareNewsEncoder(hparams_pprec,torch.from_numpy(word2vec_embedding),seed=123)
        self.user_model = PopularityAwareUserEncoder(hparams_pprec, word2vec_embedding=torch.from_numpy(word2vec_embedding), seed=123)
        self.time_news_model = TimeAwarePopularityEncoder(word2vec_embedding=torch.from_numpy(word2vec_embedding), seed=123)

        
        self.aggregator_gate = nn.Sequential(
            nn.Linear(5,5),
            nn.Sigmoid()
        )
        self.title_size = hparams_pprec.title_size
        self.softmax = nn.Softmax(dim=1)

    def forward(
        self,
        title, entities, ctr, recency, hist_title, hist_popularity
    ):
        """

        Returns the ranking scores for a batch of candidate news articles, given the user's
        past click history.

        """

        
        knowledge_news_embed = self.knowledge_news_model(title, entities)
        
        time_aware_pop = self.time_news_model(title, recency, ctr)
        
        user_embed = self.user_model(hist_title,hist_popularity)
        
        time_aware_pop = torch.mean(time_aware_pop, dim=2, keepdim=False)
        knowledge_news_embed = torch.reshape(knowledge_news_embed, (knowledge_news_embed.shape[0],int(knowledge_news_embed.shape[1]/self.title_size), self.title_size,knowledge_news_embed.shape[2]))
        knowledge_news_embed = torch.mean(knowledge_news_embed, dim=2, keepdim=False)
        
        personalized_score = torch.matmul(knowledge_news_embed,user_embed.T)
        
        score1 =  self.aggregator_gate(time_aware_pop) 
        
        personalized_score = torch.mean(personalized_score,dim=2,keepdim=False)
        score2 =  (1-self.aggregator_gate(personalized_score))
        
        score = score1 + score2
        return self.softmax(score)

        
class BPELoss(nn.Module):
    def __init__(self):
        super(BPELoss, self).__init__()
        self.sigmoid = nn.Sigmoid()

    def forward(self, output, target): 
         
         batch_size = target.shape[0]
         total_no_samples = target.shape[1]
         
         
         mask = target > 0
         postive_index_select = torch.masked_select(output, mask)
         
         neg_mask = target == 0
         negative_index_select = torch.masked_select(output, neg_mask)
         negative_index_select = torch.reshape(negative_index_select,(batch_size,total_no_samples-1))
         
         negative_index_select,_ = torch.min(negative_index_select, dim=1, keepdim = True)
         diff = torch.sub(postive_index_select, negative_index_select)
         diff_sig = self.sigmoid(diff)
         diff_log = torch.log(diff_sig)
         return - torch.mean(diff_log)
         


def train_one_epoch(epoch_index, tb_writer, train_dataloader,optimizer,model,loss_fn,device):
    running_loss = 0.
    last_loss = 0.

    
    for i, data in enumerate(train_dataloader):
        # Every data instance is an input + label pair
        
        inputs, labels = data
        # Zero your gradients for every batch!
        optimizer.zero_grad()

        title = inputs[5]
        entities = inputs[6]
        ctr = inputs[7]
        recency = inputs[8]
        hist_title = inputs[0]
        hist_popularity = inputs[2]
        
        title = torch.from_numpy(title)
        entities = torch.from_numpy(entities)
        ctr = torch.from_numpy(ctr)
        recency = torch.from_numpy(recency)
        hist_title = torch.from_numpy(hist_title)
        hist_popularity = torch.from_numpy(hist_popularity)
        labels = torch.from_numpy(labels)
        
        title = title.to(device)
        entities = entities.to(device)
        ctr = ctr.to(device)
        recency = recency.to(device)
        hist_title = hist_title.to(device)
        hist_popularity = hist_popularity.to(device)
        labels = labels.to(device)

        outputs = model(title, entities, ctr, recency ,hist_title, hist_popularity )
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        
        
        print('  batch {} loss: {}'.format(i + 1, running_loss))
        tb_x = epoch_index * len(train_dataloader) + i + 1
        tb_writer.add_scalar('Loss/train', running_loss, tb_x)

    return running_loss, i+1       

# PPRec DataLoader

In [2]:
""" Dataloader """

from dataclasses import dataclass, field
import numpy as np
import polars as pl
from torch.utils.data import Dataset, DataLoader

from ebrec.utils._articles_behaviors import map_list_article_id_to_value
from ebrec.utils._python import (
    repeat_by_list_values_from_matrix
)

from ebrec.utils._constants import (
    DEFAULT_INVIEW_ARTICLES_COL,
    DEFAULT_LABELS_COL,
    DEFAULT_USER_COL,
)

def create_lookup_objects(
    lookup_dictionary: dict[int, np.array], unknown_representation: str,is_array=True) -> tuple[dict[int, pl.Series], np.array]:
    """Creates lookup objects for efficient data retrieval.

    This function generates a dictionary of indexes and a matrix from the given lookup dictionary.
    The generated lookup matrix has an additional row based on the specified unknown representation
    which could be either zeros or the mean of the values in the lookup dictionary.

    Args:
        lookup_dictionary (dict[int, np.array]): A dictionary where keys are unique identifiers (int)
            and values are some representations which can be any data type, commonly used for lookup operations.
        unknown_representation (str): Specifies the method to represent unknown entries.
            It can be either 'zeros' to represent unknowns with a row of zeros, or 'mean' to represent
            unknowns with a row of mean values computed from the lookup dictionary.

    Raises:
        ValueError: If the unknown_representation is not either 'zeros' or 'mean',
            a ValueError will be raised.

    Returns:
        tuple[dict[int, pl.Series], np.array]: A tuple containing two items:
            - A dictionary with the same keys as the lookup_dictionary where values are polars Series
                objects containing a single value, which is the index of the key in the lookup dictionary.
            - A numpy array where the rows correspond to the values in the lookup_dictionary and an
                additional row representing unknown entries as specified by the unknown_representation argument.

    Example:
    >>> data = {
            10: np.array([0.1, 0.2, 0.3]),
            20: np.array([0.4, 0.5, 0.6]),
            30: np.array([0.7, 0.8, 0.9]),
        }
    >>> lookup_dict, lookup_matrix = create_lookup_objects(data, "zeros")

    >>> lookup_dict
        {10: shape: (1,)
            Series: '' [i64]
            [
                    1
            ], 20: shape: (1,)
            Series: '' [i64]
            [
                    2
            ], 30: shape: (1,)
            Series: '' [i64]
            [
                    3
        ]}
    >>> lookup_matrix
        array([[0. , 0. , 0. ],
            [0.1, 0.2, 0.3],
            [0.4, 0.5, 0.6],
            [0.7, 0.8, 0.9]])
    """
    # MAKE LOOKUP DICTIONARY
    lookup_indexes = {
        id: pl.Series("", [i]) for i, id in enumerate(lookup_dictionary, start=1)
    }
    # MAKE LOOKUP MATRIX
    lookup_matrix = np.array(list(lookup_dictionary.values()))
    if is_array:
        if unknown_representation == "zeros":
            UNKNOWN_ARRAY = np.zeros(lookup_matrix.shape[1], dtype=lookup_matrix.dtype)
        elif unknown_representation == "mean":
            UNKNOWN_ARRAY = np.mean(lookup_matrix, axis=0, dtype=lookup_matrix.dtype)
        else:
            raise ValueError(
                f"'{unknown_representation}' is not a specified method. Can be either 'zeros' or 'mean'."
            )

        lookup_matrix = np.vstack([UNKNOWN_ARRAY, lookup_matrix])
    return lookup_indexes, lookup_matrix




@dataclass
class NewsrecDataLoader(Dataset):
    """
    A DataLoader for news recommendation.
    """

    behaviors: pl.DataFrame
    history_column: str
    history_recency: str
    inview_recency: str
    article_dict: dict[int, any]
    unknown_representation: str
    eval_mode: bool = False
    batch_size: int = 32
    inview_col: str = DEFAULT_INVIEW_ARTICLES_COL
    labels_col: str = DEFAULT_LABELS_COL
    user_col: str = DEFAULT_USER_COL
    kwargs: field(default_factory=dict) = None

    def __post_init__(self):
        """
        Post-initialization method. Loads the data and sets additional attributes.
        """
        self.lookup_article_index, self.lookup_article_matrix = create_lookup_objects(
            self.article_dict, unknown_representation=self.unknown_representation
        )
        
        self.unknown_index = [0]
        self.X, self.y = self.load_data()
        if self.kwargs is not None:
            self.set_kwargs(self.kwargs)

    def __len__(self) -> int:
        return int(np.ceil(len(self.X) / float(self.batch_size)))

    def __getitem__(self):
        raise ValueError("Function '__getitem__' needs to be implemented.")

    def load_data(self) -> tuple[pl.DataFrame, pl.DataFrame]:
        
        X = self.behaviors.drop(self.labels_col).with_columns(
            pl.col(self.inview_col).list.len().alias("n_samples")
        )
        y = self.behaviors[self.labels_col]
        return X, y

    def set_kwargs(self, kwargs: dict):
        for key, value in kwargs.items():
            setattr(self, key, value)






@dataclass(kw_only=True)
class PPRecDataLoader(NewsrecDataLoader):
    """ PPRec DataLoader which inherits from the NewsrecDataLoader"""
    entity_mapping: dict[int, list[int]] = None
    ctr_mapping: dict[int, int] = None
    popularity_mapping: dict[int, int] = None
   
    

    def __post_init__(self):
        self.title_prefix = "title_"
        self.entity_prefix = "ner_clusters_text_"
        self.ctr_prefix = "ctr_"
        self.pop_prefix = "popularity_"
        
        (
            self.lookup_article_index_entity,
            self.lookup_article_matrix_entity,
        ) = create_lookup_objects(
            self.entity_mapping, unknown_representation=self.unknown_representation
        )

        (
            self.lookup_article_index_ctr,
            self.lookup_article_matrix_ctr,
        ) = create_lookup_objects(
            self.ctr_mapping, unknown_representation=self.unknown_representation,is_array=False
        )

        (
            self.lookup_article_index_pop,
            self.lookup_article_matrix_pop,
        ) = create_lookup_objects(
            self.popularity_mapping, unknown_representation=self.unknown_representation,is_array=False
        )

        return super().__post_init__()

    def transform(self, df: pl.DataFrame) -> tuple[pl.DataFrame]:
        """
        Special case for NAML as it requires body-encoding, verticals, & subvertivals
        """
        
        title = df.pipe(
            map_list_article_id_to_value,
            behaviors_column=self.history_column,
            mapping=self.lookup_article_index,
            fill_nulls=self.unknown_index,
            drop_nulls=False,
        ).pipe(
            map_list_article_id_to_value,
            behaviors_column=self.inview_col,
            mapping=self.lookup_article_index,
            fill_nulls=self.unknown_index,
            drop_nulls=False,
        )
        
        entities = df.pipe(
            map_list_article_id_to_value,
            behaviors_column=self.history_column,
            mapping=self.lookup_article_index_entity,
            fill_nulls=self.unknown_index,
            drop_nulls=False,
        ).pipe(
            map_list_article_id_to_value,
            behaviors_column=self.inview_col,
            mapping=self.lookup_article_index_entity,
            fill_nulls=self.unknown_index,
            drop_nulls=False,
        )
        ctr = df.pipe(
            map_list_article_id_to_value,
            behaviors_column=self.history_column,
            mapping=self.lookup_article_index_ctr,
            fill_nulls=0,
            drop_nulls=False,
        ).pipe(
            map_list_article_id_to_value,
            behaviors_column=self.inview_col,
            mapping=self.lookup_article_index_ctr,
            fill_nulls=0,
            drop_nulls=False,
        )
        popularity = df.pipe(
            map_list_article_id_to_value,
            behaviors_column=self.history_column,
            mapping=self.lookup_article_index_pop,
            fill_nulls=0,
            drop_nulls=False,
        ).pipe(
            map_list_article_id_to_value,
            behaviors_column=self.inview_col,
            mapping=self.lookup_article_index_pop,
            fill_nulls=0,
            drop_nulls=False,
        )
        
        transformed_df =  (pl.DataFrame()
            .with_columns(title.select(pl.all().name.prefix(self.title_prefix)))
            .with_columns(entities.select(pl.all().name.prefix(self.entity_prefix)))
            .with_columns(ctr.select(pl.all().name.prefix(self.ctr_prefix)))
            .with_columns(popularity.select(pl.all().name.prefix(self.pop_prefix)))
            )
      
        return transformed_df

    def __getitem__(self, idx) -> tuple[tuple[np.ndarray], np.ndarray]:
        batch_X = self.X[idx * self.batch_size : (idx + 1) * self.batch_size].pipe(
            self.transform
        )
        
        batch_y = self.y[idx * self.batch_size : (idx + 1) * self.batch_size]
        

        if self.eval_mode:
            """ Evaluation mode """

            batch_y = np.array(batch_y.to_list())
            
            his_input_title = np.array(
                batch_X[self.title_prefix + self.history_column].to_list()
            )
            his_input_entity = np.array(
                batch_X[self.entity_prefix + self.history_column].to_list()
            )
            his_input_ctr = np.array(
                batch_X[self.ctr_prefix + self.history_column].to_list()
            )
            his_input_recency = np.array(
                batch_X[self.title_prefix +self.history_recency].to_list()
            )
            his_input_pop = np.array(
                batch_X[self.pop_prefix + self.history_column].to_list()
            )

            pred_input_title = np.array(
                batch_X[self.title_prefix + self.inview_col].to_list()
            )
            
            pred_input_entity = np.array(
                batch_X[self.entity_prefix + self.inview_col].to_list()
            )
            pred_input_ctr = np.array(
                batch_X[self.ctr_prefix + self.inview_col].to_list()
            )
            pred_input_recency = np.array(
                batch_X[self.title_prefix + self.inview_recency].to_list()
            )
            pred_input_pop = np.array(
                batch_X[self.pop_prefix + self.inview_col].to_list()
            )
            
            pred_input_title = np.squeeze(
                self.lookup_article_matrix[pred_input_title], axis=2
            )
            
            pred_input_entity = np.squeeze(
                self.lookup_article_matrix_entity[pred_input_entity], axis=2
            )
            
            his_input_title = np.squeeze(
                self.lookup_article_matrix[his_input_title], axis=2
                )
                
            his_input_entity = np.squeeze(
                    self.lookup_article_matrix_entity[his_input_entity], axis=2
                    )
            pred_input_ctr = np.squeeze(
                    pred_input_ctr,axis=2
                )
            
            
            

        else:

            """ Train mode """
            
            batch_y = np.array(batch_y.to_list())
            
            his_input_title = np.array(
                batch_X[self.title_prefix + self.history_column].to_list()
            )
            his_input_entity = np.array(
                batch_X[self.entity_prefix + self.history_column].to_list()
            )
            his_input_ctr = np.array(
                batch_X[self.ctr_prefix + self.history_column].to_list()
            )
            his_input_recency = np.array(
                batch_X[self.title_prefix +self.history_recency].to_list()
            )
            his_input_pop = np.array(
                batch_X[self.pop_prefix + self.history_column].to_list()
            )
            
            pred_input_title = np.array(
                batch_X[self.title_prefix + self.inview_col].to_list()
            )
           
            pred_input_entity = np.array(
                batch_X[self.entity_prefix + self.inview_col].to_list()
            )
            pred_input_ctr = np.array(
                batch_X[self.ctr_prefix + self.inview_col].to_list()
            )
            pred_input_recency = np.array(
                batch_X[self.title_prefix + self.inview_recency].to_list()
            )
            pred_input_pop = np.array(
                batch_X[self.pop_prefix + self.inview_col].to_list()
            )
            
            pred_input_title = np.squeeze(
                self.lookup_article_matrix[pred_input_title], axis=2
            )
            
            pred_input_entity = np.squeeze(
                self.lookup_article_matrix_entity[pred_input_entity], axis=2
            )
             
            his_input_title = np.squeeze(
                self.lookup_article_matrix[his_input_title], axis=2
                )
                
            his_input_entity = np.squeeze(
                    self.lookup_article_matrix_entity[his_input_entity], axis=2
                    )
            pred_input_ctr = np.squeeze(
                    pred_input_ctr,axis=2
                )
        
        
        final_X, final_Y =(
            his_input_title,
            his_input_entity,
            his_input_ctr,
            his_input_recency,
            his_input_pop,
            pred_input_title,
            pred_input_entity,
            pred_input_ctr,
            pred_input_recency,
            pred_input_pop
        ), batch_y
        
        
        
        return final_X,final_Y

# PPRec Hyper params

In [3]:
class hparams_pprec:
    title_size: int = 30
    history_size: int = 50
    body_size: int = 40
    vert_num: int = 100
    vert_emb_dim: int = 10
    subvert_num: int = 100
    subvert_emb_dim: int = 10
    # MODEL ARCHITECTURE
    dense_activation: str = "relu"
    cnn_activation: str = "relu"
    attention_hidden_dim: int = 200
    filter_num: int = 400
    window_size: int = 3
    # MODEL OPTIMIZER:
    optimizer: str = "adam"
    dropout: float = 0.2
    learning_rate: float = 0.0001
    head_num: int = 8
    embed_dim: int = 768
    max_clicked: int = 10
    m_size: int = 768
    p_size: int = 768
    weight_size: int = 256
    

# Necessary imports

In [4]:
from transformers import AutoTokenizer, AutoModel
from pathlib import Path
import torch 
import polars as pl

from ebrec.utils._constants import (
    DEFAULT_HISTORY_ARTICLE_ID_COL,
    DEFAULT_CLICKED_ARTICLES_COL,
    DEFAULT_INVIEW_ARTICLES_COL,
    DEFAULT_IMPRESSION_ID_COL,
    DEFAULT_SUBTITLE_COL,
    DEFAULT_LABELS_COL,
    DEFAULT_TITLE_COL,
    DEFAULT_USER_COL,
    DEFAULT_ARTICLE_MODIFIED_TIMESTAMP_COL,
    DEFAULT_IMPRESSION_TIMESTAMP_COL,
    DEFAULT_HISTORY_IMPRESSION_TIMESTAMP_COL
)

from ebrec.utils._behaviors import (
    create_binary_labels_column,
    sampling_strategy_wu2019,
    add_known_user_column,
    add_prediction_scores,
    truncate_history,
)
from ebrec.evaluation import MetricEvaluator, AucScore, NdcgScore, MrrScore
from ebrec.utils._articles import convert_text2encoding_with_transformers,concat_list_to_text
from ebrec.utils._polars import concat_str_columns, slice_join_dataframes
from ebrec.utils._articles import create_article_id_to_value_mapping
from ebrec.utils._nlp import get_transformers_word_embeddings
from ebrec.utils._python import write_submission_file, rank_predictions_by_score

%load_ext autoreload
%autoreload 2


import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

  from .autonotebook import tqdm as notebook_tqdm


cpu


In [5]:
""" Path to articles.parquet"""
df_articles = pl.read_parquet("/Users/sohamchatterjee/Documents/UvA/RecSYS/Project/PP-Rec/DocumentedCode/articles.parquet")
""" Roberta Base is used as the transofrmer model for the doc embeddings"""
TRANSFORMER_MODEL_NAME = "FacebookAI/xlm-roberta-base"
TEXT_COLUMNS_TO_USE = [DEFAULT_SUBTITLE_COL, DEFAULT_TITLE_COL]
MAX_TITLE_LENGTH = 30

# LOAD HUGGINGFACE:
transformer_model = AutoModel.from_pretrained(TRANSFORMER_MODEL_NAME)
transformer_tokenizer = AutoTokenizer.from_pretrained(TRANSFORMER_MODEL_NAME)

# We'll init the word embeddings using the
word2vec_embedding = get_transformers_word_embeddings(transformer_model)
#
df_articles, cat_cal = concat_str_columns(df_articles, columns=TEXT_COLUMNS_TO_USE)
df_articles, token_col_title = convert_text2encoding_with_transformers(
    df_articles, transformer_tokenizer, cat_cal, max_length=MAX_TITLE_LENGTH
)
# =>
article_mapping_title = create_article_id_to_value_mapping(
    df=df_articles, value_col=token_col_title
)



In [6]:
import pickle
article_mapping_title, article_mapping_entity, articles_ctr, popularity_mapping = {},{},{},{}
with open('/Users/sohamchatterjee/Documents/UvA/RecSYS/Project/PP-Rec/DocumentedCode/demo_processed/article_mapping_title_DEMO.pkl', 'rb') as handle:
    article_mapping_title = pickle.load(handle)
with open('/Users/sohamchatterjee/Documents/UvA/RecSYS/Project/PP-Rec/DocumentedCode/demo_processed/article_mapping_entity_DEMO.pkl', 'rb') as handle:
    article_mapping_entity = pickle.load(handle)
with open('/Users/sohamchatterjee/Documents/UvA/RecSYS/Project/PP-Rec/DocumentedCode/demo_processed/articles_ctr_DEMO.pkl', 'rb') as handle:
    articles_ctr = pickle.load(handle)
with open('/Users/sohamchatterjee/Documents/UvA/RecSYS/Project/PP-Rec/DocumentedCode/demo_processed/popularity_mapping_DEMO.pkl', 'rb') as handle:
    popularity_mapping = pickle.load(handle)

COLUMNS = [
   'user_id',
   'article_id_fixed',
   'article_ids_inview',
   'article_ids_clicked',
   'impression_id',
   'labels',
   'recency_inview',
   'recency_hist'  
]
df_train  = pl.scan_parquet("/Users/sohamchatterjee/Documents/UvA/RecSYS/Project/PP-Rec/DocumentedCode/demo_processed/train_DEMO.parquet").select(COLUMNS).collect()

df_validation =  pl.scan_parquet("/Users/sohamchatterjee/Documents/UvA/RecSYS/Project/PP-Rec/DocumentedCode/demo_processed/val_DEMO.parquet").select(COLUMNS).collect()

In [7]:
# This is just for sanity checking. Comment this out when running the traning pipeline
df_train = df_train.head(10)
df_validation = df_validation.head(10)

In [8]:
train_dataloader = PPRecDataLoader(
    behaviors=df_train,
    article_dict=article_mapping_title,
    entity_mapping=article_mapping_entity,
    ctr_mapping=articles_ctr,
    popularity_mapping = popularity_mapping,
    unknown_representation="zeros",
    history_column=DEFAULT_HISTORY_ARTICLE_ID_COL,
    history_recency = 'recency_hist',
    inview_recency = 'recency_inview',
    eval_mode=False,
    batch_size=512,
)

In [9]:
val_dataloader = PPRecDataLoader(
    behaviors=df_validation,
    article_dict=article_mapping_title,
    entity_mapping=article_mapping_entity,
    ctr_mapping=articles_ctr,
    popularity_mapping = popularity_mapping,
    unknown_representation="zeros",
    history_column=DEFAULT_HISTORY_ARTICLE_ID_COL,
    history_recency = 'recency_hist',
    inview_recency = 'recency_inview',
    eval_mode=True,
    batch_size=512,
)

In [10]:
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/pprec_check{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5
loss_fn = BPELoss()
model = PPRec(hparams_pprec,word2vec_embedding)
model.to(device)
optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=0.01,
            weight_decay=1e-4
        )

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    running_train_loss, total_no_batches = train_one_epoch(epoch_number, writer,train_dataloader, optimizer,model,loss_fn,device)
    avg_loss = running_train_loss / total_no_batches


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(val_dataloader):
            vinputs, vlabels = vdata
            vtitle = vinputs[5]
            ventities = vinputs[6]
            vctr = vinputs[7]
            vrecency = vinputs[8]
            vhist_title = vinputs[0]
            vhist_popularity = vinputs[2]

            vtitle = torch.from_numpy(vtitle)
            ventities = torch.from_numpy(ventities)
            vctr = torch.from_numpy(vctr)
            vrecency = torch.from_numpy(vrecency)
            vhist_title = torch.from_numpy(vhist_title)
            vhist_popularity = torch.from_numpy(vhist_popularity)
            vlabels = torch.from_numpy(vlabels)
        
            vtitle = vtitle.to(device)
            ventities = ventities.to(device)
            vctr = vctr.to(device)
            vrecency = vrecency.to(device)
            vhist_title = vhist_title.to(device)
            vhist_popularity = vhist_popularity.to(device)
            vlabels = vlabels.to(device)
            


            voutputs = model(vtitle, ventities, vctr, vrecency , vhist_title, vhist_popularity )
          
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1

EPOCH 1:
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
  batch 1 loss: 0.6892244219779968
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
LOSS train 0.6892244219779968 valid 0.6577883958816528
EPOCH 2:
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
  batch 1 loss: 0.6899724006652832
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here

 # EVAL MODE

In [11]:
import torch
from transformers import AutoTokenizer, AutoModel
from ebrec.evaluation import MetricEvaluator, AucScore, NdcgScore, MrrScore
import polars as pl
from ebrec.utils._constants import (
    DEFAULT_HISTORY_ARTICLE_ID_COL,
    DEFAULT_CLICKED_ARTICLES_COL,
    DEFAULT_INVIEW_ARTICLES_COL,
    DEFAULT_IMPRESSION_ID_COL,
    DEFAULT_SUBTITLE_COL,
    DEFAULT_LABELS_COL,
    DEFAULT_TITLE_COL,
    DEFAULT_USER_COL,
    DEFAULT_ARTICLE_MODIFIED_TIMESTAMP_COL,
    DEFAULT_IMPRESSION_TIMESTAMP_COL,
    DEFAULT_HISTORY_IMPRESSION_TIMESTAMP_COL
)
from ebrec.utils._nlp import get_transformers_word_embeddings
import numpy as np


device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
TRANSFORMER_MODEL_NAME = "FacebookAI/xlm-roberta-base"

transformer_model = AutoModel.from_pretrained(TRANSFORMER_MODEL_NAME)
transformer_tokenizer = AutoTokenizer.from_pretrained(TRANSFORMER_MODEL_NAME)

word2vec_embedding = get_transformers_word_embeddings(transformer_model)

saved_model = PPRec(hparams_pprec,word2vec_embedding)

""" Best Model checkpoint path generated in the training loop """
PATH = '/Users/sohamchatterjee/Documents/UvA/RecSYS/Project/PP-Rec/DocumentedCode/model_20240628_214608_0'
saved_model.load_state_dict(torch.load(PATH,weights_only=True,map_location=torch.device('cpu')))
saved_model.to(device)
import pickle
article_mapping_title, article_mapping_entity, articles_ctr, popularity_mapping = {},{},{},{}

""" Article title dictionary  """
with open('/Users/sohamchatterjee/Documents/UvA/RecSYS/Project/PP-Rec/DocumentedCode/demo_processed/article_mapping_title_DEMO.pkl', 'rb') as handle:
    article_mapping_title = pickle.load(handle)
    """ Article Entity dictionary  """
with open('/Users/sohamchatterjee/Documents/UvA/RecSYS/Project/PP-Rec/DocumentedCode/demo_processed/article_mapping_entity_DEMO.pkl', 'rb') as handle:
    article_mapping_entity = pickle.load(handle)
""" Article CTR dictionary  """
with open('/Users/sohamchatterjee/Documents/UvA/RecSYS/Project/PP-Rec/DocumentedCode/demo_processed/articles_ctr_DEMO.pkl', 'rb') as handle:
    articles_ctr = pickle.load(handle)
""" Article Popularitydictionary  """
with open('/Users/sohamchatterjee/Documents/UvA/RecSYS/Project/PP-Rec/DocumentedCode/demo_processed/popularity_mapping_DEMO.pkl', 'rb') as handle:
    popularity_mapping = pickle.load(handle)


COLUMNS = [
   'user_id',
   'article_id_fixed',
   'article_ids_inview',
   'article_ids_clicked',
   'impression_id',
   'labels',
   'recency_inview',
   'recency_hist'  
]

""" Demo validation parquet file"""
df_validation =  pl.scan_parquet("/Users/sohamchatterjee/Documents/UvA/RecSYS/Project/PP-Rec/DocumentedCode/demo_processed/val_DEMO.parquet").collect()

""" This is just for sanity checking"""
df_validation = df_validation.head(10)

val_dataloader = PPRecDataLoader(
    behaviors=df_validation,
    article_dict=article_mapping_title,
    entity_mapping=article_mapping_entity,
    ctr_mapping=articles_ctr,
    popularity_mapping = popularity_mapping,
    unknown_representation="zeros",
    history_column=DEFAULT_HISTORY_ARTICLE_ID_COL,
    history_recency = 'recency_hist',
    inview_recency = 'recency_inview',
    eval_mode=True,
    batch_size=1024,
)


saved_model.eval()

predictions = np.empty(shape=(4,5))
with torch.no_grad():
    for i, vdata in enumerate(val_dataloader):
            vinputs, vlabels = vdata
            vtitle = vinputs[5]
            ventities = vinputs[6]
            vctr = vinputs[7]
            vrecency = vinputs[8]
            vhist_title = vinputs[0]
            vhist_popularity = vinputs[2]

            vtitle = torch.from_numpy(vtitle)
            ventities = torch.from_numpy(ventities)
            vctr = torch.from_numpy(vctr)
            vrecency = torch.from_numpy(vrecency)
            vhist_title = torch.from_numpy(vhist_title)
            vhist_popularity = torch.from_numpy(vhist_popularity)
            vlabels = torch.from_numpy(vlabels)
            
            
            
            vtitle = vtitle.to(device)
            ventities = ventities.to(device)
            vctr = vctr.to(device)
            vrecency = vrecency.to(device)
            vhist_title = vhist_title.to(device)
            vhist_popularity = vhist_popularity.to(device)
            vlabels = vlabels.to(device)
            
            

            outputs = saved_model(vtitle, ventities, vctr, vrecency , vhist_title, vhist_popularity).cpu().detach().numpy()
            predictions = np.concatenate([predictions,outputs],axis=0)
            
           
            

predictions = predictions[4:]





df_validation = df_validation.with_columns(pl.Series(name="predicted_scores", values=predictions)) 

metrics = MetricEvaluator(
    labels=df_validation["labels"].to_list(),
    predictions=df_validation["predicted_scores"].to_list(),
    metric_functions=[AucScore(), MrrScore(), NdcgScore(k=5), NdcgScore(k=10)],
)
print(metrics.evaluate())


cpu




Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
<MetricEvaluator class>: 
 {
    "auc": 0.475,
    "mrr": 0.39499999999999996,
    "ndcg@5": 0.5440741988597636,
    "ndcg@10": 0.5440741988597636
}
