-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
문서 분류 튜토리얼 #1
Comments
argument 읽기 (0) json 파일 예시 {
"model_name_or_path": "test",
"task_name": "document_classification",
"data_dir": "data",
"output_dir": "checkpoint"
} (1) python console에서 json 파일로 읽어들이기 from ratsnlp.arguments import load_arguments
model_args, data_args, training_args = load_arguments(json_file_path="examples/document_classification.json") (2) json 파일 경로를 외부의 인자로 주어 읽어들이기 from ratsnlp.arguments import load_arguments
model_args, data_args, training_args = load_arguments() python examples/document_classification.py examples/document_classification.json (3) 인자들을 직접 외부에서 주입해 읽어들이기 from ratsnlp.arguments import load_arguments
model_args, data_args, training_args = load_arguments() python examples/document_classification.py --model_name_or_path test2 --task_name doc --data_dir data --output_dir check |
코드 from ratsnlp.nlpbook import *
from ratsnlp.nlpbook.classification import NsmcCorpus, Runner
if __name__ == "__main__":
args = load_arguments(json_file_path="examples/document_classification.json")
# args = load_arguments()
set_logger(args)
download_downstream_dataset(
args.downstream_corpus_name,
cache_dir=args.downstream_corpus_dir,
force_download=False
)
download_pretrained_model(
args.pretrained_model_name,
cache_dir=args.pretrained_model_cache_dir,
force_download=False
)
check_exist_checkpoints(args)
seed_setting(args)
tokenizer = get_tokenizer(args)
corpus = NsmcCorpus()
train_dataloader, val_dataloader, test_dataloader = get_dataloaders(corpus, tokenizer, args)
model = get_pretrained_model(args, num_labels=2)
runner = Runner(model, args)
checkpoint_callback, trainer = get_trainer(args)
if args.do_train:
trainer.fit(
runner,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloader,
)
if args.do_predict:
trainer.test(
runner,
test_dataloaders=test_dataloader,
ckpt_path=checkpoint_callback.best_model_path,
) config {
"pretrained_model_name": "kobert",
"pretrained_model_cache_dir": "/Users/david/works/cache/kobert",
"downstream_corpus_name": "nsmc",
"downstream_corpus_dir": "/Users/david/works/cache/nsmc",
"downstream_task_name": "document-classification",
"downstream_model_dir": "/Users/david/works/cache/checkpoint",
"do_train": true,
"do_eval": true,
"do_predict": false,
"batch_size": 32
} |
로컬에서 학습하기 다음 세 가지 방식이 동일하다
python train_local.py
python train_local.py train_local.json {
"pretrained_model_name": "beomi/kcbert-base",
"downstream_corpus_name": "nsmc",
"downstream_corpus_root_dir": "data",
"downstream_task_name": "document-classification",
"downstream_model_dir": "checkpoint/document-classification",
"do_train": true,
"do_eval": true,
"batch_size": 32
}
|
로컬에서 인퍼런스하기 다음 세 가지 방식이 동일하다
|
ratsgo
added a commit
that referenced
this issue
Dec 26, 2020
ratsgo
added a commit
that referenced
this issue
Jan 23, 2021
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
개요
문서 분류 튜토리얼을 구축한다
The text was updated successfully, but these errors were encountered: