In [6]:
import logging
import os
import sys
from typing import Callable, List, Dict, NoReturn, Tuple

import numpy as np
from configure import *
from preprocess import *
from datasets import (
    load_metric,
    load_dataset
)

from transformers import AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer

from transformers import (
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    TrainingArguments,
    set_seed,
)

from utils_qa import postprocess_qa_predictions, check_no_error
from trainer_qa import QuestionAnsweringTrainer
from sparse_retrieval import SparseRetrieval
from retrieval_common_part import build_faiss, retrieve_faiss
import retrieval_common_part
from postprocessing import post_processing_function
from run_mrc import run_combine_mrc

In [12]:
import easydict
 
args = easydict.EasyDict({
 
        "dataset_name": '../data/test_dataset',
 
        "epoch": 20,
 
        "gpu": 0,
 
        "out": "result",
 
        "resume": False,
 
        "unit": 1000
 
})

usage: ipykernel_launcher.py [-h] [--model_name_or_path MODEL_NAME_OR_PATH] [--config_name CONFIG_NAME]
                             [--tokenizer_name TOKENIZER_NAME] [--run_extraction [RUN_EXTRACTION]] [--no_run_extraction]
                             [--run_generation [RUN_GENERATION]] [--dataset_name DATASET_NAME]
                             [--overwrite_cache [OVERWRITE_CACHE]] [--preprocessing_num_workers PREPROCESSING_NUM_WORKERS]
                             [--max_seq_length MAX_SEQ_LENGTH] [--pad_to_max_length [PAD_TO_MAX_LENGTH]]
                             [--doc_stride DOC_STRIDE] [--max_answer_length MAX_ANSWER_LENGTH]
                             [--eval_retrieval [EVAL_RETRIEVAL]] [--no_eval_retrieval] [--num_clusters NUM_CLUSTERS]
                             [--top_k_retrieval TOP_K_RETRIEVAL] [--use_faiss [USE_FAISS]] [--sparse_name SPARSE_NAME]
                             [--dense_name DENSE_NAME] [--run_seq2seq [RUN_SEQ2SEQ]] --output_dir OUTPUT_DIR
            

SystemExit: 2

In [None]:
datasets = load_dataset('json', data_files={'validation':os.path.join(data_args.dataset_name, 'test.json')}, field='data')

# AutoConfig를 이용하여 pretrained model 과 tokenizer를 불러옵니다.
# argument로 원하는 모델 이름을 설정하면 옵션을 바꿀 수 있습니다.
model, tokenizer = configure_model(model_args, training_args, data_args)

# rue일 경우 : run passage retrieval
if data_args.eval_retrieval:
    datasets = run_retrieval(
        tokenizer,
        datasets,
        training_args,
        data_args,
    )

In [None]:
column_names = datasets["validation"].column_names

In [None]:
last_checkpoint, max_seq_length = check_no_error(
        data_args, training_args, datasets, tokenizer
    )

In [None]:
if training_args.do_eval or training_args.do_predict:
        eval_dataset = datasets["validation"]

        prepare_valid_features = preprocess_extract_valid(tokenizer, data_args, column_names, max_seq_length)
        # Validation Feature 생성
        eval_dataset = eval_dataset.map(
            prepare_valid_features,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
        )

In [None]:
data_collator = DataCollatorWithPadding(
        tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None
    )

In [None]:
trainer = QuestionAnsweringTrainer( 
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        eval_examples=datasets["validation"] if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        post_process_function=post_processing_function,
        compute_metrics=compute_metrics,
    )

In [1]:
if training_args.do_predict:
        predictions = trainer.predict(
            test_dataset=eval_dataset, test_examples=datasets["validation"]
        )
        # predictions.json 은 postprocess_qa_predictions() 호출시 이미 저장됩니다.
        print(
            "No metric can be presented because there is no correct answer given. Job done!"
        )


DatasetDict({
    validation: Dataset({
        features: ['question', 'id'],
        num_rows: 9584
    })
})
