In [1]:
def _load_preprocessed_ds(cfg, sub_cfg):
    ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
    prepared_ds_path = _get_path(ds_hash, cfg)
    dataset = None

    # pylint: disable=duplicate-code
    if (
        cfg.dataset_prepared_path
        and any(prepared_ds_path.glob("*"))
        and not cfg.is_preprocess
    ):
        LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
        dataset = load_from_disk(str(prepared_ds_path))

    return dataset

def load_prepare_dpo_datasets(cfg):
    def load_split(dataset_cfgs, _cfg):
        split_datasets: List[Any] = []
        for i, ds_cfg in enumerate(dataset_cfgs):
            if ds_cfg["ds_type"] == "json":
                print(ds_cfg)
                for data_file in ds_cfg["data_files"]:
                    data_files = {ds_cfg["split"]: data_file}
                    print(data_files)
                    ds = load_dataset(  # pylint: disable=invalid-name
                        "json",
                        data_files=data_files,
                        split=ds_cfg["split"],
                    )
                    split_datasets.insert(i, ds)
            else:
                ds = load_dataset(  # pylint: disable=invalid-name
                    ds_cfg["path"],
                    split=ds_cfg["split"],
                )
                split_datasets.insert(i, ds)

        tokenizer = None

        for i, data_set in enumerate(split_datasets):
            _type = dataset_cfgs[i]["type"]
            if _type:
                if isinstance(_type, DictDefault):
                    _type = "user_defined.default"
                if _cfg.rl == "orpo":
                    ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
                elif _cfg.rl == "kto":
                    ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
                else:
                    print("load_dpo")
                    print(_type)
                    print(_cfg)
                    ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)

                print(ds_transform_fn)
                split_datasets[i] = map_dataset(
                    cfg, data_set, ds_transform_fn, tokenizer
                )
            elif _cfg.rl == "kto":
                ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
                split_datasets[i] = map_dataset(
                    cfg, data_set, ds_transform_fn, tokenizer
                )
            else:
                # If no `type` is provided, assume the dataset is already in the expected format with
                # "prompt", "chosen" and "rejected" already preprocessed
                split_datasets[i] = data_set

        return concatenate_datasets(split_datasets)

#     with zero_first(is_main_process()):
    train_is_preprocessed = False
    eval_is_preprocessed = False
    if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
        train_is_preprocessed = True
    else:
        train_dataset = load_split(cfg.datasets, cfg)

    eval_dataset = None
    if cfg.test_datasets:
        if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
            eval_is_preprocessed = True
        else:
            eval_dataset = load_split(cfg.test_datasets, cfg)
    if not eval_dataset:
        eval_dataset = None
        
    print(train_dataset)

    if not train_is_preprocessed:
        _save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
    if eval_dataset and not eval_is_preprocessed:
        _save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)

    return train_dataset, eval_dataset


In [2]:

def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
    sig = inspect.signature(ds_transform_fn)
    if "tokenizer" in sig.parameters:
        if not tokenizer:
            tokenizer = load_tokenizer(cfg)
        ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)

    data_set = data_set.map(
        ds_transform_fn,
        desc="Mapping RL Dataset",
    )
    if isinstance(data_set, DatasetDict):
        data_set = data_set["train"]
    return data_set



In [3]:
"""data handling helpers"""

import hashlib


def md5(to_hash: str, encoding: str = "utf-8") -> str:
    try:
        return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
    except TypeError:
        return hashlib.md5(to_hash.encode(encoding)).hexdigest()  # nosec

    
def _get_path(ds_hash, cfg):
    prepared_ds_path = (
        Path(cfg.dataset_prepared_path) / ds_hash
        if cfg.dataset_prepared_path
        else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
    )

    return prepared_ds_path


In [4]:
from addict import Dict
import yaml


class DictDefault(Dict):
    """
    A Dict that returns None instead of returning empty Dict for missing keys.
    """

    def __missing__(self, key):
        return None

    def __or__(self, other):
        return DictDefault(super().__ror__(other))

    
