In [478]:
# Constants #

# BEHAVIORS
DEFAULT_IMPRESSION_TIMESTAMP_COL = "impression_time"
DEFAULT_SCROLL_PERCENTAGE_COL = "scroll_percentage"
DEFAULT_CLICKED_ARTICLES_COL = "article_ids_clicked"
DEFAULT_INVIEW_ARTICLES_COL = "article_ids_inview"
DEFAULT_IMPRESSION_ID_COL = "impression_id"
DEFAULT_IS_SUBSCRIBER_COL = "is_subscriber"
DEFAULT_IS_SSO_USER_COL = "is_sso_user"
DEFAULT_ARTICLE_ID_COL = "article_id"
DEFAULT_SESSION_ID_COL = "session_id"
DEFAULT_READ_TIME_COL = "read_time"
DEFAULT_DEVICE_COL = "device_type"
DEFAULT_POSTCODE_COL = "postcode"
DEFAULT_GENDER_COL = "gender"
DEFAULT_USER_COL = "user_id"
DEFAULT_AGE_COL = "age"

DEFAULT_NEXT_SCROLL_PERCENTAGE_COL = f"next_{DEFAULT_SCROLL_PERCENTAGE_COL}"
DEFAULT_NEXT_READ_TIME_COL = f"next_{DEFAULT_READ_TIME_COL}"

# ARTICLES
DEFAULT_ARTICLE_MODIFIED_TIMESTAMP_COL = "last_modified_time"
DEFAULT_ARTICLE_PUBLISHED_TIMESTAMP_COL = "published_time"
DEFAULT_SENTIMENT_LABEL_COL = "sentiment_label"
DEFAULT_SENTIMENT_SCORE_COL = "sentiment_score"
DEFAULT_TOTAL_READ_TIME_COL = "total_read_time"
DEFAULT_TOTAL_PAGEVIEWS_COL = "total_pageviews"
DEFAULT_TOTAL_INVIEWS_COL = "total_inviews"
DEFAULT_ARTICLE_TYPE_COL = "article_type"
DEFAULT_CATEGORY_STR_COL = "category_str"
DEFAULT_SUBCATEGORY_COL = "subcategory"
DEFAULT_ENTITIES_COL = "entity_groups"
DEFAULT_IMAGE_IDS_COL = "image_ids"
DEFAULT_SUBTITLE_COL = "subtitle"
DEFAULT_CATEGORY_COL = "category"
DEFAULT_NER_COL = "ner_clusters"
DEFAULT_PREMIUM_COL = "premium"
DEFAULT_TOPICS_COL = "topics"
DEFAULT_TITLE_COL = "title"
DEFAULT_BODY_COL = "body"
DEFAULT_URL_COL = "url"

# HISTORY
DEFAULT_HISTORY_IMPRESSION_TIMESTAMP_COL = f"{DEFAULT_IMPRESSION_TIMESTAMP_COL}_fixed"
DEFAULT_HISTORY_SCROLL_PERCENTAGE_COL = f"{DEFAULT_SCROLL_PERCENTAGE_COL}_fixed"
DEFAULT_HISTORY_ARTICLE_ID_COL = f"{DEFAULT_ARTICLE_ID_COL}_fixed"
DEFAULT_HISTORY_READ_TIME_COL = f"{DEFAULT_READ_TIME_COL}_fixed"

# CREATE
DEFAULT_KNOWN_USER_COL = "is_known_user"
DEFAULT_LABELS_COL = "labels"


In [479]:
# Imports #

import torch
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass, field
import pandas as pd
import numpy as np
import polars as pl
import torch
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass, field
import pandas as pd
from pathlib import Path
import numpy as np
import random

