/
run_tabs.py
42 lines (33 loc) · 1.24 KB
/
run_tabs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from baselines.tabs import TypeDiscoveryModel
from utils.utils_common import define_callbacks, define_trainer, load_state_dict, save_confidence
from utils.utils_config import load_config
from utils.utils_tabs_datamodule import OpenTypeDataModule
def main():
# Load config
config, logger = load_config()
assert 'tabs' in config.model_name
# DataModule
dm = OpenTypeDataModule(config)
dm.setup()
model = TypeDiscoveryModel(config, len(dm.train_dataloader()), tokenizer=dm.tokenizer)
# Load pretrained model
if config.load_pretrained:
model.load_pretrained_model(load_state_dict(config, load_pretrained=True))
# Trainer & callbacks
callbacks = define_callbacks(config, monitor='val/loss', mode='min')
trainer = define_trainer(config, logger, callbacks)
if not config.eval_only:
trainer.fit(model, datamodule=dm)
# Test
trainer.test(
ckpt_path=trainer.checkpoint_callback.best_model_path,
datamodule=dm
)
# Save confidence scores from pre-trained model
if config.supervised_pretrain:
save_confidence(config, model, dm)
print(f'Run finished --> {config.run_name}')
if __name__ == '__main__':
print('Starting...')
main()
print('Done')