from pathlib import Path
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"


In [5]:
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk

In [6]:
import importlib
import logging

LOG = logging.getLogger("axolotl")


def load(strategy, cfg, module_base=None, **kwargs):
    try:
        load_fn = strategy.split(".")[-1]
        strategy = ".".join(strategy.split(".")[:-1])
        mod = importlib.import_module(f".{strategy}", module_base)
        func = getattr(mod, load_fn)
        return func(cfg, **kwargs)
    except Exception:  # pylint: disable=broad-exception-caught
        LOG.warning(f"unable to load strategy {strategy}")
        return None



In [7]:
def default(cfg, dataset_idx=0, **kwargs):  # pylint: disable=unused-argument
    ds_cfg = cfg["datasets"][dataset_idx]["type"]
    if not isinstance(ds_cfg, dict):
        raise ValueError(
            f"User-defined dataset type must be a dictionary. Got: {ds_cfg}"
        )
    field_prompt = ds_cfg.get("field_prompt", "prompt")
    field_system = ds_cfg.get("field_system", "system")
    field_chosen = ds_cfg.get("field_chosen", "chosen")
    field_rejected = ds_cfg.get("field_rejected", "rejected")
    prompt_format = ds_cfg.get("prompt_format")
    if not prompt_format:
        prompt_format = "{" + field_prompt + "}"
    chosen_format = ds_cfg.get("chosen_format")
    if not chosen_format:
        chosen_format = "{" + field_chosen + "}"
    rejected_format = ds_cfg.get("rejected_format")
    if not rejected_format:
        rejected_format = "{" + field_rejected + "}"

    def transform_fn(sample):
        if (
            "{" + field_system + "}" in prompt_format
            and field_system in sample
            and sample[field_system]
        ):
            sample["prompt"] = prompt_format.format(
                system=sample[field_system], prompt=sample[field_prompt]
            )
        else:
            sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
        sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])
        sample["rejected"] = rejected_format.format(rejected=sample[field_rejected])
        return sample

    return transform_fn


In [8]:
from axolotl.prompt_strategies.dpo import load as load_dpo

In [9]:
"""
module for base dataset transform strategies
"""

import importlib
import logging

LOG = logging.getLogger("axolotl")


def load(strategy, cfg, module_base=None, **kwargs):
    try:
        load_fn = strategy.split(".")[-1]
        strategy = ".".join(strategy.split(".")[:-1])
        print("strategy", strategy)
        print("module_base", module_base)
        mod = importlib.import_module(f".{strategy}", module_base)
        func = getattr(mod, load_fn)
        return func(cfg, **kwargs)
    except Exception:  # pylint: disable=broad-exception-caught
        LOG.warning(f"unable to load strategy {strategy}")
        return None


In [10]:
import inspect

In [11]:
importlib.import_module(".user_defined", "axolotl.prompt_strategies.dpo")

<module 'axolotl.prompt_strategies.dpo.user_defined' from '/Users/htong/Desktop/tmpsrc/axolotl/src/axolotl/prompt_strategies/dpo/user_defined.py'>

In [12]:
cfg = DictDefault()
cfg["datasets"] = [{"ds_type": "json", "split": "train", "type": {},
                    "data_files": ["finetuning_qwen2_dpo.jsonl"]}]

In [13]:
mod = importlib.import_module(".user_defined", "axolotl.prompt_strategies.dpo")
func = getattr(mod, "default")
func(cfg)

<function axolotl.prompt_strategies.dpo.user_defined.default.<locals>.transform_fn(sample)>

In [14]:
load_prepare_dpo_datasets(cfg)

{'ds_type': 'json', 'split': 'train', 'type': {}, 'data_files': ['finetuning_qwen2_dpo.jsonl']}
{'train': 'finetuning_qwen2_dpo.jsonl'}




Generating train split: 0 examples [00:00, ? examples/s]

Dataset({
    features: ['prompt', 'chosen', 'rejected'],
    num_rows: 208
})


NameError: name '_save_preprocessed_ds' is not defined