In [480]:
def create_lookup_objects(
    lookup_dictionary: dict[int, np.array], unknown_representation: str
) -> tuple[dict[int, pl.Series], np.array]:
    """Creates lookup objects for efficient data retrieval.

    This function generates a dictionary of indexes and a matrix from the given lookup dictionary.
    The generated lookup matrix has an additional row based on the specified unknown representation
    which could be either zeros or the mean of the values in the lookup dictionary.

    Args:
        lookup_dictionary (dict[int, np.array]): A dictionary where keys are unique identifiers (int)
            and values are some representations which can be any data type, commonly used for lookup operations.
        unknown_representation (str): Specifies the method to represent unknown entries.
            It can be either 'zeros' to represent unknowns with a row of zeros, or 'mean' to represent
            unknowns with a row of mean values computed from the lookup dictionary.

    Raises:
        ValueError: If the unknown_representation is not either 'zeros' or 'mean',
            a ValueError will be raised.

    Returns:
        tuple[dict[int, pl.Series], np.array]: A tuple containing two items:
            - A dictionary with the same keys as the lookup_dictionary where values are polars Series
                objects containing a single value, which is the index of the key in the lookup dictionary.
            - A numpy array where the rows correspond to the values in the lookup_dictionary and an
                additional row representing unknown entries as specified by the unknown_representation argument.

    Example:
    >>> data = {
            10: np.array([0.1, 0.2, 0.3]),
            20: np.array([0.4, 0.5, 0.6]),
            30: np.array([0.7, 0.8, 0.9]),
        }
    >>> lookup_dict, lookup_matrix = create_lookup_objects(data, "zeros")

    >>> lookup_dict
        {10: shape: (1,)
            Series: '' [i64]
            [
                    1
            ], 20: shape: (1,)
            Series: '' [i64]
            [
                    2
            ], 30: shape: (1,)
            Series: '' [i64]
            [
                    3
        ]}
    >>> lookup_matrix
        array([[0. , 0. , 0. ],
            [0.1, 0.2, 0.3],
            [0.4, 0.5, 0.6],
            [0.7, 0.8, 0.9]])
    """
    # MAKE LOOKUP DICTIONARY
    lookup_indexes = {
        id: pl.Series("", [i]) for i, id in enumerate(lookup_dictionary, start=1)
    }
    # MAKE LOOKUP MATRIX
    lookup_matrix = np.array(list(lookup_dictionary.values()))

    if unknown_representation == "zeros":
        UNKNOWN_ARRAY = np.zeros(lookup_matrix.shape[1], dtype=lookup_matrix.dtype)
    elif unknown_representation == "mean":
        UNKNOWN_ARRAY = np.mean(lookup_matrix, axis=0, dtype=lookup_matrix.dtype)
    else:
        raise ValueError(
            f"'{unknown_representation}' is not a specified method. Can be either 'zeros' or 'mean'."
        )

    lookup_matrix = np.vstack([UNKNOWN_ARRAY, lookup_matrix])
    return lookup_indexes, lookup_matrix


In [481]:
def generate_unique_name(existing_names: list[str], base_name: str = "new_name"):
    """
    Generate a unique name based on a list of existing names.

    Args:
        existing_names (list of str): The list of existing names.
        base_name (str): The base name to start with. Default is 'newName'.

    Returns:
        str: A unique name.
    Example
    >>> existing_names = ['name1', 'name2', 'newName', 'newName_1']
    >>> generate_unique_name(existing_names, 'newName')
        'newName_2'
    """
    if base_name not in existing_names:
        return base_name

    suffix = 1
    new_name = f"{base_name}_{suffix}"

    while new_name in existing_names:
        suffix += 1
        new_name = f"{base_name}_{suffix}"

    return new_name


