# [TEST] sa_lr_pyspark_main

This test set covers the complete workflow of the sa_lr_pyspark_main module.  

It verifies functions for file reading, data preprocessing, splitting into training and test sets, and model training.  

The main purpose is to confirm that the entire workflow functions smoothly and that functions are called and executed correctly.  

Mocks are used to simulate Spark operations in the read_file function.

In [None]:
import pytest
from pyspark.sql import SparkSession
from mock import Mock
from pyspark.sql.dataframe import DataFrame

from execution.sa_lr_pyspark_main import read_file, logistic_regression_workflow

In [None]:
@pytest.fixture
def spark_session():
    # Set up a SparkSession for testing
    return SparkSession.builder.master("local[2]").appName("test").getOrCreate()

In [None]:
def test_read_file(spark_session):
    # Create a test DataFrame
    data = [
        (1, "This is a sample tweet."),
        (0, "Another tweet with numbers 123."),
    ]
    columns = ["label", "text"]
    df = spark_session.createDataFrame(data, columns)

    # Mock for the SparkSession.read.csv function
    spark_session.read.csv = Mock(return_value=df)

    # Call the read_file function with a fake file
    fileUrl = "fake_file.csv"
    result = read_file(fileUrl, spark_session)

    # Verify that the SparkSession.read.csv function was called correctly
    spark_session.read.csv.assert_called_once_with(fileUrl, sep=",", inferSchema=True, header=False)

    # Verify that the result is a DataFrame
    assert isinstance(result, DataFrame)

In [None]:
def test_logistic_regression_workflow(spark_session):
    # Create a test DataFrame
    data = [
        (1, "This is a sample tweet."),
        (0, "Another tweet with numbers 123."),
    ]
    columns = ["label", "text"]
    df = spark_session.createDataFrame(data, columns)

    # Mocks for the functions used in the workflow
    pre_process_mock = Mock(return_value=df)
    train_test_split_mock = Mock(return_value=(df, df))
    logistic_regression_mock = Mock(return_value=("train_summary", "test_summary"))

    # Call the complete workflow
    train_summary, test_summary = logistic_regression_workflow(df, pre_process_mock, train_test_split_mock, logistic_regression_mock)

    # Verify that all functions were called correctly
    pre_process_mock.assert_called_once_with(df)
    train_test_split_mock.assert_called_once_with(df)
    logistic_regression_mock.assert_called_once_with(df, df)

    # Verify the results
    assert train_summary == "train_summary"
    assert test_summary == "test_summary"

In [None]:
def logistic_regression_workflow(data, pre_process_function, train_test_split_function, logistic_regression_function):
    df = pre_process_function(data)
    train_data, test_data = train_test_split_function(df)
    train_summary, test_summary = logistic_regression_function(train_data, test_data)
    return train_summary, test_summary

In [None]:
if __name__ == '__main__':
    pytest.main()