-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathconfig.py
35 lines (32 loc) · 1.12 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from pathlib import Path
def get_config():
return {
"batch_size": 4,
"num_epochs": 30,
"lr": 10**-4,
"seq_len": 350,
"d_model": 512,
"lang_src": "en",
"lang_tgt": "it",
"model_folder": "weights",
"model_basename": "tmodel_{0:02d}.pt",
"preload": "latest",
"tokenizer_file": "tokenizer_{0}.json",
}
def get_weights_file_path(config, epoch: str) -> str:
model_folder = config["model_folder"]
model_basename = config["model_basename"]
model_filename = model_basename.format(epoch)
return str(Path('.') / model_folder / model_filename)
def get_latest_weights_file_path(config) -> str:
model_folder = config["model_folder"]
model_basename = config["model_basename"]
# Check all files in the model folder
model_files = Path(model_folder).glob(f"*.pt")
# Sort by epoch number (ascending order)
model_files = sorted(model_files, key=lambda x: int(x.stem.split('_')[-1]))
if len(model_files) == 0:
return None
# Get the last one
model_filename = model_files[-1]
return str(model_filename)