In [482]:
def map_list_article_id_to_value(
    behaviors: pl.DataFrame,
    behaviors_column: str,
    mapping: dict[int, pl.Series],
    drop_nulls: bool = False,
    fill_nulls: any = None,
) -> pl.DataFrame:
    """

    Maps the values of a column in a DataFrame `behaviors` containing article IDs to their corresponding values
    in a column in another DataFrame `articles`. The mapping is performed using a dictionary constructed from
    the two DataFrames. The resulting DataFrame has the same columns as `behaviors`, but with the article IDs
    replaced by their corresponding values.

    Args:
        behaviors (pl.DataFrame): The DataFrame containing the column to be mapped.
        behaviors_column (str): The name of the column to be mapped in `behaviors`.
        mapping (dict[int, pl.Series]): A dictionary with article IDs as keys and corresponding values as values.
            Note, 'replace' works a lot faster when values are of type pl.Series!
        drop_nulls (bool): If `True`, any rows in the resulting DataFrame with null values will be dropped.
            If `False` and `fill_nulls` is specified, null values in `behaviors_column` will be replaced with `fill_null`.
        fill_nulls (Optional[any]): If specified, any null values in `behaviors_column` will be replaced with this value.

    Returns:
        pl.DataFrame: A new DataFrame with the same columns as `behaviors`, but with the article IDs in
            `behaviors_column` replaced by their corresponding values in `mapping`.

    Example:
    >>> behaviors = pl.DataFrame(
            {"user_id": [1, 2, 3, 4, 5], "article_ids": [["A1", "A2"], ["A2", "A3"], ["A1", "A4"], ["A4", "A4"], None]}
        )
    >>> articles = pl.DataFrame(
            {
                "article_id": ["A1", "A2", "A3"],
                "article_type": ["News", "Sports", "Entertainment"],
            }
        )
    >>> articles_dict = dict(zip(articles["article_id"], articles["article_type"]))
    >>> map_list_article_id_to_value(
            behaviors=behaviors,
            behaviors_column="article_ids",
            mapping=articles_dict,
            fill_nulls="Unknown",
        )
        shape: (4, 2)
        ┌─────────┬─────────────────────────────┐
        │ user_id ┆ article_ids                 │
        │ ---     ┆ ---                         │
        │ i64     ┆ list[str]                   │
        ╞═════════╪═════════════════════════════╡
        │ 1       ┆ ["News", "Sports"]          │
        │ 2       ┆ ["Sports", "Entertainment"] │
        │ 3       ┆ ["News", "Unknown"]         │
        │ 4       ┆ ["Unknown", "Unknown"]      │
        │ 5       ┆ ["Unknown"]                 │
        └─────────┴─────────────────────────────┘
    >>> map_list_article_id_to_value(
            behaviors=behaviors,
            behaviors_column="article_ids",
            mapping=articles_dict,
            drop_nulls=True,
        )
        shape: (4, 2)
        ┌─────────┬─────────────────────────────┐
        │ user_id ┆ article_ids                 │
        │ ---     ┆ ---                         │
        │ i64     ┆ list[str]                   │
        ╞═════════╪═════════════════════════════╡
        │ 1       ┆ ["News", "Sports"]          │
        │ 2       ┆ ["Sports", "Entertainment"] │
        │ 3       ┆ ["News"]                    │
        │ 4       ┆ null                        │
        │ 5       ┆ null                        │
        └─────────┴─────────────────────────────┘
    >>> map_list_article_id_to_value(
            behaviors=behaviors,
            behaviors_column="article_ids",
            mapping=articles_dict,
            drop_nulls=False,
        )
        shape: (4, 2)
        ┌─────────┬─────────────────────────────┐
        │ user_id ┆ article_ids                 │
        │ ---     ┆ ---                         │
        │ i64     ┆ list[str]                   │
        ╞═════════╪═════════════════════════════╡
        │ 1       ┆ ["News", "Sports"]          │
        │ 2       ┆ ["Sports", "Entertainment"] │
        │ 3       ┆ ["News", null]              │
        │ 4       ┆ [null, null]                │
        │ 5       ┆ [null]                      │
        └─────────┴─────────────────────────────┘
    """
   
    GROUPBY_ID = generate_unique_name(behaviors.columns, "_groupby_id")
    behaviors = behaviors.lazy().with_row_index(GROUPBY_ID)
    # =>
    select_column = (
        behaviors.select(pl.col(GROUPBY_ID), pl.col(behaviors_column))
        .explode(behaviors_column)
        .with_columns(pl.col(behaviors_column).replace(mapping, default=None))
        .collect()
    )
    
    if drop_nulls:
        
        select_column = select_column.drop_nulls()
    elif fill_nulls is not None:
        print(fill_nulls)
        print(pl.col(behaviors_column).fill_null(fill_nulls))

        select_column = select_column.with_columns(
            pl.col(behaviors_column).fill_null(fill_nulls)
        )
    
    select_column = (
        select_column.lazy().group_by(GROUPBY_ID).agg(behaviors_column).collect()
    )
    
    #print(behaviors)
    return (
        behaviors.drop(behaviors_column)
        .collect()
        .join(select_column, on=GROUPBY_ID, how="left")
        .drop(GROUPBY_ID)
    )


