<a href="https://colab.research.google.com/github/zzwony/Start_0920/blob/main/01_12_bert_qa_deploy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install ratsnlp

In [None]:
from google.colab import drive
drive.mount('/gdrive', force_remount=True)

In [None]:
# 인퍼천스 설정

from ratsnlp.nlpbook.qa import QADeployArguments
args = QADeployArguments(
    pretrained_model_name="beomi/kcbert-base",
    downstream_model_dir="/gdrive/My Drive/nlpbook/checkpoint-qa",
    max_seq_length=128,
    max_query_length=32,
)

In [None]:
# 토크나이저 로드

from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(
    args.pretrained_model_name,
    do_lower_case=False,
)

# 체크포인트 로드

import torch
fine_tuned_model_ckpt = torch.load(
    args.downstream_model_checkpoint_path,
    map_location=torch.device("cpu")
)

In [None]:
# BERT 설정 로드

from transformers import BertConfig
pretrained_model_config = BertConfig.from_pretrained(
    args.pretrained_model_name,
)

# 모델 초기화

from transformers import BertForQuestionAnswering
model = BertForQuestionAnswering(pretrained_model_config)

In [None]:
# 체크포인트 읽기
model.load_state_dict({k.replace("model.", ""): v for k, v in fine_tuned_model_ckpt['state_dict'].items()})
model.eval()

In [None]:
# 인퍼런스

def inference_fn(question, context):
    if question and context:
        truncated_query = tokenizer.encode(
            question,
            add_special_tokens=False,
            truncation=True,
            max_length=args.max_query_length
       )
        inputs = tokenizer.encode_plus(
            text=truncated_query,
            text_pair=context,
            truncation="only_second",## pair 문장이 주어졌을 경우 두번재 문자엥 대해서만 truncation
            padding="max_length",
            max_length=args.max_seq_length,
            return_token_type_ids=True,
        )
        with torch.no_grad():
            outputs = model(**{k: torch.tensor([v]) for k, v in inputs.items()})
            start_pred = outputs.start_logits.argmax(dim=-1).item()## 시작 토큰에 해당하는 로짓(확률값들)에서 가장 큰 인덱스
            end_pred = outputs.end_logits.argmax(dim=-1).item()## 마지막 토큰에 해당하는 로짓(확률값들)에서 가장 큰 인덱스
            pred_text = tokenizer.decode(inputs['input_ids'][start_pred:end_pred+1])
    else:
        pred_text = ""
    return {
        'question': question,
        'context': context,
        'answer': pred_text,
    }

In [None]:
# 웹 서비스

from ratsnlp.nlpbook.qa import get_web_service_app
app = get_web_service_app(inference_fn)
app.run()