In [1]:
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict, load_from_disk
from transformers import T5Tokenizer, T5ForConditionalGeneration
from pathlib import Path

In [2]:
use_data_cache=False

model_path=Path('./models')

pretrained_mdl= "t5-small" 

In [3]:
if pretrained_mdl == "t5-small" :
    tokenizer = T5Tokenizer.from_pretrained(model_path / "t5-small-new")
    model = T5ForConditionalGeneration.from_pretrained(model_path / "t5-small-new")
    print("loading t5-small model...") 


loading t5-small model...


In [5]:
def load_data(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    for line in lines:
        line = line.strip()[1:-1]
        tokens = line.split(', ')
        tokens = [token.strip("'") for token in tokens]
        data.append(tokens)
    return data

#preprocess data
def preprocess_function(examples):
    inputs = tokenizer(examples['input_ids'], is_split_into_words=True, padding=False, truncation=False)
    targets = tokenizer(examples['labels'], is_split_into_words=True, padding=False, truncation=False)
    
    model_inputs = {
        'input_ids': inputs['input_ids'],
        'attention_mask': inputs['attention_mask'],
        'labels': targets['input_ids']
    }
    return model_inputs


In [6]:
print("loading origin data ...")
train_origin_prefix = load_data('./data/train/random_ns_nt/prefix_origin.txt')
print("loading simple data ...")
train_simple_prefix = load_data('./data/train/random_ns_nt/prefix_simple.txt')
assert len(train_origin_prefix) == len(train_simple_prefix)
# split dataset for train and test
train_origin_prefix, eval_origin_prefix, train_simple_prefix, eval_simple_prefix = train_test_split(
    train_origin_prefix, train_simple_prefix, test_size=0.1, random_state=42
)
# Create dataset
train_dataset = Dataset.from_dict({
    'input_ids': train_origin_prefix,
    'labels': train_simple_prefix
})
eval_dataset = Dataset.from_dict({
    'input_ids': eval_origin_prefix,
    'labels': eval_simple_prefix
})
dataset = DatasetDict({
    'train': train_dataset,
    'eval': eval_dataset
})
tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=["input_ids", "labels"])

tokenized_dataset.save_to_disk("./data/train/preprocessed_data")
print("preprocessed data save to disk sucessfully! ")

loading origin data ...
loading simple data ...


Map:   0%|          | 0/450000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/450000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/50000 [00:00<?, ? examples/s]

preprocessed data save to disk sucessfully! 


In [7]:
print(tokenized_dataset['train'][100]['input_ids']) 
print(tokenized_dataset['train'][0]['labels'])

