In [None]:
from datasets import load_dataset
from transformers import BertTokenizer

def filter_short_context(data):
    return len(data["context"]) <= 1024

def tokenize_function(tokenizer, data):
    inputs = [
        f"{question} [SEP] {context}"
        for question, context in zip(data["question"], data["context"])
    ]
    model_inputs = tokenizer(
        inputs,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )
    return model_inputs

dataset = load_dataset("squad", split="train[:10]")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

filtered_dataset = dataset.filter(filter_short_context)
tokenized_dataset = filtered_dataset.map(
    lambda x: tokenize_function(tokenizer, x),
    batched=True,
    remove_columns=dataset.column_names
)
tokenized_dataset.set_format(
    type="torch",
    columns=["input_ids", "token_type_ids", "attention_mask"]
)
print(tokenized_dataset)
print(tokenized_dataset[0])