# Introduction

This notebook trains an NLP model with PySpark and Spark NLP from scratch. 
* The data set contains 18,030 tweets in parquet format; `data/raw/twitter_flat.parquet`.
* The trained model is saved in the `models` directory.

# Imports

In [1]:
from datetime import (
    datetime,
    timedelta
)
import pandas as pd
from pathlib import Path
from sklearn.datasets import fetch_20newsgroups
from bertopic import BERTopic
from joblib import (
    dump,
    load
)

from pyspark.sql import (
    SparkSession,
    functions as F
)
from pyspark.sql.types import (
    StringType,
    StructType,
    StructField,
    IntegerType,
    LongType,
    BooleanType,
    MapType
)
import sparknlp
from sparknlp import Finisher
from pyspark.ml import (
    Pipeline
)
from sparknlp.pretrained import PretrainedPipeline

from wordcloud import (
    WordCloud, 
    STOPWORDS, 
    ImageColorGenerator
)

from IPython.display import display, clear_output
import time

Error importing optional module skimage
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/_plotly_utils/optional_imports.py", line 30, in get_module
    return import_module(name)
  File "/opt/conda/lib/python3.8/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1014, in _gcd_import
  File "<frozen importlib._bootstrap>", line 991, in _find_and_load
  File "<frozen importlib._bootstrap>", line 975, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 671, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 783, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/opt/conda/lib/python3.8/site-packages/skimage/__init__.py", line 135, in <module>
    from .data import data_dir
  File "/opt/conda/lib/python3.8/site-packages/skimage/data/__init__.py", line 270, in

In [2]:
spark = sparknlp.start()
print(f'spark.version: {spark.version}')
print(f'sparknlp.version(): {sparknlp.version()}')

spark.version: 3.1.1
sparknlp.version(): 3.1.0


# Set up directories

In [3]:
project_dir = Path.cwd().parent
models_dir = project_dir / 'models'
pretrained_models_dir = models_dir / 'pretrained'
data_dir = project_dir / 'data'
raw_data_dir = data_dir / 'raw'
processed_data_dir = data_dir / 'processed'

# Load parquet

In [None]:
path = raw_data_dir / 'twitter_flat.parquet'
# `spark.read.parquet` can't infer schema for some reason
# so load into pandas df first
pdf = pd.read_parquet(path)
df = spark.createDataFrame(pdf)

In [None]:
pdf.head()

In [None]:
df.show(5)

In [None]:
def get_longer(text1, text2):
    result = text1
    if text2 is not None:
        result = text2    
    return result

In [None]:
Get the full text and assign to the `text` field.

In [None]:
pdf.loc[:, 'text'] = pdf.apply(lambda s: get_longer(s.text, s.extended_full_text), axis='columns')

In [None]:
docs = pdf.text.tolist()

# Train model

In [22]:
path = models_dir / 'topic_model'
if path.with_suffix('.joblib').exists():
    topic_model = load(path.with_suffix('.joblib'))
else:
    # Training takes ~15 minutes
    topic_model = BERTopic(language="english", calculate_probabilities=True, verbose=True)
    topics, probs = topic_model.fit_transform(docs)
    dump(topic_model, path.with_suffix('.joblib'))

# Transform one topic

In [16]:
get_topic_name(test_prediction, topic_model)

'why'

# Stream topics

In [10]:
stream_df = (
    spark 
    .readStream 
    .format("kafka") 
    .option("kafka.bootstrap.servers", "broker:29092") 
    .option("startingOffsets", "earliest") 
    .option("subscribe", "twitterdata") 
#     .option("maxOffsetsPerTrigger",1)
    .load()
)

In [11]:
tweet_schema = StructType([
    StructField('created_at', StringType(), True),
    StructField('id', LongType(), True),
    StructField('text', StringType(), True),
    StructField('is_quote_status', BooleanType(), True),
    StructField('in_reply_to_user_id', LongType(), True),
    StructField('user', StructType([
        StructField('id', LongType(), True),
        StructField('followers_count', IntegerType(), True),
        StructField('friends_count', IntegerType(), True),
        StructField('created_at', StringType(), True)
    ])),
    StructField('extended_tweet', StructType([
        StructField('full_text', StringType(), True)
    ])),
    StructField('retweeted_status', StructType([
        StructField('id', LongType(), True)
    ])),
    StructField('retweet_count', IntegerType(), True),
    StructField('favorite_count', IntegerType(), True),
    StructField('quote_count', IntegerType(), True),
    StructField('reply_count', IntegerType(), True)
])

In [12]:
tweet_stream_df = (
    stream_df
    # Convert the key and value from binary to StringType
    .withColumn('key', stream_df['key'].cast(StringType()))
    .withColumn('value', stream_df['value'].cast(StringType()))
    # Assign fields to JSON
    .withColumn('value', F.from_json('value', tweet_schema))
    .select('timestamp',
            'value.created_at',
            'value.text',
            'value.extended_tweet.full_text')
    .where(stream_df.timestamp > F.current_timestamp() - F.expr('INTERVAL 1 seconds'))
)

