In [28]:
import pandas as pd
import numpy as np

In [29]:
NAMES = ['movie', 'Amazon-KG-5core-Movies_and_TV']

In [30]:
def build_graph_and_interactions(name = NAMES):
    name0 = NAMES[0]
    name1 = NAMES[1]

    #########################################################
    #### 1. Build static knowledge graph
    #########################################################
    interactions = pd.read_csv(f'./data/{name0}/{name0}_interaction.csv', sep= ',')
    user_tokens = interactions['user_id:token'].unique()
    num_users = len(user_tokens)

    link = pd.read_csv(f'./data/{name0}/{name1}.link', sep="\t")
    entity_tokens = link['entity_id:token'].unique()

    static_graph = pd.read_csv(f'./data/{name0}/{name1}.kg', sep="\t")
    head_tokens = static_graph['head_id:token'].unique()
    tail_tokens = static_graph['tail_id:token'].unique()

    all_entity_tokens = np.unique(np.concatenate([entity_tokens, head_tokens, tail_tokens]))
    entity2id = {entity: (num_users + idx + 1) for idx, entity in enumerate(all_entity_tokens)}
    # id2entity = {idx: entity for entity, idx in entity2id.items()}

    item2entity = dict(zip(link['item_id:token'], link['entity_id:token']))
    item2entity_id = {item: entity2id[entity] for item, entity in item2entity.items()}

    static_graph['head_id'] = static_graph['head_id:token'].map(entity2id)
    static_graph['tail_id'] = static_graph['tail_id:token'].map(entity2id)

    static_graph['relation_id'] = static_graph['relation_id:token'].astype('category').cat.codes  
    static_graph['relation_id'] = static_graph['relation_id'] + 1   

    static_graph = static_graph.dropna()
    static_graph = static_graph.astype({'head_id': 'long', 'relation_id': 'long', 'tail_id': 'long'})
    static_graph = static_graph[['head_id', 'relation_id', 'tail_id',
             'head_id:token', 'relation_id:token', 'tail_id:token']]

    #########################################################
    #### . Build interactions
    #########################################################
    user2id = {user: (idx + 1) for idx, user in enumerate(user_tokens)}

    interactions['entity_id:token'] = interactions['item_id:token'].map(item2entity)
    interactions['user_id'] = interactions['user_id:token'].map(user2id)
    interactions['entity_id'] = interactions['item_id:token'].map(item2entity_id)

    interactions = interactions.dropna()
    interactions = interactions.astype({'user_id': 'long', 'entity_id': 'long'})
    interactions = interactions[['user_id', 'entity_id', 'timestamp','user_id:token',
                                 'entity_id:token', 'item_id:token']]

    return static_graph, interactions

In [31]:
def leave_k_out_split(interaction_df: pd.DataFrame, k:int, user_col='user_id', time_col='timestamp'):
    """
    Split dataset using Leave-K-Out strategy (Chronological).
    
    Args:
        interaction_df (pd.DataFrame): Dataframe of user-item interactions
        k (int): k-items for validation dataset & k-items for test dataset
        
    Returns:
        train_df, val_df, test_df
    """
    
    #########################################################
    #### 1. Sort dataframe by timestamp
    #########################################################
    interaction_df[time_col] = pd.to_datetime(interaction_df[time_col])        
    interaction_df = interaction_df.sort_values(by=[user_col, time_col]).reset_index(drop=True)
    
    #########################################################
    #### 2. Filter data
    #### 2.1 To split Train (> k item) + Val (k items) + Test (k items)
    #### 2.2 Each user must have >= 3*k interactions
    #########################################################
    min_interactions = 3*k
    
    user_counts = interaction_df[user_col].value_counts()
    valid_users = user_counts[user_counts >= min_interactions].index
    
    # Chỉ giữ lại các user đủ điều kiện
    df_filtered = interaction_df[interaction_df[user_col].isin(valid_users)].copy()
        
    if len(df_filtered) == 0:
        raise ValueError("K value is too high, please reduce K value.".format(k))

    #########################################################
    #### 3. Split train, validation, test dataset
    #########################################################

    test_indices = df_filtered.groupby(user_col).tail(k).index
    test_interaction_df = df_filtered.loc[test_indices]
    
    remaining_after_test = df_filtered.drop(test_indices)
    
    val_indices = remaining_after_test.groupby(user_col).tail(k).index
    val_interaction_df = df_filtered.loc[val_indices]
    
    train_indices = remaining_after_test.drop(val_indices).index
    train_interaction_df = df_filtered.loc[train_indices]
    
    return train_interaction_df, val_interaction_df, test_interaction_df

In [32]:
if __name__ == '__main__':
    name0 = NAMES[0]
    name1 = NAMES[1]

    static_graph, interaction_df = build_graph_and_interactions()
    
    static_graph = static_graph.sort_values(by=['tail_id'])
    interaction_df = interaction_df.sort_values(by=['user_id'])

    static_graph.to_csv(f'./data/{name0}/{name0}_processed_static_graph.csv', index=False)
    interaction_df.to_csv(f'./data/{name0}/{name0}_processed_interactions.csv', index= False)
        
    train_interaction_df, val_interaction_df, test_interaction_df = leave_k_out_split(interaction_df, k = 10)
    
    train_interaction_df.to_csv(f'./data/{name0}/{name0}_train_interactions.csv', index= False)
    val_interaction_df.to_csv(f'./data/{name0}/{name0}_val_interactions.csv', index= False)
    test_interaction_df.to_csv(f'./data/{name0}/{name0}_test_interactions.csv', index= False)