In [None]:
from recbole.data import create_dataset, data_preparation
from recbole.config import Config

In [None]:
model='NPE'
dataset='ml-100k'

config_dict = {
    'eval_args': {
        "order": "TO",
        "split": {"RS": [0.8, 0.1, 0.1]},
        "group_by": None
    },
    'train_neg_sample_args': None
}

config = Config(
    model=model,
    dataset=dataset,
    config_dict=config_dict
)

In [None]:
dataset = create_dataset(config)
# train_data, valid_data, test_data = data_preparation(config, dataset)
# model_type = config["MODEL_TYPE"]
# built_datasets = dataset.build()
# train_dataset, valid_dataset, test_dataset = built_datasets

In [None]:
dataset.config["normalize_field"]

In [None]:
dataset.float_like_fields

In [None]:
from date

In [None]:
field = "timestamp"

assert field in dataset.fields(), f"Dataset not existed field '{field}'"

for feat in dataset.field2feats(field):
    break

feat

In [None]:
dataset.field2feats(field)

In [None]:
test_dataset.inter_feat

In [None]:
group_by = 'user_id'
# grouped_inter_feat_index = dataset._grouped_index(dataset.inter_feat[group_by].numpy())

dataset.inter_feat[group_by].to_numpy()

## Implement time cutoff Dataset

In [None]:
import copy
import importlib
import os
import pickle
import warnings
from typing import Literal

import numpy as np

from recbole.data.dataloader import *
from recbole.sampler import KGSampler, Sampler, RepeatableSampler
from recbole.utils import ModelType, ensure_dir, get_local_time, set_color
from recbole.utils.argument_list import dataset_arguments
from recbole.data.dataset import Dataset
from recbole.utils import (
    FeatureType,
    set_color,


)

In [None]:
# TODO: HoangLe [Jun-09]: How to replace config["MODEL_TYPE"] to TimeCutoffDataset

class TimeCutoffDataset(Dataset):
    def __init__(self, config):
        self.timestamp_max, self.timestamp_min = 0., 0.
        # NOTE: HoangLe [Jun-09]: May allow the user to specify the timestamp field in config
        self.field_timestamp = "timestamp"

        # Check if timestamp field is available
        assert field in dataset.fields(), f"Dataset not existe field '{field}'"

        super().__init__(config)

    def _normalize(self):

        # Extract max-min of field 'timestamp'
        feat_timestamp = self.field2feats(self.field_timestamp)[0]
        assert feat_timestamp and self.field_timestamp in feat_timestamp, f"Feat not exist field '{self.field_timestamp}'"

        self.timestamp_max = np.max(feat_timestamp[self.field_timestamp])
        self.timestamp_min = np.min(feat_timestamp[self.field_timestamp])

        return super()._normalize()

    def _fill_nan(self):
        """Missing value imputation.

        For fields with type :obj:`~recbole.utils.enum_type.FeatureType.TOKEN`, missing value will be filled by
        ``[PAD]``, which indexed as 0.

        For fields with type :obj:`~recbole.utils.enum_type.FeatureType.FLOAT`, missing value will be filled by
        the average of original data.

        Note:
            This is similar to the recbole's original implementation. The difference is the change in inplace operation to suit the pandas 3.0
        """
        self.logger.debug(set_color("Filling nan", "green"))

        for feat_name in self.feat_name_list:
            feat = getattr(self, feat_name)
            for field in feat:
                ftype = self.field2type[field]
                if ftype == FeatureType.TOKEN:
                    feat[field] = feat[field].fillna(value=0)
                elif ftype == FeatureType.FLOAT:
                    feat[field] = feat[field].fillna(value=feat[field].mean())
                else:
                    dtype = np.int64 if ftype == FeatureType.TOKEN_SEQ else np.float64
                    feat[field] = feat[field].apply(
                        lambda x: (
                            np.array([], dtype=dtype) if isinstance(x, float) else x
                        )
                    )

    def build(self):
        self._change_feat_format()

        if self.benchmark_filename_list is not None:
            super().build()

        # ordering
        ordering_args = self.config["eval_args"]["order"]
        if ordering_args == "TO":
            self.sort(by=self.time_field)
        else:
            raise AssertionError("The ordering_method must be 'TO.")

        # splitting & grouping
        split_args = self.config["eval_args"]["split"]
        if split_args is None:
            raise ValueError("The split_args in eval_args should not be None.")
        if not isinstance(split_args, dict):
            raise ValueError(f"The split_args [{split_args}] should be a dict.")

        split_mode = list(split_args.keys())[0]
        assert len(split_args.keys()) == 1
        if split_mode != "CO":
            raise NotImplementedError("The split_mode must be 'CO'.")
        elif split_mode == "CO":
            cutoff = split_args["RS"]
            # NOTE: HoangLe [Jun-05]: cutoff may come with different types: string, int

            group_by = self.config["eval_args"]["group_by"]
            datasets = self.split_by_cuttoff(cutoff=cutoff, group_by=group_by)
    
        
        return datasets

    def split_by_cuttoff(self, cutoff: str|int, group_by: str) -> list[Dataset]:
        """Split the interations by cutoff date

        Args:
            cutoff (str | int): cutoff date in Unix timestamp format
            group_by (str): field to group by, usually the user_id

        Returns:
            list[Dataset]: list of training/validation/testing dataset, whose interaction features has been split.

        Notes:
            cutoff may be different types: string of Unix timestamp (e.g. '1717923174'), integer of Unix timestamp (e.g. 1717923174)
        """

        # TODO: HoangLe [Jun-05]: Implement this, may follow method split_by_ratio()
        
        self.logger.debug(f"split by cutoff date = '{cutoff}', group_by=[{group_by}]")

        # Convert cutoff to suitable format and apply 0-1 normalization with max/min timestamp
        if isinstance(cutoff, str):
            cutoff_conv = float(cutoff)
        else:
            cutoff_conv = float(cutoff)

        def norm_timestamp(timestamp: float):
            mx, mn = self.timestamp_max, self.timestamp_min
            if mx == mn:
                self.logger.warning(
                    f"All the same value in [{field}] from [{feat}_feat]."
                )
                arr = 1.0
            else:
                arr = (timestamp - mn) / (mx - mn)
            return arr

        cutoff_conv = norm_timestamp(cutoff_conv)
            

        grouped_inter_feat_index = self._grouped_index(
            self.inter_feat[group_by].numpy()
        )

        next_index = [[]]*3     # 'next_index' contains the indices for training/validation/testing dataset
        for grouped_index in grouped_inter_feat_index:
            # Split the grouped_index into into train/validation/test

            train_indices, val_indices, test_indices = [], [], []

            ## TODO: HoangLe [Jun-05]: Investivate how to access 'timestamp' and how to split the self.inter_feat using cutoff
            split_ids = self._calcu_split_ids(tot=tot_cnt, ratios=ratios)
            for index, start, end in zip(
                next_index, [0] + split_ids, split_ids + [tot_cnt]
            ):
                index.extend(grouped_index[start:end])

        self._drop_unused_col()
        next_df = [self.inter_feat[index] for index in next_index]
        next_ds = [self.copy(_) for _ in next_df]
        return next_ds