In [483]:
def repeat_by_list_values_from_matrix(
    input_array: np.array,
    matrix: np.array,
    repeats: np.array,
) -> np.array:
    """
    Example:
        >>> input = np.array([[1, 0], [0, 0]])
        >>> matrix = np.array([[7,8,9], [10,11,12]])
        >>> repeats = np.array([1, 2])
        >>> repeat_by_list_values_from_matrix(input, matrix, repeats)
            array([[[10, 11, 12],
                    [ 7,  8,  9]],
                    [[ 7,  8,  9],
                    [ 7,  8,  9]],
                    [[ 7,  8,  9],
                    [ 7,  8,  9]]])
    """
    return np.repeat(matrix[input_array], repeats=repeats, axis=0)

In [484]:
def _check_columns_in_df(df: pl.DataFrame, columns: list[str]) -> None:
    """
    Checks whether all specified columns are present in a Polars DataFrame.
    Raises a ValueError if any of the specified columns are not present in the DataFrame.

    Args:
        df (pl.DataFrame): The input DataFrame.
        columns (list[str]): The names of the columns to check for.

    Returns:
        None.

    Examples:
    >>> df = pl.DataFrame({"user_id": [1], "first_name": ["J"]})
    >>> check_columns_in_df(df, columns=["user_id", "not_in"])
        ValueError: Invalid input provided. The dataframe does not contain columns ['not_in'].
    """
    columns_not_in_df = [col for col in columns if col not in df.columns]
    if columns_not_in_df:
        raise ValueError(
            f"Invalid input provided. The DataFrame does not contain columns {columns_not_in_df}."
        )



In [485]:
def shuffle_rows(df: pl.DataFrame, seed: int = None) -> pl.DataFrame:
    """
    Shuffle the rows of a DataFrame. This methods allows for LazyFrame,
    whereas, 'df.sample(fraction=1)' is not compatible.

    Examples:
    >>> df = pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
    >>> shuffle_rows(df.lazy(), seed=123).collect()
        shape: (3, 3)
        ┌─────┬─────┬─────┐
        │ a   ┆ b   ┆ c   │
        │ --- ┆ --- ┆ --- │
        │ i64 ┆ i64 ┆ i64 │
        ╞═════╪═════╪═════╡
        │ 1   ┆ 1   ┆ 1   │
        │ 3   ┆ 3   ┆ 3   │
        │ 2   ┆ 2   ┆ 2   │
        └─────┴─────┴─────┘
    >>> shuffle_rows(df.lazy(), seed=None).collect().sort("a")
        shape: (3, 3)
        ┌─────┬─────┬─────┐
        │ a   ┆ b   ┆ c   │
        │ --- ┆ --- ┆ --- │
        │ i64 ┆ i64 ┆ i64 │
        ╞═════╪═════╪═════╡
        │ 1   ┆ 1   ┆ 1   │
        │ 2   ┆ 2   ┆ 2   │
        │ 3   ┆ 3   ┆ 3   │
        └─────┴─────┴─────┘

    Test_:
    >>> all([sum(row) == row[0]*3 for row in shuffle_rows(df, seed=None).iter_rows()])
        True

    Note:
        Be aware that 'pl.all().shuffle()' shuffles columns-wise, i.e., with if pl.all().shuffle(None)
        each column's element are shuffled independently from each other (example might change with no seed):
    >>> df_ = pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]}).select(pl.all().shuffle(None)).sort("a")
    >>> df_
        shape: (3, 3)
        ┌─────┬─────┬─────┐
        │ a   ┆ b   ┆ c   │
        │ --- ┆ --- ┆ --- │
        │ i64 ┆ i64 ┆ i64 │
        ╞═════╪═════╪═════╡
        │ 1   ┆ 3   ┆ 1   │
        │ 2   ┆ 2   ┆ 3   │
        │ 3   ┆ 1   ┆ 2   │
        └─────┴─────┴─────┘
    >>> all([sum(row) == row[0]*3 for row in shuffle_rows(df_, seed=None).iter_rows()])
        False
    """
    seed = seed if seed is not None else random.randint(1, 1_000_000)
    return df.select(pl.all().shuffle(seed))



