# BERT model from ACL'22

In [None]:
import pandas as pd
import numpy as np
import os
from typing import List
from components.setup import (load_arguments_from_tsv, split_arguments, write_tsv_dataframe)
from components.models_bert import (predict_bert_model, load_tokenizer, pre_load_saved_model, get_available_values_by_subtask)

runs_as_inference_server = os.environ.get('TIRA_INFERENCE_SERVER', None) is not None
dataset_dir = os.environ.get('TIRA_INPUT_DIRECTORY', './dataset')
output_dir = os.environ.get('TIRA_OUTPUT_DIRECTORY', './output')

## Setup

In [None]:
model_dir = 'models/'
tokenizer_dir = 'tokenizer/'
subtask = os.environ.get('SUBTASK', "1")
if subtask not in ['1', '2']:
    print(f'Unknown subtask "{subtask}". Defaulting to subtask "1".')
    subtask = "1"

In [None]:
load_tokenizer(tokenizer_dir)

values = get_available_values_by_subtask(subtask=subtask)

In [None]:
pre_load_saved_model(os.path.join(model_dir, f'bert_train_subtask_{subtask}'), values)

## Predict function

In [None]:
def predict(input_list: List) -> List:
    # expect list of simple premise-strings
    df_predict = pd.DataFrame(input_list, columns=['Text'])

    result = predict_bert_model(df_predict, os.path.join(model_dir, f'bert_train_subtask_{subtask}'), values)
    result = np.clip(result, 0.0, 1.0)
    if subtask == '2':
        for base_value in get_available_values_by_subtask(subtask='1'):
            value_attained = f'{base_value} attained'
            value_constrained = f'{base_value} constrained'
            for i in range(len(result)):
                val_sum = result.loc[i, value_attained] + result.loc[i, value_constrained]
                if val_sum > 1.0:
                    modifier = 1.0 / val_sum
                    result.loc[i, value_attained] = modifier * result.loc[i, value_attained]
                    result.loc[i, value_constrained] = modifier * result.loc[i, value_constrained]

    return result.to_dict('records')

## Classification on TIRA

In [None]:
if not runs_as_inference_server:
    argument_filepath = os.path.join(dataset_dir, 'sentences.tsv')

    # load arguments
    df_arguments = load_arguments_from_tsv(argument_filepath)

    # format dataset
    _, _, df_test = split_arguments(df_arguments)

    # predict with Bert model
    df_prediction = df_test[['Text-ID', 'Sentence-ID']]
    print("===> Bert: Predicting...")
    prediction_list = predict(df_test['Text'].tolist())  # call uniform predict function
    df_prediction = pd.concat([df_prediction, pd.DataFrame.from_dict(prediction_list)], axis=1)

    # write predictions
    print("===> Writing predictions...")
    write_tsv_dataframe(os.path.join(output_dir, 'predictions.tsv'), df_prediction)