In [None]:
def create_dataset(config):
    """Create dataset according to :attr:`config['model']` and :attr:`config['MODEL_TYPE']`.
    If :attr:`config['dataset_save_path']` file exists and
    its :attr:`config` of dataset is equal to current :attr:`config` of dataset.
    It will return the saved dataset in :attr:`config['dataset_save_path']`.

    Args:
        config (Config): An instance object of Config, used to record parameter information.

    Returns:
        Dataset: Constructed dataset.
    """
    dataset_module = importlib.import_module("recbole.data.dataset")
    if hasattr(dataset_module, config["model"] + "Dataset"):
        dataset_class = getattr(dataset_module, config["model"] + "Dataset")
    else:
        model_type = config["MODEL_TYPE"]
        type2class = {
            ModelType.GENERAL: "Dataset",
            ModelType.SEQUENTIAL: "SequentialDataset",
            ModelType.CONTEXT: "Dataset",
            ModelType.KNOWLEDGE: "KnowledgeBasedDataset",
            ModelType.TRADITIONAL: "Dataset",
            ModelType.DECISIONTREE: "Dataset",
            # TODO: HoangLe [Jun-09]: Continue the below
            ModelType: "TimeCutoffDataset"
        }
        dataset_class = getattr(dataset_module, type2class[model_type])

    default_file = os.path.join(
        config["checkpoint_dir"], f'{config["dataset"]}-{dataset_class.__name__}.pth'
    )
    file = config["dataset_save_path"] or default_file
    if os.path.exists(file):
        with open(file, "rb") as f:
            dataset = pickle.load(f)
        dataset_args_unchanged = True
        for arg in dataset_arguments + ["seed", "repeatable"]:
            if config[arg] != dataset.config[arg]:
                dataset_args_unchanged = False
                break
        if dataset_args_unchanged:
            logger = getLogger()
            logger.info(set_color("Load filtered dataset from", "pink") + f": [{file}]")
            return dataset

    dataset = dataset_class(config)
    if config["save_dataset"]:
        dataset.save()
    return dataset
