In [None]:
import json
import pandas as pd
import numpy as np
from collections import defaultdict

# 全局变量
max_length = 10
pad_item = 2359

def load_data(file_path):
    with open(file_path, 'r') as f:
        return json.load(f)

def pad_history(itemlist, length, pad_item):
    if len(itemlist) >= length:
        return itemlist[-length:]
    else:
        return itemlist + [pad_item] * (length - len(itemlist))

def generate_sequences(data, is_train=True, max_length=10, pad_item=2359):
    sequences = []
    for user_id, books in data.items():
        if is_train:
            # 训练集：为每次交互生成一个序列
            history = []
            for book in books:
                s = list(history)
                if s:  # 只有当历史不为空时才生成序列
                    len_s = len(s) if len(s) < max_length else max_length
                    s = pad_history(s, max_length, pad_item)
                    
                    # 过滤掉全是填充值的序列
                    if not all(item == pad_item for item in s):
                        sequences.append({
                            'seq': s,
                            'len_seq': len_s,
                            'next': book['book_id']
                        })
                
                history.append(book['book_id'])
                if len(history) > max_length:
                    history = history[-max_length:]
        else:
            # 验证集/测试集：每个用户只生成一个序列
            history = [book['book_id'] for book in books[:-1]]
            if history:  # 只有当历史不为空时才生成序列
                s = pad_history(history, max_length, pad_item)
                len_s = len(history) if len(history) < max_length else max_length
                
                # 过滤掉全是填充值的序列
                if not all(item == pad_item for item in s):
                    sequences.append({
                        'seq': s,
                        'len_seq': len_s,
                        'next': books[-1]['book_id']
                    })
    
    df = pd.DataFrame(sequences)
    df.reset_index(drop=True, inplace=True)  # 确保索引从0开始
    return df


def split_data_by_time(data, time_field='read_at'):
    id_time = defaultdict(list)
    for user_id, books in data.items():
        if books:
            earliest_time = min(book.get(time_field, 0) for book in books)
            id_time[user_id] = earliest_time
    
    sorted_ids = sorted(id_time, key=id_time.get)
    
    total_users = len(sorted_ids)
    train_size = int(0.8 * total_users)
    val_size = int(0.1 * total_users)
    
    train_ids = sorted_ids[:train_size]
    val_ids = sorted_ids[train_size:train_size+val_size]
    test_ids = sorted_ids[train_size+val_size:]
    
    return train_ids, val_ids, test_ids

def print_dataset_info(name, df):
    print(f"\n{name} 信息:")
    print(f"  总序列数: {len(df)}")
    print(f"  len_seq 分布:\n{df['len_seq'].value_counts().sort_index()}")
    print("\n  前5行数据:")
    pd.set_option('display.max_columns', None)
    pd.set_option('display.width', 1000)
    print(df.head().to_string())
    print("\n  seq列的前5个元素:")
    for i, seq in enumerate(df['seq'].head(), 1):
        print(f"    序列 {i}: {seq}")

# 主程序
if __name__ == "__main__":
    file_path = '/workspace/Goodreads/history/remapped_ya_user_sessions.json'
    data = load_data(file_path)

    # 检查数据结构
    sample_user = next(iter(data))
    sample_book = data[sample_user][0]
    print("Sample book data:", sample_book)

    # 确定时间字段
    time_field = 'read_at'  # 根据实际数据结构调整这个字段名
    if time_field not in sample_book:
        print(f"Warning: '{time_field}' not found in the data. Please specify the correct time field.")
        time_field = input("Please enter the correct time field name: ")

    # 按时间顺序划分数据
    train_ids, val_ids, test_ids = split_data_by_time(data, time_field)

    train_data = {user: data[user] for user in train_ids}
    val_data = {user: data[user] for user in val_ids}
    test_data = {user: data[user] for user in test_ids}

    # 生成序列
    train_df = generate_sequences(train_data, is_train=True, max_length=max_length, pad_item=pad_item)
    val_df = generate_sequences(val_data, is_train=False, max_length=max_length, pad_item=pad_item)
    test_df = generate_sequences(test_data, is_train=False, max_length=max_length, pad_item=pad_item)

    # 保存数据
    train_df.to_pickle('/workspace/Goodreads/history/train_data.df')
    val_df.to_pickle('/workspace/Goodreads/history/val_data.df')
    test_df.to_pickle('/workspace/Goodreads/history/test_data.df')

    # 生成 Test_data.df（与测试集相同，因为每个用户已经只有一个序列）
    test_df.to_pickle('/workspace/Goodreads/history/Test_data.df')

    # 打印数据集信息
    print_dataset_info("训练集", train_df)
    print_dataset_info("验证集", val_df)
    print_dataset_info("测试集", test_df)

    # 额外的数据统计
    print("\n数据统计:")
    print(f"总用户数: {len(data)}")
    print(f"训练集用户数: {len(train_data)}")
    print(f"验证集用户数: {len(val_data)}")
    print(f"测试集用户数: {len(test_data)}")
    
    all_book_ids = set()
    for user_books in data.values():
        all_book_ids.update(book['book_id'] for book in user_books)
    print(f"总图书数: {len(all_book_ids)}")

    print("\n数据集生成并保存完成。")  