In [5]:
#!/usr/bin/env python
# coding: utf-8

import argparse
import random

import numpy as np
import pandas as pd
import torch
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import QuestionAnsweringPipeline, AutoModelForQuestionAnswering, AutoTokenizer, logging, set_seed

logging.set_verbosity(50)

g = torch.Generator()
g.manual_seed(42)
torch.manual_seed(42)
random.seed(42)
set_seed(42)

model_checkpoint = 'bert-base-uncased'
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

nlp = QuestionAnsweringPipeline(model=model, tokenizer=tokenizer)

raw_datasets = load_dataset('Saptarshi7/techqa-squad-style', use_auth_token=True)


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)



In [40]:
class CustomDataset(Dataset):
    def __init__(self, records):
        self.records = records

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        record = self.records[idx]
        question = record['question']
        context = record['context'] 
        
        inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length=384, stride=128)
        
        return inputs


# Create custom dataset
dataset = CustomDataset(raw_datasets['validation'])
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, worker_init_fn=seed_worker, generator=g)


In [41]:
for batch in dataloader:
    print(batch['input_ids'].shape)
    break

torch.Size([2, 1, 384])
