# BERT model from ACL'22

In [None]:
import pandas as pd
import os
from typing import List
from tira.third_party_integrations import is_running_as_inference_server, get_input_directory_and_output_directory
from components.setup import (load_values_from_json, load_arguments_from_tsv, split_arguments,
                              write_tsv_dataframe, create_dataframe_head)
from components.models_bert import (predict_bert_model, load_tokenizer)

In [None]:
data_dir = 'core_data/'
model_dir = 'models/bert/'
tokenizer_dir = 'tokenizer/'
level = "2"

In [None]:
load_tokenizer(tokenizer_dir)

values_filepath = os.path.join(data_dir, 'values.json')
values = load_values_from_json(values_filepath)

In [None]:
def predict(input_list: List) -> List:
    # expect list of simple premise-strings
    df_predict = pd.DataFrame(input_list, columns=['Premise'])

    result = predict_bert_model(df_predict, os.path.join(model_dir, 'bert_train_level{}'.format(level)),
                                    values[level])
    return result.tolist()

In [None]:
if not is_running_as_inference_server():
    dataset_dir, output_dir = get_input_directory_and_output_directory('./')
    argument_filepath = os.path.join(dataset_dir, 'arguments.tsv')

    # load arguments
    df_arguments = load_arguments_from_tsv(argument_filepath)

    # format dataset
    _, _, df_test = split_arguments(df_arguments)

    # predict with Bert model
    df_prediction = create_dataframe_head(df_test['Argument ID'])
    print("===> Bert: Predicting Level %s..." % level)
    result_list = predict(df_test['Premise'].tolist())  # call uniform predict function
    df_prediction = pd.concat([df_prediction, pd.DataFrame(result_list, columns=values[level])], axis=1)

    # write predictions
    print("===> Writing predictions...")
    write_tsv_dataframe(os.path.join(output_dir, 'predictions.tsv'), df_prediction)