In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np

# 1. 构建DataFrame
df = pd.DataFrame({
    'x': [np.random.rand(5).tolist() for _ in range(100)],   # 5维特征
    'y': [np.random.rand(2).tolist() for _ in range(100)]    # 2维标签
})
print(df.head())
# 2. 定义自定义Dataset类
class ListDataset(Dataset):
    def __init__(self, dataframe):
        self.x = torch.tensor(dataframe['x'].tolist(), dtype=torch.float32)
        self.y = torch.tensor(dataframe['y'].tolist(), dtype=torch.float32)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

# 3. 实例化数据集
dataset = ListDataset(df)

# 4. 可选：使用DataLoader封装
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# 5. 测试读取一个batch
for batch_x, batch_y in dataloader:
    print("x shape:", batch_x.shape)  # torch.Size([16, 5])
    print("y shape:", batch_y.shape)  # torch.Size([16, 2])
    break

                                                   x  \
0  [0.3030138533582116, 0.6712399866387153, 0.361...   
1  [0.3516107207568504, 0.605682316092597, 0.9822...   
2  [0.30424811830520104, 0.39276607731198265, 0.8...   
3  [0.07439907945787438, 0.6075642291454495, 0.80...   
4  [0.06436289812511797, 0.8000812882522669, 0.13...   

                                            y  
0  [0.08598679291478462, 0.02774622628993273]  
1    [0.7318663135247675, 0.7244321485864631]  
2   [0.9338693048944986, 0.23047153439460855]  
3     [0.850913288204863, 0.9686055308276817]  
4    [0.6035661798811656, 0.6801631786361815]  
x shape: torch.Size([16, 5])
y shape: torch.Size([16, 2])


In [3]:
import pickle
with open("/workspace/moe_analysis/outputs/t3/43.pkl", "rb") as f:
    data = pickle.load(f)

In [7]:
type(data), len(data)/58

(list, 290.0)

In [6]:
data[-1:]

[{'mode': 'decode',
  'token_idx': 290,
  'layer_idx': 57,
  'topk_idx': [[35, 99, 167, 176, 231, 234, 255, 175]],
  'hidden_states': [[-0.302734375,
    -0.5,
    0.024658203125,
    -0.609375,
    -0.11083984375,
    -0.375,
    -0.2255859375,
    -0.09814453125,
    -0.279296875,
    0.63671875,
    -0.126953125,
    -0.0211181640625,
    0.185546875,
    -0.30859375,
    0.20703125,
    -0.2275390625,
    0.01153564453125,
    0.5390625,
    -0.14453125,
    0.341796875,
    -0.4453125,
    0.11962890625,
    -0.70703125,
    0.205078125,
    0.18359375,
    0.267578125,
    -0.263671875,
    -0.130859375,
    0.0673828125,
    -0.333984375,
    0.248046875,
    -0.29296875,
    0.03662109375,
    0.130859375,
    -0.1484375,
    0.11083984375,
    0.04443359375,
    -0.09814453125,
    -0.71484375,
    -0.734375,
    -0.287109375,
    -0.06591796875,
    -0.5625,
    0.2001953125,
    0.341796875,
    -0.55859375,
    -0.10302734375,
    -0.14453125,
    0.11279296875,
    0.34179