/
config.py
290 lines (238 loc) · 9.62 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
from dataclasses import dataclass, field
from typing import Dict, Optional
import tempfile
import os
from metaflow import IncludeFile, Parameter, JSONType
### DATASET ###
@dataclass
class DataStoreConfig:
hf_dataset_name: str = "wikicorpus"
hf_dataset_config_name: str = "raw_en"
local_path: str = "data/wikicorpus_llama2_7B_tokenized_4k"
s3_prefix: str = "wikicorpus_llama2_7B_tokenized_4k"
block_size: int = 4096
### TOKENIZER ###
@dataclass
class TokenizerStoreConfig:
local_path: str = "tokenizer.model"
s3_prefix: str = "llama2_7b/tokenizer"
### MODEL ###
@dataclass
class ModelStoreConfig:
local_weights_path: str = "model"
local_checkpoints_path: str = "model/checkpoints"
s3_prefix: str = "llama2_7b/model"
s3_checkpoints_key: str = "checkpoints"
s3_experiments_key: str = "experiments"
@dataclass
class ModelArchitectureConfig:
architectures: list = field(default_factory=lambda: ["LlamaForCausalLM"])
bos_token_id: int = 1
eos_token_id: int = 2
hidden_act: str = "silu"
hidden_size: int = 4096
initializer_range: float = 0.02
intermediate_size: int = 11008
max_position_embeddings: int = 2048
model_type: str = "llama"
num_attention_heads: int = 32
num_hidden_layers: int = 32
num_key_value_heads: int = 32
pad_token_id: int = 0
pretraining_tp: int = 1
rms_norm_eps: float = 1e-05
rope_scaling: Optional[str] = None
tie_word_embeddings: bool = False
torch_dtype: str = "float16"
transformers_version: str = "4.31.0"
use_cache: bool = True
vocab_size: int = 32000
sequence_parallel_enabled: bool = False
selective_checkpoint_enabled: bool = False
move_model_to_device: bool = True
@dataclass
class TrainingConfig:
tensor_parallelism_degree: int = (
8 # NOTE: always keep this lower than num devices per node.
)
use_mix_precision: bool = True
use_zero_1: bool = True # NOTE: 0 --> pure data parallelism, 1 --> ZeRO-1
global_batch_size: int = 1024
micro_batch_size: int = 1
learning_rate: float = 3.0e-4
sequence_length: int = 4096
do_pre_compilation: bool = True
pre_compilation_steps: int = 1
warmup_steps: int = 3
steps_this_run: int = 5
total_steps: int = 5
logging_interval: int = 1 # affects TensorBoard & CLI
checkpoint_frequency: int = 50
metrics_file: str = "metrics.json"
@dataclass
class TrainiumLlama2PretrainConfig:
data_store: DataStoreConfig = field(default_factory=DataStoreConfig)
tokenizer_store: TokenizerStoreConfig = field(default_factory=TokenizerStoreConfig)
model_store: ModelStoreConfig = field(default_factory=ModelStoreConfig)
training: TrainingConfig = field(default_factory=TrainingConfig)
model_architecture: ModelArchitectureConfig = field(
default_factory=ModelArchitectureConfig
)
### ENVIRONMENT ###
# for @step cache_dataset in flow.py
caching_env_config = {
"transformers": "4.31.0",
"regex": "2023.12.25",
"datasets": "2.16.1",
"sentencepiece": "0.1.99",
"protobuf": "3.20.0",
"omegaconf": "2.3.0",
}
@dataclass
class CachingEnvironmentConfig:
batch_enabled: bool = False # NOTE: Turn this on to tokenize data remotely.
packages: Dict[str, str] = field(default_factory=lambda: caching_env_config)
# Unused, baked in training step docker image
training_env_config = {
"transformers": "4.31.0",
"regex": "2023.12.25",
"tensorboard": "2.15.1",
"datasets": "2.16.1",
"sentencepiece": "0.1.99",
"protobuf": "3.20.0",
"omegaconf": "2.3.0",
}
# Derived from: https://github.com/aws-neuron/neuronx-distributed/blob/main/examples/training/llama2/tp_zero1_llama2_7b_hf_pretrain/tp_zero1_llama2_7b_hf_pretrain.sh
NUM_RT_NEURON_CORES = 32 # trn1.32xlarge instance property.
env_vars_config = {
"FI_EFA_USE_DEVICE_RDMA": "1",
"FI_PROVIDER": "efa",
"FI_EFA_FORK_SAFE": "1",
"CCOM_SOCKET_IFNAME": "eth0",
"MALLOC_ARENA_MAX": "64", # host OOM
"XLA_USE_BF16": "1",
"TF_NUM_INTEROP_THREADS": "8192",
"PROCESSES_PER_NODE": "32",
"NEURON_CC_FLAGS": "--model-type transformer --distribution-strategy=llm-training --cache_dir=~/neuron_compile_cache/",
"NEURON_FUSE_SOFTMAX": "1",
"NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS": "3", # Controls number of asynchronous execution requests to be supported. Reduces latency.
"NEURON_RT_NUM_CORES": str(NUM_RT_NEURON_CORES),
"NUM_NEURONCORES": str(NUM_RT_NEURON_CORES),
"TPU_NUM_DEVICES": str(NUM_RT_NEURON_CORES),
"TPU_CHIPS_PER_HOST_BOUNDS": str(NUM_RT_NEURON_CORES),
"NEURON_RT_ROOT_COMM_ID": "localhost:48620",
}
@dataclass
class BatchJobConfig:
n_nodes: int = 4
n_trainium_devices: int = 16
n_cpu: int = 96
memory: int = 500000
n_efa_interfaces: int = 8
image: str = "public.ecr.aws/outerbounds/trainium:llama2"
job_queue: str = "trn1-batch-trn1_32xl_batch_job_queue"
@dataclass
class TrainLlama2EnvConfig:
packages: Dict[str, str] = field(default_factory=lambda: training_env_config)
env_vars: Dict[str, str] = field(default_factory=lambda: env_vars_config)
batch_job: BatchJobConfig = field(default_factory=BatchJobConfig)
continue_from_checkpoint_instructions: str = "To continue from a checkpoint, specify the checkpoint name in the --checkpoint parameter."
@dataclass
class EnvironmentConfig:
dataset_cache_step: CachingEnvironmentConfig = field(
default_factory=CachingEnvironmentConfig
)
train_llama2_step: TrainLlama2EnvConfig = field(
default_factory=TrainLlama2EnvConfig
)
### CONFIG HELPERS ###
def create_config(filepath, _class):
from omegaconf import OmegaConf
conf = OmegaConf.structured(_class)
OmegaConf.save(conf, filepath)
def load_config(filepath, _class):
from omegaconf import OmegaConf
conf = OmegaConf.load(filepath)
schema = OmegaConf.structured(_class)
trainconf = OmegaConf.merge(schema, conf)
return trainconf
def _to_file(file_bytes, extension=None):
params = {
"suffix": f".{extension.replace('.', '')}" if extension is not None else None,
"delete": True,
"dir": "./",
}
latent_temp = tempfile.NamedTemporaryFile(**params)
latent_temp.write(file_bytes)
latent_temp.seek(0)
return latent_temp
class ConfigBase:
"""
Base class for all config needed for this flow as well as any dependent flows.
This class can be inherited by downstream classes or even used a mixin.
This class is meant for reuse in Metaflow flows which want to resue the configuration parameters of this training flow so
that they can call downstream flows with the same configuration parameters.
Example Usecases:
--------
- Upstream flow which is preparing data is inheriting the configuration schema / parameters from this class
- This way correct configuration parsed in both flows while we can also pass the configuration from the upstream flow to the downstream flow while ensuring that the configuration is valid.
- This pattern is very useful when we have a complex configuration schema and we want to reuse it in multiple flows. These flows may be invoked asynchronously using event handlers, so having a common configuration schema parser is useful.
All downstream flows will have to inherit this class and set the `config` property in this class.
This way we will be able to access the config directly.
The `_CORE_CONFIG_CLASS` property of this class should be set to the class which will be used to parse the configuration.
Usage Example:
--------
```
_CORE_CONFIG_CLASS = TrainiumLlama2PretrainConfig
@property
def config(self) -> TrainiumLlama2PretrainConfig:
return self._get_config()
```
"""
def _resolve_config(self):
if self._CORE_CONFIG_CLASS is None:
raise ValueError(
"Please set the _CORE_CONFIG_CLASS property of this class to the class which will be used to parse the configuration"
)
if (
self.experiment_config is not None
and self.experiment_config_file is not None
):
raise ValueError("Cannot specify both --config or --config-file")
elif self.experiment_config is None and self.experiment_config_file is None:
raise ValueError("Must specify either --config or --config-file")
if self.experiment_config is not None:
return load_config(self.experiment_config, self._CORE_CONFIG_CLASS)
if self.experiment_config_file is not None:
temf = _to_file(
bytes(self.experiment_config_file, "utf-8"),
)
return load_config(temf.name, self._CORE_CONFIG_CLASS)
_config = None
_CORE_CONFIG_CLASS = None
def _get_config(self):
if self._config is not None:
return self._config
self._config = self._resolve_config()
return self._config
experiment_config_file = IncludeFile(
"config-file", help="experiment config file path", default=None
)
experiment_config = Parameter(
"config", help="experiment config", default=None, type=JSONType
)
def config_report(self):
from metaflow.cards import Markdown
from omegaconf import OmegaConf
return [
Markdown(f"## Experiment Config"),
Markdown(f"```\n{OmegaConf.to_yaml(self.config)}```"),
]
if __name__ == "__main__":
if os.path.exists("config.yaml"):
user_input = input(
"config.yaml already exists. Type 'y/Y' and enter to overwrite: "
).upper()[0]
if user_input != "Y":
sys.exit("Exiting...")
create_config("config.yaml", TrainiumLlama2PretrainConfig)