-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_utils.py
106 lines (95 loc) · 3.6 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import enum
import logging
import os
from mlfoundry_utils import (
download_mlfoundry_artifact,
is_mlfoundry_artifact,
sanitize_name,
)
logger = logging.getLogger("axolotl")
def find_all_jsonl_files(path):
for root, dirs, files in os.walk(path):
for file in files:
filepath = os.path.join(root, file)
filename = os.path.basename(filepath)
if filename.endswith(".jsonl") and not filename.startswith("."):
yield filepath
class DatasetType(str, enum.Enum):
completion = "completion"
chat = "chat"
def _make_dataset_file_source(
path,
split="train",
dataset_type: DatasetType = DatasetType.completion,
):
"""
Axolotl dynamically loads prompt strategies based on the `type` key
The modules are present at axolotl.prompt_strategies.*
The `load` function in the module is called with the tokenizer, cfg and ds_cfg
Ideally we want to use the HF tokenizers library to apply the base model's chat template
But axolotl's chat template strategy forces to select one of the built-in template.
"""
if dataset_type == DatasetType.completion:
return {
"path": path,
"ds_type": "json",
"type": {
"system_prompt": "",
"field_system": "system",
"field_instruction": "prompt",
"field_output": "completion",
"format": "{instruction}\n{input}\n",
"no_input_format": "{instruction}\n",
"system_format": "{system}\n",
},
"split": split,
}
elif dataset_type == DatasetType.chat:
return {
"path": path,
"ds_type": "json",
"type": "chat_template",
"field_messages": "messages",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user", "human"],
"assistant": ["assistant"],
"tool": ["tool"],
},
"split": split,
"roles_to_train": ["gpt", "assistant", "ipython"],
"train_on_eos": "last",
}
else:
raise ValueError(f"Unsupported dataset type: {dataset_type}")
def dataset_uri_to_axolotl_datasources(
uri,
download_dir,
dataset_type: DatasetType = DatasetType.completion,
):
# TODO: Add support for HF datasets
if uri.startswith("https://"):
return [_make_dataset_file_source(path=uri, dataset_type=dataset_type)]
elif is_mlfoundry_artifact(uri):
datasources = []
logger.info("Downloading artifact from mlfoundry")
artifact_download_dir = os.path.join(download_dir, sanitize_name(uri))
download_path = download_mlfoundry_artifact(
artifact_version_fqn=uri, download_dir=artifact_download_dir, overwrite=True
)
for filepath in find_all_jsonl_files(download_path):
logger.info("Adding jsonl file {filepath}")
datasources.append(_make_dataset_file_source(path=filepath, dataset_type=dataset_type))
return datasources
elif os.path.exists(uri):
datasources = []
if os.path.isdir(uri):
for filepath in find_all_jsonl_files(uri):
datasources.append(_make_dataset_file_source(path=filepath, dataset_type=dataset_type))
else:
datasources = [_make_dataset_file_source(path=uri, dataset_type=dataset_type)]
return datasources
else:
raise ValueError(f"Unsupported data uri or path does not exist: {uri}")