In [25]:
SPARK_MASTER = 'spark://cm027:52337'
SAVE_DIRECTORY = 'scratch/tmv7269/lda'

In [None]:
# HPC cluster's pyspark is 3.1.2
!pip install --upgrade datasets apache-beam pyspark==3.1.2 findspark

In [None]:
import os
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
import pyspark

In [None]:
current_context = SparkContext.getOrCreate()
current_context.stop()

In [None]:
conf = SparkConf().setAppName("LDA") \
                  .setMaster(SPARK_MASTER)


sc = SparkContext(conf=conf)

spark = SparkSession(sc)

udf_registration = pyspark.sql.udf.UDFRegistration(spark)

spark

In [26]:
from datasets import load_dataset
import pandas as pd
import time

# LDA parameters, go to max in production
MAX_ITER = 100
CHECKPOINT = 5
K = 100

In [34]:
# change it to this to use the full 1m dataset
# dataset = load_dataset("wikipedia", "20220301.en")
dataset = load_dataset("wikipedia", "20220301.simple")
dataset



  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'url', 'title', 'text'],
        num_rows: 205328
    })
})

In [10]:
sparkDF = spark.createDataFrame(dataset['train'])
sparkDF.columns

['id', 'text', 'title', 'url']

In [11]:
sparkDF.count()

