In [1]:
import pandas as pd
import numpy as np

In [14]:

class DataFrameIterator:
    def __init__(self, df, stock_size, client_size, window_size):
        """
        初始化迭代器
        
        参数:
        df: 输入的DataFrame，包含date, stock_id, client_id和7个特征列
        stock_size: 每次迭代处理的股票数量
        client_size: 每次迭代处理的客户数量
        window_size: 每次迭代处理的时间窗口天数
        """
        self.df = df.copy()
        self.stock_size = stock_size
        self.client_size = client_size
        self.window_size = window_size
        
        # 获取唯一值列表
        self.dates = sorted(self.df['date'].unique())
        self.stocks = sorted(self.df['stock_id'].unique())
        self.clients = sorted(self.df['client_id'].unique())
        
        # 计算总迭代次数
        self.total_windows = len(self.dates) - window_size + 1  # 因为每次step=1
        self.total_stock_batches = (len(self.stocks) + stock_size - 1) // stock_size
        self.total_client_batches = (len(self.clients) + client_size - 1) // client_size
        
        # 初始化迭代计数器
        self.current_window = 0
        self.current_stock_batch = 0
        self.current_client_batch = 0
    
    def __iter__(self):
        return self
    
    def __next__(self):
        # 检查是否完成所有迭代
        if self.current_window >= self.total_windows:
            raise StopIteration
        
        # 获取当前批次的范围
        date_start = self.current_window
        date_end = date_start + self.window_size  # 窗口大小固定为window_size
        current_dates = self.dates[date_start:date_end]
        
        stock_start = self.current_stock_batch * self.stock_size
        stock_end = stock_start + self.stock_size
        current_stocks = self.stocks[stock_start:stock_end]
        
        client_start = self.current_client_batch * self.client_size
        client_end = client_start + self.client_size
        current_clients = self.clients[client_start:client_end]
        
        # 筛选数据
        mask = (self.df['date'].isin(current_dates)) & \
               (self.df['stock_id'].isin(current_stocks)) & \
               (self.df['client_id'].isin(current_clients))
        batch_df = self.df[mask].copy()
        
        # 更新迭代计数器 - 关键修改点：每次只增加1个时间窗口
        self.current_client_batch += 1
        if self.current_client_batch >= self.total_client_batches:
            self.current_client_batch = 0
            self.current_stock_batch += 1
            
            if self.current_stock_batch >= self.total_stock_batches:
                self.current_stock_batch = 0
                self.current_window += 1  # 这里每次只增加1，实现rolling step=1
        
        return batch_df, {
            'window': (date_start, date_end),
            'stocks': (stock_start, stock_end),
            'clients': (client_start, client_end),
            'current_date_window': current_dates  # 新增返回当前窗口的具体日期
        }
    
    def reset(self):
        """重置迭代器"""
        self.current_window = 0
        self.current_stock_batch = 0
        self.current_client_batch = 0


In [7]:
dates = pd.date_range(start='2023-01-01', periods=10)
stocks = [f'STK_{i}' for i in range(1, 21)]
clients = [f'CL_{i}' for i in range(1, 6)]


In [8]:
data = []
for date in dates:
    for stock in stocks:
        for client in clients:
            features = np.random.rand(7)
            data.append([date, stock, client] + list(features))

In [9]:
df = pd.DataFrame(data, columns=['date', 'stock_id', 'client_id'] + [f'feature_{i}' for i in range(1, 8)])


In [17]:
# 使用迭代器
iterator = DataFrameIterator(df, stock_size=205, client_size=2, window_size=3)

for batch_df, metadata in iterator:
    print(f"处理时间窗口: {metadata['window']}, 股票范围: {metadata['stocks']}, 客户范围: {metadata['clients']}")
    print(f"批次数据形状: {batch_df.shape}")
    print("---")

处理时间窗口: (0, 3), 股票范围: (0, 205), 客户范围: (0, 2)
批次数据形状: (120, 10)
---
处理时间窗口: (0, 3), 股票范围: (0, 205), 客户范围: (2, 4)
批次数据形状: (120, 10)
---
处理时间窗口: (0, 3), 股票范围: (0, 205), 客户范围: (4, 6)
批次数据形状: (60, 10)
---
处理时间窗口: (1, 4), 股票范围: (0, 205), 客户范围: (0, 2)
批次数据形状: (120, 10)
---
处理时间窗口: (1, 4), 股票范围: (0, 205), 客户范围: (2, 4)
批次数据形状: (120, 10)
---
处理时间窗口: (1, 4), 股票范围: (0, 205), 客户范围: (4, 6)
批次数据形状: (60, 10)
---
处理时间窗口: (2, 5), 股票范围: (0, 205), 客户范围: (0, 2)
批次数据形状: (120, 10)
---
处理时间窗口: (2, 5), 股票范围: (0, 205), 客户范围: (2, 4)
批次数据形状: (120, 10)
---
处理时间窗口: (2, 5), 股票范围: (0, 205), 客户范围: (4, 6)
批次数据形状: (60, 10)
---
处理时间窗口: (3, 6), 股票范围: (0, 205), 客户范围: (0, 2)
批次数据形状: (120, 10)
---
处理时间窗口: (3, 6), 股票范围: (0, 205), 客户范围: (2, 4)
批次数据形状: (120, 10)
---
处理时间窗口: (3, 6), 股票范围: (0, 205), 客户范围: (4, 6)
批次数据形状: (60, 10)
---
处理时间窗口: (4, 7), 股票范围: (0, 205), 客户范围: (0, 2)
批次数据形状: (120, 10)
---
处理时间窗口: (4, 7), 股票范围: (0, 205), 客户范围: (2, 4)
批次数据形状: (120, 10)
---
处理时间窗口: (4, 7), 股票范围: (0, 205), 客户范围: (4, 6)
批次数据形状: (60, 10)
---


In [None]:

    # 创建示例数据

# 创建DataFrame
data = []
for date in dates:
    for stock in stocks:
        for client in clients:
            features = np.random.rand(7)
            data.append([date, stock, client] + list(features))


# # 使用迭代器
# iterator = DataFrameIterator(df, stock_size=5, window_size=3)

# for i, batch in enumerate(iterator):
#     print(f"Batch {i+1}:")
#     print(f"日期范围: {batch['date'].min()} 到 {batch['date'].max()}")
#     print(f"股票数量: {batch['stock_id'].nunique()}")
#     print(f"记录数: {len(batch)}")
#     print("------")

IndentationError: expected an indented block after 'for' statement on line 8 (4240020543.py, line 9)