In [13]:
tweet_stream_df.printSchema()

root
 |-- timestamp: timestamp (nullable = true)
 |-- created_at: string (nullable = true)
 |-- text: string (nullable = true)
 |-- full_text: string (nullable = true)



## Check `tweet_stream_df`

In [14]:
tweet_stream = (
    tweet_stream_df
    .writeStream
    .format('memory')
    .queryName('tweet_view')
    .outputMode('update')
    .start()
)

In [15]:
table_name = 'tweet_view'
query = f"""
SELECT timestamp, text
FROM {table_name}
--WHERE timestamp > (CURRENT_TIMESTAMP() - INTERVAL 1 seconds)
"""
query_df = spark.sql(query)
query_df.show()

+--------------------+--------------------+
|           timestamp|                text|
+--------------------+--------------------+
|2021-06-13 06:52:...|RT @jack: The peo...|
|2021-06-13 06:52:...|RT @ladyincrypto:...|
|2021-06-13 06:52:...|RT @DaCryptoMonke...|
|2021-06-13 06:52:...|12 Ways That I Pa...|
|2021-06-13 06:52:...|@GMaurya1411019 @...|
|2021-06-13 06:52:...|Treasury Secretar...|
|2021-06-13 06:52:...|RT @BVYCrypto: @C...|
|2021-06-13 06:52:...|I am starting to ...|
|2021-06-13 06:52:...|Possible Uptrend ...|
|2021-06-13 06:52:...|How to add the Sm...|
|2021-06-13 06:52:...|        Nice project|
|2021-06-13 06:52:...|RT @exoswrld: rem...|
|2021-06-13 06:52:...|RT @trader1sz: I’...|
|2021-06-13 06:52:...|good project
 #BS...|
|2021-06-13 06:52:...|RT @Roobet: $2,50...|
|2021-06-13 06:52:...|@JRNYcrypto Imagi...|
|2021-06-13 06:52:...|RT @Lion_King_Tok...|
|2021-06-13 06:52:...|RT @BitspawnGG: B...|
|2021-06-13 06:52:...|Thank you so much...|
|2021-06-13 06:52:...|RT @rawfuc

In [17]:
query_df.count()

133

# Create UDF

In [25]:
@F.udf
def predict_topic(text):
    prediction = topic_model.transform(text)
    topic_name = topic_model.get_topics()[prediction[0][0]][0][0]
    return topic_name

In [26]:
spark.udf.register('predict_topic', predict_topic)

<function __main__.predict_topic(text)>

In [20]:
table_name = 'tweet_view'
query = f"""
SELECT 
  timestamp
, text
FROM {table_name}
WHERE timestamp > (CURRENT_TIMESTAMP() - INTERVAL 1 seconds)
"""
query_df = spark.sql(query)
query_df.show()
print(query_df.count())

+--------------------+--------------------+
|           timestamp|                text|
+--------------------+--------------------+
|2021-06-13 06:52:...|RT @SwapQi: 💥Lau...|
|2021-06-13 06:52:...|RT @BaubleNft: Wi...|
|2021-06-13 06:52:...|RT @orderally: 🎉...|
|2021-06-13 06:52:...|RT @potatoe421770...|
|2021-06-13 06:52:...|RT @AlkayalWajdi:...|
|2021-06-13 06:52:...|@SwapQi
Qiswap is...|
|2021-06-13 06:52:...|@bellathorne @the...|
|2021-06-13 06:52:...|RT @VincePrince24...|
|2021-06-13 06:52:...|RT @bobbyong: Cry...|
|2021-06-13 06:52:...|RT @raypaxful: Th...|
|2021-06-13 06:52:...|RT @CNBC: Bitcoin...|
|2021-06-13 06:52:...|RT @realjunsoncha...|
|2021-06-13 06:52:...|@mantarayswap 
Go...|
|2021-06-13 06:52:...|RT @CamelGlobal: ...|
|2021-06-13 06:52:...|RT @CryptoTamil: ...|
|2021-06-13 06:52:...|@elonmusk #ElonMu...|
|2021-06-13 06:52:...|@ArtZentsik Mrs. ...|
+--------------------+--------------------+

17


In [27]:
query_df.select('timestamp', 'text', predict_topic('text').alias('predicted_topic')).show()

KeyboardInterrupt: 

In [64]:
table_name = 'tweet_view'
query_predict = f"""
SELECT 
  timestamp
, text
, predict_topic(text) AS predicted_topic
FROM {table_name}
WHERE timestamp > (CURRENT_TIMESTAMP() - INTERVAL 1 seconds)
"""
query_predict_df = spark.sql(query_predict)
query_predict_df.show()

