# dataset

> Fill in a module description here

In [None]:
#| default_exp dataset

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
from typing import Callable

from torch.utils.data import Dataset
from transformers import AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


### Dataset for Reward Model

In [None]:
#| export
class PairDataset(Dataset):
    def __init__(self, dataset, tokenizer: Callable, max_length: int):
        
        self.chosen = []
        self.rejected = []
        
        for data in tqdm(dataset):
            chosen, rejected = data["chosen"], data["rejected"]
            chosen_encoding = tokenizer(
                chosen,
                max_length=max_length, padding="max_length", truncation=True,
                return_tensors="pt"
            )
            rejected_encoding = tokenizer(
                rejected,
                max_length=max_length, padding="max_length", truncation=True,
                return_tensors="pt"
            )
            
            self.chosen.append({
                "input_ids": chosen_encoding["input_ids"],
                "attention_mask": chosen_encoding["attention_mask"]
            })
            self.rejected.append({
                "input_ids": rejected_encoding["input_ids"],
                "attention_mask": rejected_encoding["attention_mask"]
            })
            
    
    def __len__(self):
        return len(self.chosen)

    def __getitem__(self, idx: int):
        return self.chosen[idx]["input_ids"],\
               self.chosen[idx]["attention_mask"],\
               self.rejected[idx]["input_ids"],\
               self.rejected[idx]["attention_mask"]

### Dataset for PPO Agent

In [None]:
#| export
class PromptDataset(Dataset):
    def __init__(self, dataset, tokenizer: Callable, max_length: int):
        
        self.prompts = []
        
        for data in tqdm(dataset):
            prompt = data["prompt"]
            prompt_encoding = tokenizer(
                prompt,
                max_length=max_length, padding="max_length", truncation=True,
                return_tensors="pt"
            )
            
            self.prompts.append({
                "input_ids": prompt_encoding["input_ids"],
                "attention_mask": prompt_encoding["attention_mask"]
            })
            
    
    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx: int):
        return self.prompts[idx]["input_ids"],\
               self.prompts[idx]["attention_mask"]