In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
df = spark.read.options(header=True, inferSchema=True).csv('data/kddcup99.csv')
df.limit(10).toPandas()

21/11/09 19:18:25 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
21/11/09 19:18:29 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


Unnamed: 0,duration,protocol_type,service,flag,src_bytes,dst_bytes,land,wrong_fragment,urgent,hot,...,dst_host_srv_count,dst_host_same_srv_rate,dst_host_diff_srv_rate,dst_host_same_src_port_rate,dst_host_srv_diff_host_rate,dst_host_serror_rate,dst_host_srv_serror_rate,dst_host_rerror_rate,dst_host_srv_rerror_rate,labels
0,0,b'tcp',b'http',b'SF',181,5450,0,0,0,0,...,9,1.0,0.0,0.11,0.0,0.0,0.0,0.0,0.0,b'normal.'
1,0,b'tcp',b'http',b'SF',239,486,0,0,0,0,...,19,1.0,0.0,0.05,0.0,0.0,0.0,0.0,0.0,b'normal.'
2,0,b'tcp',b'http',b'SF',235,1337,0,0,0,0,...,29,1.0,0.0,0.03,0.0,0.0,0.0,0.0,0.0,b'normal.'
3,0,b'tcp',b'http',b'SF',219,1337,0,0,0,0,...,39,1.0,0.0,0.03,0.0,0.0,0.0,0.0,0.0,b'normal.'
4,0,b'tcp',b'http',b'SF',217,2032,0,0,0,0,...,49,1.0,0.0,0.02,0.0,0.0,0.0,0.0,0.0,b'normal.'
5,0,b'tcp',b'http',b'SF',217,2032,0,0,0,0,...,59,1.0,0.0,0.02,0.0,0.0,0.0,0.0,0.0,b'normal.'
6,0,b'tcp',b'http',b'SF',212,1940,0,0,0,0,...,69,1.0,0.0,1.0,0.04,0.0,0.0,0.0,0.0,b'normal.'
7,0,b'tcp',b'http',b'SF',159,4087,0,0,0,0,...,79,1.0,0.0,0.09,0.04,0.0,0.0,0.0,0.0,b'normal.'
8,0,b'tcp',b'http',b'SF',210,151,0,0,0,0,...,89,1.0,0.0,0.12,0.04,0.0,0.0,0.0,0.0,b'normal.'
9,0,b'tcp',b'http',b'SF',212,786,0,0,0,1,...,99,1.0,0.0,0.12,0.05,0.0,0.0,0.0,0.0,b'normal.'


In [2]:
from __future__ import annotations
from typing import Optional
from pyspark.sql.types import NumericType
from pyspark.sql import DataFrame

from faas.utils_dataframe import get_non_numeric_columns
from faas.base import Pipeline
from faas.scaler import StandardScaler
from faas.encoder import OrdinalEncoder
import logging
from faas.utils_dataframe import JoinableByRowID
from lightgbm import LGBMModel


logger = logging.getLogger(__name__)


def is_numeric(df: DataFrame, column: str) -> str:
    return isinstance(df.schema[column].dataType, NumericType)


class E2EPipline:

    def __init__(
        self,
        df: DataFrame,
        target_column: str,
        target_group: Optional[str] = None,  # = 'protocol_type'
        feature_columns: Optional[str] = None
    ):
        self.target_column = target_column
        if feature_columns is None:
            feature_columns = [c for c in df.columns if c != target_column]
            logger.info(
                'Received None as feature_columns, automatically using all non-target columns. '
                f'len(feature_columns): {len(feature_columns)}'
            )
        self.feature_columns = feature_columns

        categorical_features = get_non_numeric_columns(df.select(*feature_columns))
        numeric_features = [c for c in feature_columns if c not in categorical_features]
        logger.info(
            f'num_features: {len(feature_columns)} '
            f'num_categorical_features: {len(categorical_features)} '
            f'num_numeric_features: {len(numeric_features)}'
        )
        self.categorical_features = categorical_features
        self.numeric_features = numeric_features

        self.getx = Pipeline(steps=[
            OrdinalEncoder(c)
            for c in self.categorical_features
        ])
        if target_group is not None:
            self.gety = Pipeline(steps=[
                StandardScaler(column=target_column, group_column=target_group)
            ])
        else:
            # so many workarounds.. maybe there's a better way? a selector transformer?
            self.gety = None
        self.m = LGBMModel(
            objective='regression' if is_numeric(df=df, column=self.target_column) else 'binary',
            deterministic=True,
        )

    def fit(self, df: DataFrame) -> E2EPipline:
        self.getx.fit(df)
        X = (
            self.getx
            .transform(df)
            .select(self.getx.feature_columns + self.numeric_features)
            .toPandas()
        )
        if self.gety is not None:
            self.gety.fit(df)
            y = (
                self.gety
                .transform(df)
                .select(self.gety.feature_columns)
                .toPandas()
            )
        else:
            y = (
                df
                .select(self.target_column)
                .toPandas()
            )
        self.m.fit(X=X, y=y)
        return self

    def predict(self, df: DataFrame) -> DataFrame:
        jb = JoinableByRowID(df)
        Xpred = (
            self.getx
            .transform(jb.df)
            .select(self.getx.feature_columns + self.numeric_features)
            .toPandas()
        )
        ypred = self.m.predict(Xpred)
        df_with_y = jb.join_by_row_id(
            ypred,
            column=self.gety.feature_columns[0] if self.gety is not None else self.target_column
        )
        if self.gety is not None:
            df_pred = self.gety.inverse_transform(df_with_y)
        else:
            df_pred = df_with_y
        return df_pred


In [3]:
e2e = E2EPipline(
    df=df,
    target_column = 'dst_host_srv_count',
).fit(df)



In [4]:
dfpred = e2e.predict(df)



In [5]:
df.select('dst_bytes', 'dst_host_srv_count').limit(10).toPandas()

Unnamed: 0,dst_bytes,dst_host_srv_count
0,5450,9
1,486,19
2,1337,29
3,1337,39
4,2032,49
5,2032,59
6,1940,69
7,4087,79
8,151,89
9,786,99


In [6]:
dfpred.select('dst_bytes', 'dst_host_srv_count').limit(10).toPandas()

Unnamed: 0,dst_bytes,dst_host_srv_count
0,5450,21.554066
1,486,16.321566
2,1337,34.219186
3,1337,42.318821
4,2032,50.131344
5,2032,53.699666
6,1940,142.211937
7,4087,153.488182
8,151,162.879353
9,786,130.87418