def shuffle_list_column(
    df: pl.DataFrame, column: str, seed: int = None
) -> pl.DataFrame:
    """Shuffles the values in a list column of a DataFrame.

    Args:
        df (pl.DataFrame): The input DataFrame.
        column (str): The name of the column to shuffle.
        seed (int, optional): An optional seed value.
            Defaults to None.

    Returns:
        pl.DataFrame: A new DataFrame with the specified column shuffled.

    Example:
    >>> df = pl.DataFrame(
            {
                "id": [1, 2, 3],
                "list_col": [["a-", "b-", "c-"], ["a#", "b#"], ["a@", "b@", "c@"]],
                "rdn": ["h", "e", "y"],
            }
        )
    >>> shuffle_list_column(df, 'list_col', seed=1)
        shape: (3, 3)
        ┌─────┬────────────────────┬─────┐
        │ id  ┆ list_col           ┆ rdn │
        │ --- ┆ ---                ┆ --- │
        │ i64 ┆ list[str]          ┆ str │
        ╞═════╪════════════════════╪═════╡
        │ 1   ┆ ["c-", "b-", "a-"] ┆ h   │
        │ 2   ┆ ["a#", "b#"]       ┆ e   │
        │ 3   ┆ ["b@", "c@", "a@"] ┆ y   │
        └─────┴────────────────────┴─────┘

    No seed:
    >>> shuffle_list_column(df, 'list_col', seed=None)
        shape: (3, 3)
        ┌─────┬────────────────────┬─────┐
        │ id  ┆ list_col           ┆ rdn │
        │ --- ┆ ---                ┆ --- │
        │ i64 ┆ list[str]          ┆ str │
        ╞═════╪════════════════════╪═════╡
        │ 1   ┆ ["b-", "a-", "c-"] ┆ h   │
        │ 2   ┆ ["a#", "b#"]       ┆ e   │
        │ 3   ┆ ["a@", "c@", "b@"] ┆ y   │
        └─────┴────────────────────┴─────┘

    Test_:
    >>> assert (
            sorted(shuffle_list_column(df, "list_col", seed=None)["list_col"].to_list()[0])
            == df["list_col"].to_list()[0]
        )

    >>> df = pl.DataFrame({
            'id': [1, 2, 3],
            'list_col': [[6, 7, 8], [-6, -7, -8], [60, 70, 80]],
            'rdn': ['h', 'e', 'y']
        })
    >>> shuffle_list_column(df.lazy(), 'list_col', seed=2).collect()
        shape: (3, 3)
        ┌─────┬──────────────┬─────┐
        │ id  ┆ list_col     ┆ rdn │
        │ --- ┆ ---          ┆ --- │
        │ i64 ┆ list[i64]    ┆ str │
        ╞═════╪══════════════╪═════╡
        │ 1   ┆ [7, 6, 8]    ┆ h   │
        │ 2   ┆ [-8, -7, -6] ┆ e   │
        │ 3   ┆ [60, 80, 70] ┆ y   │
        └─────┴──────────────┴─────┘

    Test_:
    >>> assert (
            sorted(shuffle_list_column(df, "list_col", seed=None)["list_col"].to_list()[0])
            == df["list_col"].to_list()[0]
        )
    """
    _COLUMN_ORDER = df.columns
    GROUPBY_ID = generate_unique_name(_COLUMN_ORDER, "_groupby_id")

    # Row count is
    df = df.with_row_index(GROUPBY_ID)
    df_shuffle = (
        df.explode(column)
        .pipe(shuffle_rows, seed=seed)
        .group_by(GROUPBY_ID)
        .agg(column)
    )
    return (
        df.drop(column)
        .join(df_shuffle, on=GROUPBY_ID, how="left")
        .drop(GROUPBY_ID)
        .select(_COLUMN_ORDER)
    )



