In [1]:
# 创建示例文本数据
with open("torch_dataset.txt", "w", encoding="utf-8") as f:
    f.write("I love machine learning\t1\n")
    f.write("Deep learning is fascinating\t1\n")
    f.write("I hate waiting in line\t0\n")
    f.write("This is boring\t0\n")
    f.write("AI will change the world\t1\n")


In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer


class TextDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        super().__init__()
        self.samples = []
        self.labels = []
        self.tokenizer = tokenizer
        self.max_length = max_length

        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                text, label = line.strip().split('\t')
                self.samples.append(text)
                self.labels.append(int(label))

    def __len__(self):
        return len(self.samples)
        
    def __getitem__(self, idx):
        text = self.samples[idx]
        label = self.labels[idx]

        encoded = self.tokenizer(
            text, 
            padding = "max_length", 
            truncation = "longest_first", # 高版本的transformers不再接受True作为参数
            max_length = self.max_length, 
            return_tensors = "pt"
        )
        return {
            "input_ids": encoded["input_ids"].squeeze(0), 
            "attention_mask": encoded["attention_mask"].squeeze(0), 
            "label": torch.tensor(label, dtype=torch.long)
        }


if __name__ == '__main__':
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    dataset = TextDataset("torch_dataset.txt", tokenizer)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=False)

    for batch in dataloader:
        print(f"Input Ids: {batch['input_ids'].shape}")
        break

Input Ids: torch.Size([2, 128])


In [3]:
from datasets import load_dataset, Dataset
from transformers import BertTokenizer
import torch
from torch.utils.data import DataLoader


lines = []
with open("torch_dataset.txt", 'r', encoding='utf-8') as f:
    for line in f:
        text, label = line.strip().split('\t')
        lines.append({"text": text, "label": int(label)})

dataset = Dataset.from_list(lines)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def tokenize_function(example):
    encoded = tokenizer(
        example["text"],
        padding="max_length",
        truncation=True,
        max_length=32
    )
    encoded["label"] = example["label"]
    return encoded
tokenized_dataset = dataset.map(tokenize_function, batched=False)
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
dataloader = DataLoader(tokenized_dataset, batch_size=2, shuffle=True)
for batch in dataloader:
    print("input_ids:", batch["input_ids"].shape)
    print("attention_mask:", batch["attention_mask"].shape)
    print("labels:", batch["label"])
    break

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

input_ids: torch.Size([2, 32])
attention_mask: torch.Size([2, 32])
labels: tensor([1, 0])