23/04/29 18:18:15 WARN TaskSetManager: Stage 0 contains a task of very large size (11931 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

205328

In [14]:
from pyspark.ml.feature import StopWordsRemover, RegexTokenizer, CountVectorizer
from pyspark.sql.functions import col,udf
from pyspark.sql.types import IntegerType

In [15]:
tokenizer = RegexTokenizer(inputCol='text',outputCol='words', pattern = '[^a-zA-Z]')
tokenized_df = tokenizer.transform(sparkDF).drop('text')
tokenized_df.head()

23/04/29 18:18:42 WARN TaskSetManager: Stage 2 contains a task of very large size (11931 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

Row(id='1', title='April', url='https://simple.wikipedia.org/wiki/April', words=['april', 'is', 'the', 'fourth', 'month', 'of', 'the', 'year', 'in', 'the', 'julian', 'and', 'gregorian', 'calendars', 'and', 'comes', 'between', 'march', 'and', 'may', 'it', 'is', 'one', 'of', 'four', 'months', 'to', 'have', 'days', 'april', 'always', 'begins', 'on', 'the', 'same', 'day', 'of', 'week', 'as', 'july', 'and', 'additionally', 'january', 'in', 'leap', 'years', 'april', 'always', 'ends', 'on', 'the', 'same', 'day', 'of', 'the', 'week', 'as', 'december', 'april', 's', 'flowers', 'are', 'the', 'sweet', 'pea', 'and', 'daisy', 'its', 'birthstone', 'is', 'the', 'diamond', 'the', 'meaning', 'of', 'the', 'diamond', 'is', 'innocence', 'the', 'month', 'april', 'comes', 'between', 'march', 'and', 'may', 'making', 'it', 'the', 'fourth', 'month', 'of', 'the', 'year', 'it', 'also', 'comes', 'first', 'in', 'the', 'year', 'out', 'of', 'the', 'four', 'months', 'that', 'have', 'days', 'as', 'june', 'september', 

In [16]:
remover = StopWordsRemover(inputCol="words", outputCol="filtered")
removed_df = remover.transform(tokenized_df).drop('words')
removed_df.head()

23/04/29 18:18:58 WARN TaskSetManager: Stage 3 contains a task of very large size (11931 KiB). The maximum recommended task size is 1000 KiB.


Row(id='1', title='April', url='https://simple.wikipedia.org/wiki/April', filtered=['april', 'fourth', 'month', 'year', 'julian', 'gregorian', 'calendars', 'comes', 'march', 'may', 'one', 'four', 'months', 'days', 'april', 'always', 'begins', 'day', 'week', 'july', 'additionally', 'january', 'leap', 'years', 'april', 'always', 'ends', 'day', 'week', 'december', 'april', 'flowers', 'sweet', 'pea', 'daisy', 'birthstone', 'diamond', 'meaning', 'diamond', 'innocence', 'month', 'april', 'comes', 'march', 'may', 'making', 'fourth', 'month', 'year', 'also', 'comes', 'first', 'year', 'four', 'months', 'days', 'june', 'september', 'november', 'later', 'year', 'april', 'begins', 'day', 'week', 'july', 'every', 'year', 'day', 'week', 'january', 'leap', 'years', 'april', 'ends', 'day', 'week', 'december', 'every', 'year', 'last', 'days', 'exactly', 'weeks', 'days', 'apart', 'common', 'years', 'april', 'starts', 'day', 'week', 'october', 'previous', 'year', 'leap', 'years', 'may', 'previous', 'year

In [17]:
cv = CountVectorizer(inputCol="filtered", outputCol="features", minDF=2.0).fit(removed_df)

lda_count = cv.transform(removed_df).drop('filtered')
lda_count.head()

23/04/29 18:19:00 WARN TaskSetManager: Stage 4 contains a task of very large size (11931 KiB). The maximum recommended task size is 1000 KiB.
23/04/29 18:19:10 WARN DAGScheduler: Broadcasting large task binary with size 2.2 MiB
23/04/29 18:19:10 WARN TaskSetManager: Stage 8 contains a task of very large size (11931 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

Row(id='1', title='April', url='https://simple.wikipedia.org/wiki/April', features=SparseVector(221737, {0: 5.0, 1: 1.0, 2: 7.0, 3: 3.0, 4: 17.0, 5: 1.0, 6: 10.0, 7: 3.0, 8: 13.0, 9: 8.0, 12: 3.0, 13: 1.0, 14: 1.0, 17: 13.0, 18: 2.0, 19: 1.0, 22: 1.0, 25: 10.0, 26: 6.0, 27: 1.0, 32: 7.0, 34: 1.0, 35: 4.0, 37: 4.0, 38: 5.0, 40: 1.0, 41: 1.0, 42: 20.0, 44: 7.0, 46: 3.0, 49: 4.0, 50: 5.0, 55: 25.0, 56: 1.0, 59: 1.0, 60: 6.0, 61: 13.0, 62: 1.0, 65: 1.0, 66: 1.0, 67: 4.0, 71: 3.0, 72: 3.0, 76: 3.0, 77: 5.0, 78: 208.0, 81: 1.0, 82: 3.0, 85: 3.0, 86: 1.0, 90: 1.0, 93: 2.0, 94: 2.0, 96: 1.0, 97: 1.0, 103: 2.0, 106: 1.0, 109: 96.0, 111: 10.0, 117: 5.0, 120: 1.0, 123: 1.0, 126: 2.0, 129: 2.0, 140: 1.0, 142: 5.0, 147: 4.0, 152: 3.0, 153: 5.0, 154: 1.0, 155: 1.0, 156: 1.0, 158: 3.0, 159: 2.0, 160: 1.0, 161: 2.0, 162: 1.0, 170: 1.0, 175: 4.0, 176: 2.0, 181: 5.0, 183: 2.0, 184: 4.0, 186: 2.0, 187: 2.0, 188: 1.0, 192: 1.0, 194: 1.0, 198: 1.0, 199: 1.0, 200: 4.0, 203: 1.0, 204: 1.0, 205: 2.0, 206: 3.0

In [27]:
from pyspark.ml.clustering import LDA

# create LDA with K topics

start = time.time()
lda = LDA(k=K, seed=1, optimizer="em", maxIter=MAX_ITER, checkpointInterval=CHECKPOINT)
model = lda.fit(lda_count)
end = time.time()

print(f"Time elapsed: {end-start:.2f} seconds")

23/04/29 18:28:53 WARN DAGScheduler: Broadcasting large task binary with size 2.2 MiB
23/04/29 18:28:53 WARN TaskSetManager: Stage 46 contains a task of very large size (11931 KiB). The maximum recommended task size is 1000 KiB.
23/04/29 18:28:53 WARN DAGScheduler: Broadcasting large task binary with size 2.2 MiB
23/04/29 18:28:53 WARN TaskSetManager: Stage 47 contains a task of very large size (11931 KiB). The maximum recommended task size is 1000 KiB.
23/04/29 18:28:53 WARN DAGScheduler: Broadcasting large task binary with size 2.2 MiB
23/04/29 18:28:53 WARN DAGScheduler: Broadcasting large task binary with size 2.2 MiB
23/04/29 18:28:54 WARN TaskSetManager: Stage 48 contains a task of very large size (11931 KiB). The maximum recommended task size is 1000 KiB.
23/04/29 18:28:57 WARN TaskSetManager: Stage 49 contains a task of very large size (11931 KiB). The maximum recommended task size is 1000 KiB.

Time elapsed: 895.63 seconds


                                                                                

In [28]:
topicIndices = model.describeTopics(maxTermsPerTopic = 5)
vocabList = cv.vocabulary

for row in topicIndices.collect():
    print(f"Topic {row.topic + 1}: ")
    for topic, weight in zip(row.termIndices, row.termWeights):
        print(f"{vocabList[topic]} {weight:.2E}")
    print()

Topic 1: 
league 6.48E-02
football 6.11E-02
rowspan 5.91E-02
club 4.23E-02
colspan 3.72E-02

Topic 2: 
may 3.65E-01
death 9.33E-02
home 5.77E-02
hospital 3.31E-02
natural 3.24E-02

Topic 3: 
countries 3.16E-02
empire 2.80E-02
history 1.25E-02
europe 1.24E-02
independence 1.20E-02

Topic 4: 
websites 5.12E-01
official 1.61E-01
website 1.39E-01
site 4.33E-02
page 2.72E-02

Topic 5: 
movies 1.86E-01
movie 7.49E-02
directed 4.30E-02
drama 4.09E-02
comedy 3.18E-02

Topic 6: 
john 8.11E-02
james 4.09E-02
george 3.89E-02
william 3.28E-02
robert 2.88E-02

Topic 7: 
business 1.64E-02
development 1.29E-02
bank 1.28E-02
public 1.08E-02
services 9.88E-03

Topic 8: 
best 1.01E-01
award 8.75E-02
awards 4.72E-02
won 3.73E-02
academy 2.92E-02

Topic 9: 
german 8.59E-02
germany 7.74E-02
austria 1.91E-02
austrian 1.85E-02
bavaria 1.77E-02

Topic 10: 
art 4.55E-02
ancient 3.34E-02
bc 2.87E-02
greek 2.39E-02
artists 1.59E-02

Topic 11: 
english 1.40E-01
french 1.15E-01
paris 4.06E-02
de 3.08E-02
l 2.98E-0

In [32]:
model.write().overwrite().save('/scratch/tmv7269/lda')

                                                                                