-
Notifications
You must be signed in to change notification settings - Fork 263
/
main.py
192 lines (164 loc) · 6.18 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import argparse
import logging
import tempfile
from argparse import Namespace
from typing import Optional, Tuple
import torch
from libfb.py import parutil
from torch.utils.data import random_split
from torch.utils.data.dataset import Dataset, Subset
from torchtnt.examples.mingpt.char_dataset import CharDataset, DataConfig
from torchtnt.examples.mingpt.model import (
create_optimizer,
GPT,
GPTConfig,
OptimizerConfig,
)
from torchtnt.framework.auto_unit import AutoUnit, TrainStepResults
from torchtnt.framework.fit import fit
from torchtnt.framework.state import State
from torchtnt.utils import init_from_env, seed, TLRScheduler
from torchtnt.utils.loggers import TensorBoardLogger
_logger: logging.Logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Batch = Tuple[torch.Tensor, torch.Tensor]
PATH: str = parutil.get_file_path("data/input.txt", pkg=__package__)
def prepare_dataloader(
dataset: Dataset[CharDataset],
batch_size: int,
device: torch.device,
) -> torch.utils.data.DataLoader:
"""Instantiate DataLoader"""
# pin_memory enables faster host to GPU copies
on_cuda = device.type == "cuda"
return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=batch_size,
pin_memory=on_cuda,
)
def get_datasets(
data_cfg: DataConfig,
) -> Tuple[Subset[CharDataset], Subset[CharDataset], CharDataset]:
dataset = CharDataset(data_cfg)
train_len = int(len(dataset) * data_cfg.train_split)
train_set, eval_set = random_split(dataset, [train_len, len(dataset) - train_len])
return train_set, eval_set, dataset
class MinGPTUnit(AutoUnit[Batch]):
def __init__(
self,
tb_logger: TensorBoardLogger,
opt_cfg: OptimizerConfig,
log_every_n_steps: int,
module: torch.nn.Module,
device: torch.device,
strategy: str,
gradient_accumulation_steps: int,
detect_anomaly: bool,
clip_grad_norm: float,
) -> None:
super().__init__(
module=module,
device=device,
strategy=strategy,
gradient_accumulation_steps=gradient_accumulation_steps,
detect_anomaly=detect_anomaly,
clip_grad_norm=clip_grad_norm,
)
self.tb_logger = tb_logger
self.opt_cfg = opt_cfg
self.log_every_n_steps = log_every_n_steps
def configure_optimizers_and_lr_scheduler(
self,
module: torch.nn.Module,
) -> Tuple[torch.optim.Optimizer, Optional[TLRScheduler]]:
optimizer = create_optimizer(module, self.opt_cfg)
return optimizer, None
def compute_loss(
self, state: State, data: Batch
) -> Tuple[torch.Tensor, torch.Tensor]:
input, target = data
outputs, loss = self.module(input, target)
return loss, outputs
def on_train_step_end(
self,
state: State,
data: Batch,
step: int,
results: TrainStepResults,
) -> None:
loss = results.loss
if step % self.log_every_n_steps == 0:
self.tb_logger.log("loss", loss, step)
def main(args: Namespace) -> None:
seed(args.seed)
device = init_from_env()
data_cfg = DataConfig(
path=PATH,
block_size=args.block_size,
train_split=args.train_split,
truncate=args.truncate,
)
train_data, eval_data, dataset = get_datasets(data_cfg)
train_dataloader = prepare_dataloader(train_data, args.batch_size, device)
eval_dataloader = prepare_dataloader(eval_data, args.batch_size, device)
gpt_cfg = GPTConfig(
n_layer=args.n_layer,
n_head=args.n_head,
n_embd=args.n_embd,
vocab_size=dataset.vocab_size,
block_size=dataset.block_size,
device=str(device),
)
module = GPT(gpt_cfg)
path = tempfile.mkdtemp()
tb_logger = TensorBoardLogger(path)
my_unit = MinGPTUnit(
tb_logger=tb_logger,
opt_cfg=OptimizerConfig(learning_rate=args.lr, weight_decay=args.weight_decay),
module=module,
device=device,
strategy="ddp" if torch.distributed.is_initialized() else None,
log_every_n_steps=args.log_every_n_steps,
gradient_accumulation_steps=4,
detect_anomaly=True,
clip_grad_norm=args.clip_grad_norm,
)
fit(
my_unit,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
max_epochs=args.max_epochs,
max_steps=args.max_steps,
max_train_steps_per_epoch=args.max_steps_epoch,
)
def get_args() -> Namespace:
"""Parse command line arguments"""
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument("--truncate", type=float, default=0.05, help="truncate data")
parser.add_argument("--train-split", type=float, default=0.9, help="training split")
parser.add_argument("--block-size", type=int, default=128, help="block size")
parser.add_argument("--batch-size", type=int, default=216, help="batch size")
parser.add_argument("--max-steps", type=int, help="training steps")
parser.add_argument("--max-steps-epoch", type=int, help="training steps per epoch")
parser.add_argument("--max-epochs", type=int, default=10, help="training epochs")
parser.add_argument("--n-layer", type=int, default=8, help="number of block layers")
parser.add_argument("--n-head", type=int, default=8, help="number of heads")
parser.add_argument("--n-embd", type=int, default=512, help="embedding dimension")
parser.add_argument("--lr", type=float, default=0.0003, help="learning rate")
parser.add_argument("--weight-decay", type=float, default=0.1, help="weight decay")
parser.add_argument(
"--clip-grad-norm", type=float, default=1.0, help="clip gradient norm"
)
parser.add_argument(
"--log-every-n-steps", type=int, default=10, help="log every n steps"
)
return parser.parse_args()
if __name__ == "__main__":
main(get_args())