# SVM model from ACL'22

In [None]:
import pandas as pd
import os
from typing import List
from components.setup import (load_values_from_json, load_arguments_from_tsv, split_arguments,
                              write_tsv_dataframe, create_dataframe_head)
from components.models_svm import (load_svms)

runs_as_inference_server = os.environ.get('TIRA_INFERENCE_SERVER', None) is not None
dataset_dir = os.environ.get('TIRA_INPUT_DIRECTORY', './dataset')
output_dir = os.environ.get('TIRA_OUTPUT_DIRECTORY', './output')

## Setup

In [None]:
data_dir = 'core_data/'
model_dir = 'models/svm/'
level = "2"

In [None]:
values_filepath = os.path.join(data_dir, 'values.json')
values = load_values_from_json(values_filepath)

In [None]:
_model_registry = load_svms(values[level],
                            os.path.join(model_dir, f'svm_train_level{level}_vectorizer.json'),
                            os.path.join(model_dir, f'svm_train_level{level}_models.json'))

## Predict function

In [None]:
def predict(input_list: List) -> List:
    # expect list of simple premise-strings
    input_vector = pd.Series(input_list, name='Premise')
    df_model_predictions = {}

    for label_name in values[level]:
        svm = _model_registry[label_name]
        df_model_predictions[label_name] = svm.predict(input_vector)

    return pd.DataFrame(df_model_predictions, columns=values[level]).to_dict('records')

## Classification on TIRA

In [None]:
if not runs_as_inference_server:
    argument_filepath = os.path.join(dataset_dir, 'arguments.tsv')

    # load arguments
    df_arguments = load_arguments_from_tsv(argument_filepath)

    # format dataset
    _, _, df_test = split_arguments(df_arguments)

    # predict with Bert model
    df_prediction = create_dataframe_head(df_test['Argument ID'])
    print("===> SVM: Predicting Level %s..." % level)
    prediction_list = predict(df_test['Premise'].tolist())  # call uniform predict function
    df_prediction = pd.concat([df_prediction, pd.DataFrame.from_dict(prediction_list)], axis=1)

    # write predictions
    print("===> Writing predictions...")
    write_tsv_dataframe(os.path.join(output_dir, 'predictions.tsv'), df_prediction)