In [20]:
import logging
import numpy as np
from utilities import preprocessing
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import CountVectorizer
from pyspark.sql.functions import col, udf, rand, max
from pyspark.ml.clustering import LDA
from pyspark.ml.feature import HashingTF, IDF
from pyspark.ml.classification import LogisticRegression

In [3]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class TextClassificationEngine:
    
    def __init__(self, spark):
        logger.info("Starting up text classification engine: ")
        self.spark = spark
        self.text_classification_data = self.__load_data_from_database()
        self.preprocessed_data = self.__data_preprocessing()
        self.hashing_tf, self.idf_vectorizer, self.rescaled_data = self.__vectorize_data()
        self.model = self.__train_model() 
    
    def __load_data_from_database(self) :
        logger.info("Loading labled data...")
        text_classification_data = spark.read \
                                        .format("jdbc") \
                                        .option("driver","com.mysql.cj.jdbc.Driver") \
                                        .option("url", "jdbc:mysql://web-database/Web") \
                                        .option("dbtable", "textClassification") \
                                        .option("user", "root") \
                                        .option("password", "123") \
                                        .load()
        text_classification_data = text_classification_data.select('category', 'descriptions')
        text_classification_data = text_classification_data.dropna(subset = ('category'))
        logger.info("Loading completed")
        return text_classification_data

    def __data_preprocessing(self):
        logger.info("Preprocessing data...")
        preprocessed_data = preprocessing(self.text_classification_data, 'descriptions')
        logger.info("Preprocessing completed")
        return preprocessed_data
    
    def __vectorize_data(self):
        logger.info("Vectorize data...")
        hashing_tf = HashingTF(inputCol = "filtered", outputCol = "raw_features", numFeatures=10000)
        featurized_data = hashing_tf.transform(self.preprocessed_data)

        idf = IDF(inputCol = "raw_features", outputCol = "features")
        idf_vectorizer = idf.fit(featurized_data)
        rescaled_data = idf_vectorizer.transform(featurized_data)
        logger.info("Vectorization completed")
        return hashing_tf, idf_vectorizer, rescaled_data

    def __train_model(self):
        labelEncoder = StringIndexer(inputCol = 'category',outputCol = 'label').fit(self.rescaled_data)
        df = labelEncoder.transform(self.rescaled_data)
        
        family = 'multinomial'
        regParam = 0.3
        elasticNetParam = 0
        maxIter = 50
        
        logger.info("Training text classification model...")
        lr = LogisticRegression(featuresCol = 'features',
                                labelCol = 'label',
                                family = family,
                                regParam = regParam,
                                elasticNetParam = regParam,
                                maxIter = maxIter)
        model = lr.fit(df)
        logger.info("Text classification model built!")
        return model
    
    def predict_label(self, input_data):
        schema = StructType([StructField("post_id", StringType(), True)\
                            ,StructField("descriptions", StringType(), True)])
        input_df = self.spark.createDataFrame(data = input_data, schema = schema)
        input_df = preprocessing(input_df, 'descriptions')
        
        featurized_input_df = self.hashing_tf.transform(input_df)
        rescaled_input_df = self.idf_vectorizer.transform(featurized_input_df) 
        predictions = self.model.transform(rescaled_input_df)
        def get_label(label): 
            label_dict = {0.0: 'Business',
                         1.0: 'Sci/Tech',
                         2.0: 'Sports',
                         3.0: 'World'}
            return label_dict[label]
        get_label_udf = udf(get_label, StringType())
        predictions = predictions.withColumn('label_name', get_label_udf(col('prediction')))
        return predictions.select('post_id', 'descriptions', 'label_name')

