# Extract testing and training data from TSV.

In [1]:
from pathlib import Path
from dataclasses import dataclass
import polars as pl

@dataclass
class MLData:
    """
    Class to store the training and testing dataframes for machine learning models.
    Attributes:
        training_data (pl.DataFrame): Training data.
        testing_data (pl.DataFrame): Testing data.
    """
    training_data: pl.DataFrame
    testing_data: pl.DataFrame
    
class FormatMLData:
    """ Class for formatting the training and testing dataframes for machine learning models."""
    def __init__(self, input_data_path: Path):
        """ 
        Initialise the FormatMLData class.
        Args:
            input_data_path (Path): Path to the input data tsv.
        """
        self.input_data_path = input_data_path
        
    def read_input_data(self) -> pl.DataFrame:
        """
        Read the input data tsv.
        Returns:
            pl.DataFrame: The input data dataframe.
        """
        return pl.read_csv(self.input_data_path, separator="\t", infer_schema_length=100000000)
    
    @staticmethod
    def fix_max_path_null(input_data: pl.DataFrame) -> pl.DataFrame:
        """
        Fix max path null values - fills the Null values with the EXOMISER_VARIANT_SCORE.
        Args:
            input_data (pl.DataFrame): The input data dataframe
        Returns:
            pl.DataFrame: The input data dataframe with Null values replaced with the EXOMISER_VARIANT_SCORE
        """
        return input_data.with_columns(pl.col('MAX_PATH').fill_null(pl.col('EXOMISER_VARIANT_SCORE')))
    
    @staticmethod
    def fix_max_freq_null(input_data: pl.DataFrame) -> pl.DataFrame:
        """
        Fix max freq null values - fills the Null values with 0
        Args:
            input_data (pl.DataFrame): The input data dataframe
        Returns:
            pl.DataFrame: The input data dataframe with Null values replaced with 0
        """
        return input_data.with_columns(pl.col("MAX_FREQ").fill_null(0))
    
    @staticmethod
    def retrieve_training_data(input_data: pl.DataFrame) -> pl.DataFrame:
        """
        Retrieve training data.
        Args:
            input_data (pl.Dataframe): The input data dataframe
        Returns:
            pl.DataFrame: The training data dataframe
        """
        training = input_data.filter(pl.col("TRAIN_STATUS")==1)
        return training

    @staticmethod
    def retrieve_test_data(input_data: pl.DataFrame) -> pl.DataFrame:
        """
        Retrieve testing data.
        Args:
            input_data (pl.Dataframe): The input data dataframe
        Returns:
            pl.DataFrame: The testing data dataframe
        """
        return input_data.filter(pl.col("TRAIN_STATUS")==0)
    
    def return_ml_data(self) -> MLData:
        """
        Retrieve the formatted training and testing data, with Null values replaced.
        Returns:
            MLData: The formatted training and testing data
        """
        input_data = self.read_input_data()
        input_data = self.fix_max_path_null(input_data)
        input_data = self.fix_max_freq_null(input_data)
        return MLData(training_data=self.retrieve_training_data(input_data), testing_data=self.retrieve_test_data(input_data))