In [None]:
# default_exp datasets.kkbox

# KKBox
> KKBox music dataset.

In [None]:
#hide
from nbdev.showdoc import *
from fastcore.nb_imports import *
from fastcore.test import *

In [None]:
#export
from recohut.datasets.bases.ctr import *
from recohut.utils.common_utils import download_url, extract_zip

import pandas as pd
import numpy as np
import os
from datetime import datetime, date

In [None]:
#export
class KKBoxDataset(CTRDataset):

    feature_cols = [
                    {'name': ["msno","song_id","source_system_tab","source_screen_name","source_type","city","gender",
                  "registered_via","language"], 'active': True, 'dtype': 'str', 'type': 'categorical'},
                    {'name': 'genre_ids', 'active': True, 'dtype': 'str', 'type': 'sequence', 'max_len': 3},
                    {'name': 'artist_name', 'active': True, 'dtype': 'str', 'type': 'sequence', 'max_len': 3},
                    {'name': 'isrc', 'active': True, 'dtype': 'str', 'type': 'categorical', 'preprocess': 'extract_country_code'},
                    {'name': 'bd', 'active': True, 'dtype': 'str', 'type': 'categorical', 'preprocess': 'bucketize_age'}]
                    
    label_col = {'name': 'label', 'dtype': float}

    url = "https://zenodo.org/record/5700987/files/KKBox_x1.zip"
    
    @property
    def raw_file_names(self):
        return ['train.csv',
                'valid.csv',
                'test.csv']

    def download(self):
        path = download_url(self.url, self.raw_dir)
        extract_zip(path, self.raw_dir)
        os.unlink(path)

    def extract_country_code(self, df, col_name):
        return df[col_name].apply(lambda isrc: isrc[0:2] if not pd.isnull(isrc) else "")

    def bucketize_age(self, df, col_name):
        def _bucketize(age):
            if pd.isnull(age):
                return ""
            else:
                age = float(age)
                if age < 1 or age > 95:
                    return ""
                elif age <= 10:
                    return "1"
                elif age <=20:
                    return "2"
                elif age <=30:
                    return "3"
                elif age <=40:
                    return "4"
                elif age <=50:
                    return "5"
                elif age <=60:
                    return "6"
                else:
                    return "7"
        return df[col_name].apply(_bucketize)

In [None]:
#export
class KKBoxDataModule(CTRDataModule):
    dataset_cls = KKBoxDataset

Example

In [None]:
params = {'model_id': 'DeepCross_demo',
              'data_dir': '/content/data',
              'model_root': './checkpoints/',
              'dnn_hidden_units': [64, 64],
              'dnn_activations': "relu",
              'crossing_layers': 3,
              'learning_rate': 1e-3,
              'net_dropout': 0,
              'batch_norm': False,
              'optimizer': 'adamw',
              'task': 'binary_classification',
              'loss': 'binary_crossentropy',
              'metrics': ['logloss', 'AUC'],
              'embedding_dim': 10,
              'batch_size': 10000,
              'epochs': 3,
              'shuffle': True,
              'seed': 2019,
              'use_hdf5': True,
              'workers': 1,
              'verbose': 0}

In [None]:
!rm -r /content/data/processed/*
ds = KKBoxDataModule(**params)
ds.prepare_data()
ds.setup()

for batch in ds.train_dataloader():
    print(batch)
    break

rm: cannot remove '/content/data/processed/*': No such file or directory


  "DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7."
Processing...
Done!


[tensor([[2.9800e+02, 3.8693e+04, 1.0000e+00,  ..., 4.0296e+04, 6.0000e+00,
         4.0000e+00],
        [1.1104e+04, 4.1240e+03, 1.0000e+00,  ..., 4.0296e+04, 1.0000e+00,
         1.0000e+00],
        [1.2177e+04, 9.8400e+02, 1.0000e+00,  ..., 4.0296e+04, 1.0000e+00,
         2.0000e+00],
        ...,
        [6.5680e+03, 3.4050e+03, 2.0000e+00,  ..., 4.0296e+04, 0.0000e+00,
         0.0000e+00],
        [2.7100e+02, 3.4463e+04, 2.0000e+00,  ..., 4.0296e+04, 6.0000e+00,
         4.0000e+00],
        [3.7290e+03, 1.2920e+03, 1.0000e+00,  ..., 4.0296e+04, 1.0000e+00,
         1.0000e+00]], dtype=torch.float64), tensor([1., 0., 1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1.,
        0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 1., 1., 1., 0., 1., 1., 1., 0.,
        0., 0., 0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 1., 1., 0., 1., 1., 0.,
        1., 0., 1., 0., 1., 1., 1., 0., 0., 1.], dtype=torch.float64)]


> **References**
> - https://github.com/xue-pai/FuxiCTR/blob/main/config/dataset_config/KKBox.yaml
> - https://github.com/openbenchmark/BARS/tree/master/ctr_prediction/datasets/KKBox