In [4]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class TopicModellingModel:
    
    def __init__(self, spark, label_name):
        logger.info("Starting up model LDA Business: ")
        self.spark = spark
        self.label_name = label_name
        self.data = self.__load_data_from_database()
        self.preprocessed_data = self.__data_preprocessing()
        self.vectorizer, self.wordVectors = self.__vectorize_data()
        self.model, self.final_df = self.__train_model() 
    
    def __load_data_from_database(self) :
        logger.info("Loading data...")
        data = spark.read \
                    .format("jdbc") \
                    .option("driver","com.mysql.cj.jdbc.Driver") \
                    .option("url", "jdbc:mysql://web-database/Web") \
                    .option("dbtable", "redditData") \
                    .option("user", "root") \
                    .option("password", "123") \
                    .load() \
                    .filter(col('category') == self.label_name)
        logger.info("Loading completed")
        return data

    def __data_preprocessing(self):
        logger.info("Preprocessing data...")
        preprocessed_data = self.data.select('id', 'category', 'descriptions')
        preprocessed_data = preprocessed_data.dropna(subset = ('category'))
        preprocessed_data = preprocessing(preprocessed_data, 'descriptions')
        logger.info("Preprocessing completed")
        return preprocessed_data
    
    def __vectorize_data(self):
        vectorizer = CountVectorizer().setInputCol("filtered").setOutputCol("features").fit(self.preprocessed_data)
        wordVectors_business = vectorizer.transform(self.preprocessed_data)
        return vectorizer, wordVectors_business

    def __train_model(self):
        k = 5
        maxIter = 50
        seed = 2
        lda = LDA(k = k, maxIter = maxIter, featuresCol = 'features', seed = seed)
        ldaModel = lda.fit(self.wordVectors)
        final_df = ldaModel.transform(self.wordVectors)

        to_array = udf(lambda v: v.toArray().tolist(), ArrayType(FloatType()))
        max_index = udf(lambda x: x.index(__builtin__.max(x)) if x is not None else None, IntegerType())
        final_df = final_df.withColumn('topicDistribution', to_array(final_df['topicDistribution']))
        final_df = final_df.withColumn('topic', max_index(final_df['topicDistribution']))
        logger.info("LDA Business model built!")
        return ldaModel, final_df
    
    def predict_topic(self, input_data):
        input_df = preprocessing(input_data, 'descriptions')
        input_wordVectors = self.vectorizer.transform(input_df)
        predictions = self.model.transform(input_wordVectors)
        
        to_array = udf(lambda v: v.toArray().tolist(), ArrayType(FloatType()))
        max_index = udf(lambda x: x.index(__builtin__.max(x)) if x is not None else None, IntegerType())
        predictions = predictions.withColumn('topicDistribution', to_array(predictions['topicDistribution']))
        predictions = predictions.withColumn('topic', max_index(predictions['topicDistribution']))
        return predictions.select('post_id', 'descriptions', 'label_name', 'topic')
    
    def get_recommendation(self, topic):
        relevant_posts = self.final_df.filter(col('topic') == topic)
        relevant_posts = relevant_posts.orderBy(rand()).limit(5).select('id')
        recommendation = relevant_posts.join(self.data, relevant_posts.id == self.data.id, "inner").drop(relevant_posts.id)
        return recommendation

In [5]:
spark = SparkSession.builder \
           .appName('Build Recommendation Model') \
           .config("spark.jars", "mysql-connector-j-8.0.32.jar")\
           .config("spark.driver.memory", "6g") \
           .config("spark.executor.memory", "8g") \
           .getOrCreate()

In [6]:
label_engine = TextClassificationEngine(spark)

INFO:__main__:Starting up text classification engine: 
INFO:__main__:Loading labled data...
INFO:__main__:Loading completed
INFO:__main__:Preprocessing data...
INFO:__main__:Preprocessing completed
INFO:__main__:Vectorize data...
INFO:__main__:Vectorization completed
INFO:__main__:Training text classification model...
INFO:__main__:Text classification model built!


