-
Notifications
You must be signed in to change notification settings - Fork 13
/
gpt2_finetune.py
56 lines (42 loc) · 1.64 KB
/
gpt2_finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import fire
import pickle
from models.gpt2 import GPT2
BATCH_SIZE = 5
LR = 4e-5
ADAM_EPSILON = 1e-8
WEIGHT_DECAY = 0.
WARMUP_PROPORTION = 0.1
def load_data(dataset, split, vocab_size):
if vocab_size == 'full':
examples = pickle.load(open(f'data/{dataset}/{split}.pickle', 'rb'))
return [example['condition'] for example in examples], \
[example['text'] for example in examples]
else:
examples = pickle.load(open(
f'data/{dataset}/extracted_{split}_{vocab_size}words.pickle', 'rb'))
return [example['condition'] for example in examples],\
[example['extracted_text'] for example in examples]
def main(dataset='wp', vocab='full', gpt2_type='gpt2', n_epochs=3):
if os.path.exists(f'training_logs/{gpt2_type}_{dataset}_{vocab}words'):
print('Training path existed! Remove it if wanna re-train.')
return
gpt2 = GPT2(gpt2_type=gpt2_type)
for split in ['train', 'dev']:
conds, texts = load_data(dataset, split, vocab)
gpt2.load_data(split=split, conds=conds, texts=texts)
train_steps = n_epochs * (len(gpt2.train_dataset) // BATCH_SIZE + 1)
warmup_steps = int(train_steps * WARMUP_PROPORTION)
gpt2.get_optimizer(
lr=LR,
train_steps=train_steps,
warmup_steps=warmup_steps,
weight_decay=WEIGHT_DECAY,
adam_epsilon=ADAM_EPSILON)
gpt2.creat_log_dir(
eval_steps=len(gpt2.train_dataset) // BATCH_SIZE,
label=f'{gpt2_type}_{dataset}_{vocab}words')
for epoch in range(n_epochs):
gpt2.train_epoch(batch_size=BATCH_SIZE)
if __name__ == '__main__':
fire.Fire(main)