+--------------------+--------------------+------------------+
|           timestamp|                text|   predicted_topic|
+--------------------+--------------------+------------------+
|2021-06-12 06:03:...|RT @m1sterone: Th...|               why|
|2021-06-12 06:03:...|@TheDoggyCoin shi...|               why|
|2021-06-12 06:03:...|RT @evan_van_ness...|               why|
|2021-06-12 06:03:...|RT @shiba_coin: H...|    soltanmojtaba6|
|2021-06-12 06:03:...|@dens_club @Polyt...|               why|
|2021-06-12 06:03:...|@VivetVeritate Wh...|          birthday|
|2021-06-12 06:03:...|RT @cycloneprotoc...|       coinwindcom|
|2021-06-12 06:03:...|Bitcoin and El Sa...|               why|
|2021-06-12 06:03:...|Filecoin and Chia...|             china|
|2021-06-12 06:03:...|RT @mythica_magic...|httpstcolbca2w1v4a|
|2021-06-12 06:03:...|RT @sama_kasa: Dn...|               sun|
|2021-06-12 06:03:...|RT @SPACEdotcom: ...|              halo|
|2021-06-12 06:03:...|The price of #Sol...|            

In [None]:
query_df.select('timestamp', 'text', predict_topic('text').alias('predicted_topic'))

# Stream to memory

Test if the prediction model can stream successfully.

# Stream to Kafka

In [28]:
kafka_checkpoint = processed_data_dir / 'kafka_checkpoint'

In [29]:
kafka_stream = (
    tweet_stream_df
    .select(F.col('text').alias('key'), predict_topic('text').alias('value'))
    .where(tweet_stream_df.timestamp > (F.current_timestamp() - F.expr('INTERVAL 1 seconds')))
    .writeStream
    .format('kafka')
    .option('checkpointLocation', kafka_checkpoint.as_posix())
    .option("kafka.bootstrap.servers", "broker:29092")
    .option("topic", "topic_predictions")
    .start()
)

In [242]:
kafka_stream.stop()

# Test `topic_predictions` 

Check if the stream from Kafka is correct.

In [30]:
pred_df = (
    spark 
    .readStream 
    .format("kafka") 
    .option("kafka.bootstrap.servers", "broker:29092") 
    .option("startingOffsets", "earliest") 
    .option("subscribe", "topic_predictions") 
    .load()
)

In [31]:
pred_df.printSchema()

root
 |-- key: binary (nullable = true)
 |-- value: binary (nullable = true)
 |-- topic: string (nullable = true)
 |-- partition: integer (nullable = true)
 |-- offset: long (nullable = true)
 |-- timestamp: timestamp (nullable = true)
 |-- timestampType: integer (nullable = true)



In [34]:
pred_raw_stream = (
    pred_df
    .withColumn('key', pred_df['key'].cast(StringType()))
    .withColumn('value', pred_df['value'].cast(StringType()))
    .writeStream
    .format('memory')
    .queryName('prediction_raw_view')
    .start()
)

In [38]:
table_name = 'prediction_raw_view'
query = f"""
SELECT *
FROM {table_name}
"""
print(query)
spark.sql(query).show()


SELECT *
FROM prediction_raw_view

+--------------------+------------------+-----------------+---------+------+--------------------+-------------+
|                 key|             value|            topic|partition|offset|           timestamp|timestampType|
+--------------------+------------------+-----------------+---------+------+--------------------+-------------+
|RT @Bad_J0ker: We...|               why|topic_predictions|        0|     0|2021-06-13 06:56:...|            0|
|reputation and st...|            richie|topic_predictions|        0|     1|2021-06-13 06:56:...|            0|
|Very innovative c...|               why|topic_predictions|        0|     2|2021-06-13 06:56:...|            0|
|Burn it all! Init...|               why|topic_predictions|        0|     3|2021-06-13 06:56:...|            0|
|How hard is it to...|               why|topic_predictions|        0|     4|2021-06-13 06:56:...|            0|
|$XRP $DOGE $CRV H...|               why|topic_predictions|        0

In [37]:
print(query)
spark.sql(query).count()


SELECT *
FROM prediction_raw_view



0

# Stream to Parquet

In [247]:
parquet_path = processed_data_dir / 'prediction_parquet'
parquet_checkpoint_path = processed_data_dir / 'prediction_parquet_checkpoint'

In [248]:
parquet_stream = (
    tweet_stream_df
    .select(F.col('text').alias('key'), predict_topic('text').alias('value'))
    .where(tweet_stream_df.timestamp > (F.current_timestamp() - F.expr('INTERVAL 1 seconds')))
    .writeStream
    .option('path', parquet_path.as_posix())
    .outputMode('append')
    .option('checkpointLocation', parquet_checkpoint_path.as_posix())
    .start()
)

In [249]:
parquet_stream.stop()