def create_binary_labels_column(
    df: pl.DataFrame,
    shuffle: bool = True,
    seed: int = None,
    clicked_col: str = DEFAULT_CLICKED_ARTICLES_COL,
    inview_col: str = DEFAULT_INVIEW_ARTICLES_COL,
    label_col: str = DEFAULT_LABELS_COL,
) -> pl.DataFrame:
    """Creates a new column in a DataFrame containing binary labels indicating
    whether each article ID in the "article_ids" column is present in the corresponding
    "list_destination" column.

    Args:
        df (pl.DataFrame): The input DataFrame.

    Returns:
        pl.DataFrame: A new DataFrame with an additional "labels" column.

    Examples:
    >>> from ebrec.utils._constants import (
            DEFAULT_CLICKED_ARTICLES_COL,
            DEFAULT_INVIEW_ARTICLES_COL,
            DEFAULT_LABELS_COL,
        )
    >>> df = pl.DataFrame(
            {
                DEFAULT_INVIEW_ARTICLES_COL: [[1, 2, 3], [4, 5, 6], [7, 8]],
                DEFAULT_CLICKED_ARTICLES_COL: [[2, 3, 4], [3, 5], None],
            }
        )
    >>> create_binary_labels_column(df)
        shape: (3, 3)
        ┌────────────────────┬─────────────────────┬───────────┐
        │ article_ids_inview ┆ article_ids_clicked ┆ labels    │
        │ ---                ┆ ---                 ┆ ---       │
        │ list[i64]          ┆ list[i64]           ┆ list[i8]  │
        ╞════════════════════╪═════════════════════╪═══════════╡
        │ [1, 2, 3]          ┆ [2, 3, 4]           ┆ [0, 1, 1] │
        │ [4, 5, 6]          ┆ [3, 5]              ┆ [0, 1, 0] │
        │ [7, 8]             ┆ null                ┆ [0, 0]    │
        └────────────────────┴─────────────────────┴───────────┘
    >>> create_binary_labels_column(df.lazy(), shuffle=True, seed=123).collect()
        shape: (3, 3)
        ┌────────────────────┬─────────────────────┬───────────┐
        │ article_ids_inview ┆ article_ids_clicked ┆ labels    │
        │ ---                ┆ ---                 ┆ ---       │
        │ list[i64]          ┆ list[i64]           ┆ list[i8]  │
        ╞════════════════════╪═════════════════════╪═══════════╡
        │ [3, 1, 2]          ┆ [2, 3, 4]           ┆ [1, 0, 1] │
        │ [5, 6, 4]          ┆ [3, 5]              ┆ [1, 0, 0] │
        │ [7, 8]             ┆ null                ┆ [0, 0]    │
        └────────────────────┴─────────────────────┴───────────┘
    Test_:
    >>> assert create_binary_labels_column(df, shuffle=False)[DEFAULT_LABELS_COL].to_list() == [
            [0, 1, 1],
            [0, 1, 0],
            [0, 0],
        ]
    >>> assert create_binary_labels_column(df, shuffle=True)[DEFAULT_LABELS_COL].list.sum().to_list() == [
            2,
            1,
            0,
        ]
    """
    _check_columns_in_df(df, [inview_col, clicked_col])
    _COLUMNS = df.columns
    GROUPBY_ID = generate_unique_name(_COLUMNS, "_groupby_id")

    df = df.with_row_index(GROUPBY_ID)

    if shuffle:
        df = shuffle_list_column(df, column=inview_col, seed=seed)

    df_labels = (
        df.explode(inview_col)
        .with_columns(
            pl.col(inview_col).is_in(pl.col(clicked_col)).cast(pl.Int8).alias(label_col)
        )
        .group_by(GROUPBY_ID)
        .agg(label_col)
    )
    return (
        df.join(df_labels, on=GROUPBY_ID, how="left")
        .drop(GROUPBY_ID)
        .select(_COLUMNS + [label_col])
    )

def create_lookup_dict(df: pl.DataFrame, key: str, value: str) -> dict:
    """
    Creates a dictionary lookup table from a Pandas-like DataFrame.

    Args:
        df (pl.DataFrame): The DataFrame from which to create the lookup table.
        key (str): The name of the column containing the keys for the lookup table.
        value (str): The name of the column containing the values for the lookup table.

    Returns:
        dict: A dictionary where the keys are the values from the `key` column of the DataFrame
            and the values are the values from the `value` column of the DataFrame.

    Example:
        >>> df = pl.DataFrame({'id': [1, 2, 3], 'name': ['Alice', 'Bob', 'Charlie']})
        >>> create_lookup_dict(df, 'id', 'name')
            {1: 'Alice', 2: 'Bob', 3: 'Charlie'}
    """
    return dict(zip(df[key], df[value]))




def create_article_id_to_value_mapping(
    df: pl.DataFrame,
    value_col: str,
    article_col: str = DEFAULT_ARTICLE_ID_COL,
):
    return create_lookup_dict(
        df.select(article_col, value_col), key=article_col, value=value_col
    )


In [486]:
def create_user_id_to_int_mapping(
    df: pl.DataFrame, user_col: str = DEFAULT_USER_COL, value_str: str = "id"
):
    return create_lookup_dict(
        df.select(pl.col(user_col).unique()).with_row_index(value_str),
        key=user_col,
        value=value_str,
    )



In [487]:
# LOAD DATA:

