-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
82 lines (68 loc) · 2.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from pathlib import Path
import flax
import hydra
import jax
import numpy as np
import torch
from flax.training import checkpoints
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader
flax.config.update("flax_use_orbax_checkpointing", False)
import wandb
from src.const import (
SPECIAL_TOKENS,
SEGMENT_TYPES,
MAX_SEQUENCE_LENGTH,
MISSING_TITLE,
WHAT_OTHER_PEOPLE_SEARCHED_TITLE,
)
from src.trainer import Trainer
@hydra.main(version_base="1.3", config_path="config", config_name="config")
def main(config: DictConfig):
np.random.seed(config.seed)
torch.manual_seed(config.seed)
directory = Path(config.dataset_directory)
train_files = [f for f in directory.glob("part-*")]
train_dataset = instantiate(
config.data,
files=train_files,
max_sequence_length=MAX_SEQUENCE_LENGTH,
special_tokens=SPECIAL_TOKENS,
segment_types=SEGMENT_TYPES,
ignored_titles=[MISSING_TITLE, WHAT_OTHER_PEOPLE_SEARCHED_TITLE],
)
train_loader = DataLoader(
train_dataset,
batch_size=config.per_device_train_batch_size * jax.device_count(),
collate_fn=train_dataset.collate_fn,
)
model = instantiate(config.model)
trainer = Trainer(**OmegaConf.to_container(config))
if config.log_metrics:
wandb.init(
project=config.wandb_project_name,
entity=config.wandb_entity,
sync_tensorboard=False,
config=OmegaConf.to_container(config, resolve=True, throw_on_missing=True),
name=config.model.name,
save_code=True,
)
trained_state = trainer.train(model, train_loader)
if config.hf_hub_push:
model.save_pretrained(
save_directory=config.output_dir,
params=trained_state.params,
push_to_hub=True,
repo_id=f"{config.hf_hub_user}/{config.hf_hub_model}",
token=config.hf_hub_token,
safe_serialization=True,
)
checkpoints.save_checkpoint(
ckpt_dir=config.output_dir,
target=trained_state,
step=config.max_steps,
overwrite=True,
)
if __name__ == "__main__":
main()