[4115, 32101, 32100, 4115, 32103, 209, 172, 13039, 32103, 305, 17, 32103, 209, 4115, 32101, 32100, 4115, 172, 32101, 13039, 32102, 220, 4115, 32102, 314, 17, 32103, 209, 4115, 32101, 13039, 32102, 220, 4115, 32102, 314, 17, 32103, 209, 13039, 32103, 209, 4115, 32103, 209, 17, 32103, 209, 4115, 32101, 32100, 4115, 32103, 209, 4115, 172, 32101, 13039, 32103, 431, 4115, 32102, 305, 17, 32103, 209, 4115, 32101, 13039, 32103, 431, 4115, 32102, 305, 17, 32103, 209, 13039, 32103, 209, 489, 4115, 32102, 209, 314, 17, 32103, 209, 4115, 32100, 4115, 32103, 209, 172, 13039, 32103, 431, 17, 32100, 4115, 172, 32101, 13039, 32102, 220, 4115, 32102, 314, 17, 32103, 209, 4115, 32101, 13039, 32102, 220, 4115, 32102, 314, 17, 32103, 209, 13039, 32102, 505, 4115, 32102, 209, 209, 17, 1]
[4115, 32100, 4115, 172, 32101, 13039, 32103, 431, 4115, 32102, 489, 17, 32103, 209, 4115, 32101, 13039, 32103, 431, 4115, 32102, 489, 17, 32103, 209, 13039, 32102, 305, 4115, 32103, 431, 17, 32100, 4115, 32103, 209, 4115

In [8]:
from collections import Counter
from tqdm import tqdm

nonzero_lens = [sum(1 for t in x['input_ids'] if t != 0) for x in tqdm(tokenized_dataset['train'], desc='统计非0长度', unit='样本')]
print(f"非0长度最大值: {max(nonzero_lens)}")
print(f"非0长度最小值: {min(nonzero_lens)}")
print(f"非0长度分布:")
print(Counter(nonzero_lens).most_common(20))

统计非0长度:   0%|          | 0/450000 [00:00<?, ?样本/s]

统计非0长度: 100%|██████████| 450000/450000 [00:33<00:00, 13374.49样本/s]

非0长度最大值: 347
非0长度最小值: 4
非0长度分布:
[(169, 3886), (172, 3823), (166, 3724), (170, 3508), (171, 3495), (175, 3472), (163, 3407), (173, 3395), (174, 3379), (167, 3354), (105, 3296), (102, 3282), (168, 3250), (176, 3250), (177, 3142), (99, 3073), (160, 3033), (178, 3022), (164, 3013), (165, 2982)]





In [15]:
from collections import Counter
from tqdm import tqdm

print("=" * 50)
print("数据异常检查")
print("=" * 50)

# 使用 len(tokenizer) 而非 vocab_size，因为有新增token
actual_vocab_size = len(tokenizer)
pad_token_id = tokenizer.pad_token_id
eos_token_id = tokenizer.eos_token_id

print(f"实际词表大小: {actual_vocab_size}")

anomalies = {
    'empty_input': [],      # input_ids 全是padding
    'empty_label': [],      # labels 全是padding
    'invalid_token_input': [],  # input中有超出词表的token
    'invalid_token_label': [],  # label中有超出词表的token
    'very_short_input': [],     # input非常短 (< 3个有效token)
    'very_short_label': [],     # label非常短
}

for i, sample in enumerate(tqdm(tokenized_dataset['train'], desc='检查异常样本')):
    input_ids = sample['input_ids']
    labels = sample['labels']
    
    # 计算有效token数量（非padding）
    input_valid = sum(1 for t in input_ids if t != pad_token_id)
    label_valid = sum(1 for t in labels if t != pad_token_id)
    
    # 检查空序列
    if input_valid == 0:
        anomalies['empty_input'].append(i)
    if label_valid == 0:
        anomalies['empty_label'].append(i)
    
    # 检查非法token - 使用实际词表大小
    if any(t < 0 or t >= actual_vocab_size for t in input_ids):
        anomalies['invalid_token_input'].append(i)
    if any(t < 0 or t >= actual_vocab_size for t in labels if t != -100):
        anomalies['invalid_token_label'].append(i)
    
    # 检查过短序列
    if 0 < input_valid < 3:
        anomalies['very_short_input'].append(i)
    if 0 < label_valid < 3:
        anomalies['very_short_label'].append(i)

print("\n检查结果:")
for key, indices in anomalies.items():
    count = len(indices)
    if count > 0:
        print(f"  ❌ {key}: {count} 个样本")
        print(f"     前5个索引: {indices[:5]}")
    else:
        print(f"  ✅ {key}: 0 个异常")

total_anomalies = sum(len(v) for v in anomalies.values())
if total_anomalies > 0:
    print(f"\n⚠️  共发现 {total_anomalies} 处异常!")
else:
    print(f"\n✅ 数据检查通过，未发现异常!")

数据异常检查
实际词表大小: 32104


检查异常样本:   0%|          | 0/450000 [00:00<?, ?it/s]

检查异常样本: 100%|██████████| 450000/450000 [01:46<00:00, 4240.54it/s]


检查结果:
  ✅ empty_input: 0 个异常
  ✅ empty_label: 0 个异常
  ✅ invalid_token_input: 0 个异常
  ✅ invalid_token_label: 0 个异常
  ✅ very_short_input: 0 个异常
  ✅ very_short_label: 0 个异常

✅ 数据检查通过，未发现异常!



