In [17]:
import json
from tqdm import tqdm
from datasets import load_dataset, load_from_disk, concatenate_datasets, DatasetDict

In [12]:
def process_metamathqa(examples, prefix="Please answer the following question: "):
    """Add a prefix to the inputs and tokenize the inputs and targets.

    Args:
        examples: dataset examples.
        tokenizer: tokenizer.
        prefix (str, optional): Prefix to add to each inputs. Defaults to "Please answer the following question: ".

    Returns:
        dataset: processed examples.
    """
    # The "inputs" are the tokenized answer:
    inputs = [prefix + doc for doc in examples["query"]]

    labels = examples["response"]
    return {"question": inputs, "answer": labels}

def load_metamathqa(dev_mode=False):
    """Load the metamathqa dataset.

    Returns:
        Dataset: metamathqa dataset.
    """
    dataset = load_dataset("meta-math/MetaMathQA")
    if dev_mode:
        dataset["train"] = dataset["train"].select(range(20))  # For development
    else:
        dataset["train"] = dataset["train"].select(range(50000))  # For fine-tuning
    dataset = dataset["train"].train_test_split(test_size=0.2)
    tokenized_dataset = dataset.map(
        process_metamathqa,
        batched=True,
        remove_columns=["query", "response", "type", "original_question"],
    )
    return tokenized_dataset

In [13]:
dataset_class = load_from_disk("../datasets/class_sft_datasetdict")
metamath_qa = load_metamathqa(dev_mode=True)

Map: 100%|██████████| 16/16 [00:00<00:00, 7240.14 examples/s]
Map: 100%|██████████| 4/4 [00:00<00:00, 2385.50 examples/s]


In [19]:
dataset = DatasetDict()
dataset['train'] = concatenate_datasets([dataset_class['train'], metamath_qa['train']]).shuffle(seed=42)
dataset['test'] = concatenate_datasets([dataset_class['test'], metamath_qa['test']]).shuffle(seed=42)