In [13]:
test = [('1', "President Xi Jinping says he doesn't want a hodgepodge 'street stall economy' in Beijing, even as China's youth unemployment rates hit a record high"),
        ('2', "Tech titans who are trying to live forever might soon have new ammunition: Next-gen anti-aging pills"),
        ('3', "Bank of England governor says the UK is facing a wage-price spiral"), 
        ('4', " A group of technology companies  including Texas Instruments Inc. &lt;TXN.N&gt;, STMicroelectronics  &lt;STM.PA&gt; and Broadcom Corp. &lt;BRCM.O&gt;, on Thursday said they  will propose a new wireless networking standard up to 10 times  the speed of the current generation.")]

In [14]:
predicted_label = label_engine.predict_label(test)

In [15]:
predicted_label.show()

+-------+--------------------+----------+
|post_id|        descriptions|label_name|
+-------+--------------------+----------+
|      1|President Xi Jinp...|     World|
|      2|Tech titans who a...|  Sci/Tech|
|      3|Bank of England g...|  Business|
|      4|A group of techno...|  Sci/Tech|
+-------+--------------------+----------+



In [28]:
grouped_label = predicted_label.groupBy("label_name").count()
max_count = grouped_label.agg(max("count")).first()[0]

highest_count_groups = grouped_label.filter(col('count') == max_count)
label_name = highest_count_groups.select('label_name').first()[0]

In [33]:
considered_post = predicted_label.filter(col('label_name') == label_name)

In [29]:
label_name

'Sci/Tech'

In [30]:
topic_engine = TopicModellingModel(spark, label_name)

INFO:__main__:Starting up model LDA Business: 
INFO:__main__:Loading data...
INFO:__main__:Loading completed
INFO:__main__:Preprocessing data...
INFO:__main__:Preprocessing completed
INFO:__main__:LDA Business model built!


In [34]:
predicted_topic = topic_engine.predict_topic(considered_post)

In [35]:
predicted_topic.show()

+-------+--------------------+----------+-----+
|post_id|        descriptions|label_name|topic|
+-------+--------------------+----------+-----+
|      2|Tech titans who a...|  Sci/Tech|    4|
|      4|A group of techno...|  Sci/Tech|    2|
+-------+--------------------+----------+-----+



In [52]:
grouped_topic = predicted_topic.groupBy("topic").count()
max_count = grouped_topic.agg(max("count")).first()[0]

highest_count_groups = grouped_topic.filter(col('count') == max_count)
topic = highest_count_groups.select('topic').first()[0]
topic

4

In [39]:
topic_engine.get_recommendation(2).show()

+---+-------+--------------------+-------------------+--------------------+--------------------+--------+
| id|post_id|        descriptions|        created_utc|          source_url|            post_url|category|
+---+-------+--------------------+-------------------+--------------------+--------------------+--------+
| 92|13l1sq4|One million cance...|2023-05-18 15:15:22|https://www.bbc.c...|https://www.reddi...|Sci/Tech|
|126|13lsql9|First true EV rel...|2023-05-19 11:28:02|https://www.noteb...|https://www.reddi...|Sci/Tech|
|130|13lt5hy|The 'world's smal...|2023-05-19 11:47:30|https://www.pcgam...|https://www.reddi...|Sci/Tech|
|135|13ltb99|1st Solar Bike Pa...|2023-05-19 11:55:03|https://cleantech...|https://www.reddi...|Sci/Tech|
|312|13kvzzx|How robots could ...|2023-05-18 11:15:48|https://www.canar...|https://www.reddi...|Sci/Tech|
+---+-------+--------------------+-------------------+--------------------+--------------------+--------+



In [40]:
a = topic_engine.get_recommendation(2).collect()

In [43]:
a[0]

Row(id=61, post_id='13llygz', descriptions='Tech giants should help pay for 5G, say Europe’s mobile operators', created_utc=datetime.datetime(2023, 5, 19, 5, 23, 41), source_url='https://www.standard.co.uk/tech/5g-mobile-internet-prices-meta-google-amazon-apple-b1082085.html', post_url='https://www.reddit.com/r/technology/comments/13llygz', category='Sci/Tech')