TOKEN_COL = "tokens"
N_SAMPLES = "n"
BATCH_SIZE = 100


PATH_DATA = Path("/Users/sohamchatterjee/Documents/UvA/RecSYS/Project/ebnerd_data")
df_articles = (
    pl.scan_parquet(PATH_DATA.joinpath("articles.parquet"))
    .select(pl.col(DEFAULT_ARTICLE_ID_COL, DEFAULT_CATEGORY_COL))
    .with_columns(pl.Series(TOKEN_COL, np.random.randint(0, 20, (1, 10))))
    .collect()
)
df_history = (
    pl.scan_parquet(PATH_DATA.joinpath("ebnerd_demo","train", "history.parquet"))
    .select(DEFAULT_USER_COL, DEFAULT_HISTORY_ARTICLE_ID_COL)
    .with_columns(pl.col(DEFAULT_HISTORY_ARTICLE_ID_COL).list.tail(3))
)
df_behaviors = (
    pl.scan_parquet(PATH_DATA.joinpath("ebnerd_demo","train", "behaviors.parquet"))
    .select(DEFAULT_USER_COL, DEFAULT_INVIEW_ARTICLES_COL, DEFAULT_CLICKED_ARTICLES_COL)
    .with_columns(pl.col(DEFAULT_INVIEW_ARTICLES_COL).list.len().alias(N_SAMPLES))
    .join(df_history, on=DEFAULT_USER_COL, how="left")
    .collect()
    .pipe(create_binary_labels_column)
)
# => MAPPINGS:
article_mapping = create_article_id_to_value_mapping(
    df=df_articles, value_col=TOKEN_COL
)
user_mapping = create_user_id_to_int_mapping(df=df_behaviors)
# => NPRATIO IMPRESSION - SAME LENGTHS:
df_behaviors_train = df_behaviors.filter(pl.col(N_SAMPLES) == pl.col(N_SAMPLES).min())
# => FOR TEST-DATALOADER
label_lengths = df_behaviors[DEFAULT_INVIEW_ARTICLES_COL].list.len().to_list()



In [488]:
df_articles['article_id'].max()

9803607

