In [None]:
import rootutils
rootutils.setup_root('.', indicator='.project-root', pythonpath=True)

import os

import pandas as pd
import json

from relbench.base import Table
from IPython.display import display
import duckdb

from src.helpers.task_vector_generation_helper import get_task_columns_dict, generate_task_vector_clauses, get_timestamps_for_split, transform_features
from src.helpers.relbench_helper import get_dataset
from src.helpers.dataset_helper import get_default_db
from src.helpers.database_helper import read_stypes_cache
from src.definitions import DATA_DIR

dataset_name = 'rel-amazon'
stype_dict = read_stypes_cache(os.path.join(DATA_DIR, dataset_name, 'stypes.json'))
dataset = get_dataset(dataset_name)
timedelta = pd.Timedelta(days=365 // 4)
num_eval_timestamps = 1

In [None]:
db, _, _ = get_default_db(dataset_name, dataset, 'glove')
task_columns_dict = get_task_columns_dict(db, stype_dict)

sub_customer_review_clauses, final_customer_review_clauses, when_clause_customer_review = \
    generate_task_vector_clauses(['customer', 'review'], 'customer_id', task_columns_dict, stype_dict)
sql_query = f"""
    SELECT
        universal.timestamp as timestamp,
        universal.customer_id as customer_id,
    FROM
        -- universal table
        (
            SELECT DISTINCT
                t.timestamp,
                r.customer_id
            FROM
                timestamp_df t
                CROSS JOIN (
                    SELECT 
                        customer_id,
                        MIN(review_time) AS first_review
                    FROM review
                    GROUP BY customer_id
                ) r
            WHERE
                -- Case 1: Reviews that fall within a window
                EXISTS (
                    SELECT 1
                    FROM all_timestamp_df t2
                    WHERE t2.timestamp <= t.timestamp
                    AND r.first_review > t2.timestamp
                    AND r.first_review <= t2.timestamp + INTERVAL '{timedelta}'
                )
                -- Case 2 (modified): If first review is before earliest timestamp, 
                -- include for ALL timestamps
                OR (
                    r.first_review <= (SELECT MIN(timestamp) FROM all_timestamp_df)
                )
        ) as universal
"""
print(sql_query)

In [None]:
# Register to duckdb
split = 'train'
db, train_timestamps = get_timestamps_for_split(dataset, stype_dict, split, timedelta, num_eval_timestamps)
timestamp_df = pd.DataFrame({"timestamp": train_timestamps})
all_timestamp_df = pd.DataFrame({"timestamp": train_timestamps})
df_dict = {f"{table_name}": table.df for table_name, table in db.table_dict.items()}
duckdb.register("timestamp_df", timestamp_df)
duckdb.register("all_timestamp_df", all_timestamp_df)
for table_name, df in df_dict.items():
    duckdb.register(table_name, df)

train_df = duckdb.sql(sql_query).df()
assert (train_df['customer_id'].max() < len(db.table_dict['customer']))

print("Number of customers: ", len(train_df['customer_id'].unique()))

In [None]:
# Register to duckdb
split = 'val'
db, val_timestamps = get_timestamps_for_split(dataset, stype_dict, split, timedelta, num_eval_timestamps)
timestamp_df = pd.DataFrame({"timestamp": val_timestamps})
all_timestamp_df = pd.DataFrame({"timestamp": pd.concat([pd.Series(val_timestamps), pd.Series(train_timestamps)], ignore_index=True)})
df_dict = {f"{table_name}": table.df for table_name, table in db.table_dict.items()}
duckdb.register("timestamp_df", timestamp_df)
duckdb.register("all_timestamp_df", all_timestamp_df)
for table_name, df in df_dict.items():
    duckdb.register(table_name, df)

val_df = duckdb.sql(sql_query).df()
assert (val_df['customer_id'].max() < len(db.table_dict['customer']))

print("Number of customers: ", len(val_df['customer_id'].unique()))

In [None]:
# Register to duckdb
split = 'test'
db, test_timestamps = get_timestamps_for_split(dataset, stype_dict, split, timedelta, num_eval_timestamps)
timestamp_df = pd.DataFrame({"timestamp": test_timestamps})
all_timestamp_df = pd.DataFrame({"timestamp": pd.concat([pd.Series(test_timestamps), pd.Series(val_timestamps), pd.Series(train_timestamps)], ignore_index=True)})
df_dict = {f"{table_name}": table.df for table_name, table in db.table_dict.items()}
duckdb.register("timestamp_df", timestamp_df)
duckdb.register("all_timestamp_df", all_timestamp_df)
for table_name, df in df_dict.items():
    duckdb.register(table_name, df)

test_df = duckdb.sql(sql_query).df()
assert (test_df['customer_id'].max() < len(db.table_dict['customer']))

print("Number of customers: ", len(test_df['customer_id'].unique()))

In [6]:
assert train_df.shape[1] == val_df.shape[1] == test_df.shape[1]
entity_col = "customer_id"
entity_table = "customer"
time_col = "timestamp"

# Create Table objects
train_table = Table(df=train_df, fkey_col_to_pkey_table={entity_col: entity_table}, pkey_col=None, time_col=time_col)
val_table = Table(df=val_df, fkey_col_to_pkey_table={entity_col: entity_table}, pkey_col=None, time_col=time_col)
test_table = Table(df=test_df, fkey_col_to_pkey_table={entity_col: entity_table}, pkey_col=None, time_col=time_col)

task_name = "user-ssl"
path_dir = os.path.join(DATA_DIR, 'relbench', dataset_name, 'tasks', task_name)
os.makedirs(path_dir, exist_ok=True)

train_table.save(os.path.join(path_dir, 'train.parquet'))
val_table.save(os.path.join(path_dir, 'val.parquet'))
test_table.save(os.path.join(path_dir, 'test.parquet'))