# [TEST] sa_lr_pyspark_training

These tests center on functions related to model training and result evaluation in the sa_lr_pyspark_training module.  

The primary objective is to ensure that the training and evaluation process is conducted accurately.  

The tests include the evaluation of metrics such as the area under the ROC curve, F1 score, and accuracy.  

They also verify that training and test DataFrames are generated correctly using the train_test_split function.

In [None]:
import pytest
from pyspark.sql import SparkSession

from training.sa_lr_pyspark_training import train_test_split, evaluate_model

In [None]:
@pytest.fixture
def spark_session():
    # Set up a SparkSession for testing
    return SparkSession.builder.master("local[2]").appName("test").getOrCreate()

In [None]:
def test_train_test_split(spark_session):
    # Create a test DataFrame
    data = [
        (1, "This is a sample tweet."),
        (0, "Another tweet with numbers 123."),
    ]
    columns = ["label", "text"]
    df = spark_session.createDataFrame(data, columns)

    # Call the train_test_split function with the test DataFrame
    train_df, test_df = train_test_split(df)

    # Verify that the training and testing DataFrames were created correctly
    assert train_df.count() > 0
    assert test_df.count() > 0

In [None]:
def test_evaluate_model(spark_session):
    # Create a test DataFrame
    data = [
        (1, 1, "This is a sample tweet."),
        (0, 0, "Another tweet with numbers 123."),
    ]
    columns = ["label", "prediction", "features"]
    df = spark_session.createDataFrame(data, columns)

    # Call the evaluate_model function with the test DataFrame
    result = evaluate_model(df, labelCol="label", predictionCol="prediction")

    # Verify that the results are numbers and within a reasonable range
    assert isinstance(result["ROC"], float)
    assert 0 <= result["ROC"] <= 1
    assert isinstance(result["F1"], float)
    assert 0 <= result["F1"] <= 1
    assert isinstance(result["Accuracy"], float)
    assert 0 <= result["Accuracy"] <= 1

In [None]:
if __name__ == '__main__':
    pytest.main()