In [489]:
class PPRecDataLoader(Dataset):
    """
    NPA and LSTUR shares the same DataLoader
    """

    def __init__(self,behaviors: pd.DataFrame, history_column: str, article_dict: dict[int, any], unknown_representation: str,
                 eval_mode: bool = False,
    batch_size: int = 32,
    inview_col: str = DEFAULT_INVIEW_ARTICLES_COL,
    labels_col: str = DEFAULT_LABELS_COL,
    user_col: str = DEFAULT_USER_COL, user_id_mapping: dict[int, int] = None, unknown_user_value: int = 0,kwargs: dict = field(default_factory=dict)):
        self.behaviors = behaviors
        self.history_column = history_column
        self.article_dict = article_dict
        self.unknown_representation = unknown_representation
        self.eval_mode: bool = False
        self.batch_size = batch_size
        self.inview_col = inview_col 
        self.labels_col = labels_col
        self.user_col = user_col 
        self.kwargs = kwargs
        self.lookup_article_index, self.lookup_article_matrix = create_lookup_objects(
            self.article_dict, unknown_representation=self.unknown_representation
        )
        #self.unknown_index = torch.tensor([0])
        self.unknown_index = 0
        
        
        
        self.user_id_mapping = user_id_mapping
        self.unknown_user_value = unknown_user_value

        
    def __len__(self) -> int:
        return int(np.ceil(len(self.X) / float(self.batch_size)))
    def set_kwargs(self, kwargs: dict):
        for key, value in kwargs.items():
            setattr(self, key, value)
    
    def load_data(self): 
    # -> tuple[pd.DataFrame, pd.DataFrame]:
        
        #print(self.behaviors)
     
        X = self.behaviors.drop(self.labels_col)
       
        X= X.with_columns(
            n_samples=self.behaviors[self.inview_col].map_elements(len)
        )
       
        y = self.behaviors[self.labels_col]
       
        self.X, self.y = X,y
        return self.X, self.y

    

    def transform(self, df: pd.DataFrame) -> pd.DataFrame:
        #print(df)
        
        return (
            df.pipe(
                map_list_article_id_to_value,
                behaviors_column=self.history_column,
                mapping=self.lookup_article_index,
                fill_nulls=self.unknown_index,
                drop_nulls=False,
            )
            .pipe(
                map_list_article_id_to_value,
                behaviors_column=self.inview_col,
                mapping=self.lookup_article_index,
                fill_nulls=self.unknown_index,
                drop_nulls=False,
            )
            .with_columns(
                **{self.user_col: lambda x: x[self.user_col].replace(
                    self.user_id_mapping, default=self.unknown_user_value
                )}
            )
        )

    def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
        """
        user_indexes:       ()
        his_input_title:    (samples, history_size, document_dimension)
        pred_input_title:   (samples, npratio, document_dimension)
        batch_y:            (samples, npratio)
        """
        batch_X1 = self.X[idx * self.batch_size : (idx + 1) * self.batch_size]
        #print(batch_X1)
        batch_X = batch_X1.pipe(
            self.transform
        )
        
        batch_y = self.y[idx * self.batch_size : (idx + 1) * self.batch_size]
        # =>
        if self.eval_mode:
            repeats = np.array(batch_X["n_samples"])
            # =>
            batch_y = np.array(batch_y.explode().to_list()).reshape(-1, 1)
            # =>
            user_indexes = np.array(
                batch_X[self.user_col].repeat(repeats).explode()
            ).reshape(-1, 1)
            # =>
            his_input_title = repeat_by_list_values_from_matrix(
                batch_X[self.history_column].to_list(),
                matrix=self.lookup_article_matrix,
                repeats=repeats,
            )
            # =>
            pred_input_title = self.lookup_article_matrix[
                batch_X[self.inview_col].explode().to_list()
            ]
        else:
            # =>
            batch_y = np.array(batch_y.to_list())
            # =>
            user_indexes = np.array(batch_X[self.user_col].to_list()).reshape(-1, 1)
            # =>
            his_input_title = self.lookup_article_matrix[
                batch_X[self.history_column].to_list()
            ]
            # =>
            print(batch_X[self.inview_col].to_list())
            pred_input_title = self.lookup_article_matrix[
                batch_X[self.inview_col].to_list()
            ]
            
            pred_input_title = np.squeeze(pred_input_title, axis=2)
        # =>
        his_input_title = np.squeeze(his_input_title, axis=2)
        return (
            torch.from_numpy(user_indexes).float(),
            torch.from_numpy(his_input_title).float(),
            torch.from_numpy(pred_input_title).float(),
        ), torch.from_numpy(batch_y).float()
    
    def __iter__(self):
        for idx in range(len(self)):
            yield self.__getitem__(idx)


In [490]:
def test_PPRecDataLoader():
    train_dataloader = PPRecDataLoader(
        behaviors=df_behaviors_train,
        article_dict=article_mapping,
        user_id_mapping=user_mapping,
        history_column=DEFAULT_HISTORY_ARTICLE_ID_COL,
        unknown_representation="zeros",
        batch_size=4,
    )
    train_dataloader.load_data()
   

    batch = train_dataloader.__iter__().__next__()

    assert train_dataloader.__len__() == int(np.ceil(df_behaviors_train.shape[0] / 100))
    assert len(batch) == 2, "There should be two outputs: (inputs, labels)"
    assert (
        len(batch[0]) == 3
    ), "LSTUR has two outputs (user_indexes, his_input_title, pred_input_title_one)"

    for type_in_batch in batch[0][0]:
        assert isinstance(
            type_in_batch.ravel()[0], np.integer
        ), "Expected output to be integer; used for lookup value"

    assert isinstance(
        batch[1].ravel()[0], np.integer
    ), "Expected output to be integer; this is label"

    test_dataloader = PPRecDataLoader(
        user_id_mapping=user_mapping,
        behaviors=df_behaviors,
        article_dict=article_mapping,
        history_column=DEFAULT_HISTORY_ARTICLE_ID_COL,
        unknown_representation="zeros",
        batch_size=BATCH_SIZE,
        eval_mode=True,
    )

    batch = test_dataloader.__iter__().__next__()
    assert len(batch[1]) == sum(
        label_lengths[:BATCH_SIZE]
    ), "Should have unfolded all the test samples"

In [491]:
test_PPRecDataLoader()

0
col("article_id_fixed").fill_null([0])
0
col("article_ids_inview").fill_null([0])


TypeError: invalid literal value: '<function PPRecDataLoader.transform.<locals>.<lambda> at 0x2d213bec0>'