In [24]:
import logging
import numpy as np
from utilities import preprocessing
from pyspark.sql.types import *
from pyspark.sql import SparkSession
from pyspark.ml.feature import HashingTF, IDF
from pyspark.ml.feature import StringIndexer
from pyspark.sql.functions import col, udf
from pyspark.ml.classification import LogisticRegression

In [131]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class TextClassificationEngine:
    
    def __init__(self, spark):
        logger.info("Starting up text classification engine: ")
        self.spark = spark
        self.text_classification_data = self.__load_data_from_database()
        self.preprocessed_data = self.__data_preprocessing()
        self.hashing_tf, self.idf_vectorizer, self.rescaled_data = self.__vectorize_data()
        self.model = self.__train_model() 
    
    def __load_data_from_database(self) :
        logger.info("Loading labled data...")
        text_classification_data = spark.read \
                                        .format("jdbc") \
                                        .option("driver","com.mysql.cj.jdbc.Driver") \
                                        .option("url", "jdbc:mysql://web-database/Web") \
                                        .option("dbtable", "textClassification") \
                                        .option("user", "root") \
                                        .option("password", "123") \
                                        .load()
        text_classification_data = text_classification_data.select('category', 'descriptions')
        text_classification_data = text_classification_data.dropna(subset = ('category'))
        logger.info("Loading completed")
        return text_classification_data

    def __data_preprocessing(self):
        logger.info("Preprocessing data...")
        preprocessed_data = preprocessing(self.text_classification_data, 'descriptions')
        logger.info("Preprocessing completed")
        return preprocessed_data
    
    def __vectorize_data(self):
        logger.info("Vectorize data...")
        hashing_tf = HashingTF(inputCol = "filtered", outputCol = "raw_features", numFeatures=10000)
        featurized_data = hashing_tf.transform(self.preprocessed_data)

        idf = IDF(inputCol = "raw_features", outputCol = "features")
        idf_vectorizer = idf.fit(featurized_data)
        rescaled_data = idf_vectorizer.transform(featurized_data).cache()
        logger.info("Vectorization completed")
        return hashing_tf, idf_vectorizer, rescaled_data

    def __train_model(self):
        labelEncoder = StringIndexer(inputCol = 'category',outputCol = 'label').fit(self.rescaled_data)
        df = labelEncoder.transform(self.rescaled_data)
        
        family = 'multinomial'
        regParam = 0.3
        elasticNetParam = 0
        maxIter = 50
        
        logger.info("Training text classification model...")
        lr = LogisticRegression(featuresCol = 'features',
                                labelCol = 'label',
                                family = family,
                                regParam = regParam,
                                elasticNetParam = regParam,
                                maxIter = maxIter)
        model = lr.fit(df)
        logger.info("Text classification model built!")
        return model
    
    def predict_label(self, input_data):
        schema = StructType([StructField("post_id", StringType(), True)\
                            ,StructField("descriptions", StringType(), True)])
        input_df = self.spark.createDataFrame(data = input_data, schema = schema)
        input_df = preprocessing(input_df, 'descriptions')
        
        featurized_input_df = self.hashing_tf.transform(input_df)
        rescaled_input_df = self.idf_vectorizer.transform(featurized_input_df) 
        predictions = self.model.transform(rescaled_input_df)
        def get_label(label): 
            label_dict = {0.0: 'Business',
                         1.0: 'Sci/Tech',
                         2.0: 'Sports',
                         3.0: 'World'}
            return label_dict[label]
        get_label_udf = udf(get_label, StringType())
        predictions = predictions.withColumn('label_name', get_label_udf(col('prediction')))
        return predictions.select('post_id', 'descriptions', 'label_name')

In [32]:
spark = SparkSession.builder \
           .appName('Test Text Classification Model') \
           .config("spark.jars", "mysql-connector-j-8.0.32.jar")\
           .config("spark.driver.memory", "6g") \
           .config("spark.executor.memory", "8g") \
           .getOrCreate()

In [33]:
engine = TextClassificationEngine(spark)

INFO:__main__:Starting up text classification engine: 
INFO:__main__:Loading labled data...
INFO:__main__:Loading completed
INFO:__main__:Preprocessing data...
INFO:__main__:Preprocessing completed
INFO:__main__:Vectorize data...
INFO:__main__:Vectorization completed
INFO:__main__:Training text classification model...
INFO:__main__:Text classification model built!


In [207]:
test = [(1, "President Xi Jinping says he doesn't want a hodgepodge 'street stall economy' in Beijing, even as China's youth unemployment rates hit a record high"),
        (2, "Tech titans who are trying to live forever might soon have new ammunition: Next-gen anti-aging pills"),
        (3, "Bank of England governor says the UK is facing a wage-price spiral"), 
        (4, "Congratulation Erling Haaland who score 2 goals last match with Manchester United")]

In [208]:
a = engine.predict_label(test)

In [209]:
a.show()

+-------+--------------------+----------+
|post_id|        descriptions|label_name|
+-------+--------------------+----------+
|      4|Congratulations t...|  Sci/Tech|
+-------+--------------------+----------+

