In [2]:
from pyspark import SparkContext, StorageLevel
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.ml import Pipeline
from pyspark.sql.types import ArrayType, StringType, BooleanType, IntegerType, StructType, StructField, FloatType
from pyspark.ml.feature import Tokenizer, CountVectorizer, IDF, StopWordsRemover, Normalizer
from transformers import pipeline
import pandas as pd

In [3]:
spark = SparkSession.builder.appName("WILDCHAT-1M") \
    .config('spark.driver.memory', '24g') \
    .config('spark.executor.memory', '12g') \
    .config('spark.sql.debug.maxToStringFields', 1000) \
    .config("spark.default.parallelism", "10") \
    .master('local[8]') \
    .config("spark.driver.maxResultSize", "10g") \
    .getOrCreate()
spark

24/11/24 13:26:37 WARN Utils: Your hostname, Sharans-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 10.0.0.30 instead (on interface en0)
24/11/24 13:26:37 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/11/24 13:26:38 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [4]:
main_df = spark.read.parquet('/Users/sharan/Desktop/IDMP Data/*.parquet')
main_df.show()

+--------------------+------------------+-------------------+--------------------+----+--------+--------------------+--------------------+-----+--------+------------------+-------------+--------------------+--------------------+
|   conversation_hash|             model|          timestamp|        conversation|turn|language|   openai_moderation| detoxify_moderation|toxic|redacted|             state|      country|           hashed_ip|              header|
+--------------------+------------------+-------------------+--------------------+----+--------+--------------------+--------------------+-----+--------+------------------+-------------+--------------------+--------------------+
|698e02bae74e1ca4e...|        gpt-4-0314|2023-04-08 20:01:06|[{POUVEZ VOUS ME ...|   1|  French|[{{false, false, ...|[{0.0120244864374...| true|   false|             Dakar|      Senegal|cc4eb1e4234c16afc...|{fr,fr-FR;q=0.9,e...|
|c9ec5b440fbdd2a26...|        gpt-4-0314|2023-04-08 20:02:53|[{Hey there! Are ...|  

In [5]:
journalism_df = main_df

In [6]:
journalism_df.filter(F.col('country').isNull()).select('state').distinct().collect()

[Row(state=None)]

In [7]:
journalism_df.select('country').distinct().collect()

[Row(country='Russia'),
 Row(country='Macao'),
 Row(country='Yemen'),
 Row(country='Senegal'),
 Row(country='Sweden'),
 Row(country='Cabo Verde'),
 Row(country='The Netherlands'),
 Row(country='Philippines'),
 Row(country='Singapore'),
 Row(country='Malaysia'),
 Row(country='Iraq'),
 Row(country='Germany'),
 Row(country='Rwanda'),
 Row(country='Jordan'),
 Row(country='Maldives'),
 Row(country='Ivory Coast'),
 Row(country='France'),
 Row(country='Greece'),
 Row(country='Kosovo'),
 Row(country='Sri Lanka'),
 Row(country='Taiwan'),
 Row(country='Algeria'),
 Row(country='Togo'),
 Row(country='Slovakia'),
 Row(country='Argentina'),
 Row(country='Belgium'),
 Row(country='Ecuador'),
 Row(country='Qatar'),
 Row(country='Madagascar'),
 Row(country='Finland'),
 Row(country='Türkiye'),
 Row(country='Nicaragua'),
 Row(country='Myanmar'),
 Row(country='Ghana'),
 Row(country='Peru'),
 Row(country='Benin'),
 Row(country='United States'),
 Row(country='China'),
 Row(country='India'),
 Row(country='Bah

In [8]:
import requests

response = requests.get("https://restcountries.com/v3.1/region/europe")
journalism_countries = [country["name"]["common"] for country in response.json()]
journalism_countries = journalism_countries + ['United States']
print(journalism_countries)

['Norway', 'Greece', 'Åland Islands', 'Switzerland', 'Croatia', 'Iceland', 'Luxembourg', 'Hungary', 'Netherlands', 'Lithuania', 'Slovakia', 'Liechtenstein', 'Moldova', 'Italy', 'Jersey', 'Monaco', 'Belarus', 'Latvia', 'Andorra', 'France', 'Gibraltar', 'Denmark', 'North Macedonia', 'Malta', 'Czechia', 'Guernsey', 'Kosovo', 'Svalbard and Jan Mayen', 'Montenegro', 'Faroe Islands', 'Albania', 'Serbia', 'Ukraine', 'Isle of Man', 'Estonia', 'Romania', 'Bulgaria', 'Germany', 'Poland', 'United Kingdom', 'Finland', 'Sweden', 'Vatican City', 'Russia', 'Austria', 'Cyprus', 'Portugal', 'Bosnia and Herzegovina', 'Belgium', 'Spain', 'Slovenia', 'San Marino', 'Ireland', 'United States']


In [9]:
journalism_df = journalism_df.drop('timestamp', 'openai_moderation', 'detoxify_moderation', 'hashed_ip', 'header')
journalism_df = journalism_df.filter((F.col('toxic') == False) & 
                                  (F.col('language') == "English") & 
                                  (F.col('country').isin(journalism_countries))
                                )


In [10]:
journalism_df.select('country').distinct().collect()

[Row(country='Russia'),
 Row(country='Sweden'),
 Row(country='Germany'),
 Row(country='France'),
 Row(country='Greece'),
 Row(country='Kosovo'),
 Row(country='Slovakia'),
 Row(country='Belgium'),
 Row(country='Finland'),
 Row(country='United States'),
 Row(country='Belarus'),
 Row(country='Malta'),
 Row(country='Croatia'),
 Row(country='Italy'),
 Row(country='Lithuania'),
 Row(country='Norway'),
 Row(country='Spain'),
 Row(country='Czechia'),
 Row(country='Denmark'),
 Row(country='Ireland'),
 Row(country='Ukraine'),
 Row(country='Cyprus'),
 Row(country='Estonia'),
 Row(country='Switzerland'),
 Row(country='Latvia'),
 Row(country='North Macedonia'),
 Row(country='Slovenia'),
 Row(country='Luxembourg'),
 Row(country='Bosnia and Herzegovina'),
 Row(country='Poland'),
 Row(country='Portugal'),
 Row(country='Romania'),
 Row(country='Bulgaria'),
 Row(country='Austria'),
 Row(country='Serbia'),
 Row(country='Hungary'),
 Row(country='United Kingdom'),
 Row(country='Moldova'),
 Row(country='Alb

In [11]:
journalism_df.select('language').distinct().collect()

[Row(language='English')]

In [12]:
journalism_df.select('toxic').distinct().collect()

[Row(toxic=False)]

In [13]:
journalism_df = journalism_df.drop('toxic', 'language')

In [14]:
journalism_df.printSchema()

root
 |-- conversation_hash: string (nullable = true)
 |-- model: string (nullable = true)
 |-- conversation: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- content: string (nullable = true)
 |    |    |-- country: string (nullable = true)
 |    |    |-- hashed_ip: string (nullable = true)
 |    |    |-- header: struct (nullable = true)
 |    |    |    |-- accept-language: string (nullable = true)
 |    |    |    |-- user-agent: string (nullable = true)
 |    |    |-- language: string (nullable = true)
 |    |    |-- redacted: boolean (nullable = true)
 |    |    |-- role: string (nullable = true)
 |    |    |-- state: string (nullable = true)
 |    |    |-- timestamp: timestamp (nullable = true)
 |    |    |-- toxic: boolean (nullable = true)
 |    |    |-- turn_identifier: long (nullable = true)
 |-- turn: long (nullable = true)
 |-- redacted: boolean (nullable = true)
 |-- state: string (nullable = true)
 |-- country: string (nullable = true)


In [15]:
journalism_df = journalism_df.withColumn('conversation_explode', F.explode(F.col("conversation"))) \
    .withColumn('prompt', F.col('conversation_explode.content')) \
    .withColumn('turn_identifier', F.col('conversation_explode.turn_identifier')) \
    .drop('conversation_explode')

In [16]:
journalism_df.filter(F.col('state').isNull()).count()

101652

In [17]:
journalism_df = journalism_df.fillna({'state' : " "})
journalism_df.filter(F.col('state').isNull()).count()

0

In [18]:
journalism_df.printSchema()

root
 |-- conversation_hash: string (nullable = true)
 |-- model: string (nullable = true)
 |-- conversation: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- content: string (nullable = true)
 |    |    |-- country: string (nullable = true)
 |    |    |-- hashed_ip: string (nullable = true)
 |    |    |-- header: struct (nullable = true)
 |    |    |    |-- accept-language: string (nullable = true)
 |    |    |    |-- user-agent: string (nullable = true)
 |    |    |-- language: string (nullable = true)
 |    |    |-- redacted: boolean (nullable = true)
 |    |    |-- role: string (nullable = true)
 |    |    |-- state: string (nullable = true)
 |    |    |-- timestamp: timestamp (nullable = true)
 |    |    |-- toxic: boolean (nullable = true)
 |    |    |-- turn_identifier: long (nullable = true)
 |-- turn: long (nullable = true)
 |-- redacted: boolean (nullable = true)
 |-- state: string (nullable = false)
 |-- country: string (nullable = true)

In [19]:
import re

@F.udf(StringType())
def cleanText(prompt):
    if prompt:
        #clean_text = re.sub('[^a-zA-Z0-9]', '', prompt)
        #clean_text = re.sub(' \\-*/=:,.&|^%$@!%', r'\1', clean_text)
        # Remove all non-alphanumeric characters except specific ones (.-/=,:,&|^%@!%)
        clean_text = re.sub(r'[^a-zA-Z0-9\.\*/=:,.&|^%@!# ]', '', prompt)
        clean_text = clean_text.lower()
        # Replace multiple consecutive special characters with a single one
        clean_text = re.sub(r'([\-*/=:,.&|^%@!]#)\1+', r' \1 ', clean_text)
        clean_text = re.sub(r'(\d)([a-z])', r'\1 \2', clean_text)
        clean_text = re.sub(r'([.?])', r' \1 ', clean_text)
        clean_text = re.sub(r'\s+', ' ', clean_text).strip() 
        clean_text = re.sub(r'[\n\t]', ' ', clean_text)
        clean_text = clean_text.replace('/', '') 
        return clean_text
    return ''


journalism_df= journalism_df.withColumn('clean', cleanText(F.col('prompt')))

In [20]:
journalism_df = journalism_df.withColumn('clean', F.trim(F.col('clean')))

In [21]:
journalism_df = journalism_df.drop('conversation')

In [22]:
groupCols = [col for col in journalism_df.columns if col != 'prompt' and col != 'clean']
journalism_df = journalism_df.groupBy(groupCols).agg(F.concat_ws(' --botresp-- ', F.collect_list('prompt')).alias('full_interaction'),
                                                   F.concat_ws(' ', F.collect_list('clean')).alias('clean_interaction')
                                                  )

In [23]:
tk = Tokenizer(inputCol = 'clean_interaction', outputCol = 'tokenized_clean')
custom_stop_words = StopWordsRemover.loadDefaultStopWords("english") + ["\n", "\t", ""]
swr = StopWordsRemover(inputCol = 'tokenized_clean', outputCol = 'swr_clean_tokens', stopWords=custom_stop_words)
cv = CountVectorizer(inputCol = 'swr_clean_tokens', outputCol = 'raw_features', vocabSize=10000000, maxDF =0.7)
idf = IDF(inputCol = 'raw_features', outputCol = 'tfidf_features')
normalizer = Normalizer(inputCol="tfidf_features", outputCol="normalized_tfidf", p=2.0)

In [24]:
pipeline = Pipeline(stages=[tk, swr, cv, idf, normalizer])

pipeline_model = pipeline.fit(journalism_df)

processed_journalism_data = pipeline_model.transform(journalism_df)

24/11/24 13:26:51 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors
24/11/24 13:28:32 WARN DAGScheduler: Broadcasting large task binary with size 46.6 MiB
24/11/24 13:29:14 WARN DAGScheduler: Broadcasting large task binary with size 46.6 MiB
                                                                                

In [25]:
#Extract vocabulary from CountVectorizer
vocab = pipeline_model.stages[2].vocabulary

# Function to extract frequent terms from the TF-IDF vector
def extract_frequent_terms(tfidf_vector, vocab, threshold=0.2):
    # Convert the sparse vector to indices and values
    indices = tfidf_vector.indices.tolist()
    values = tfidf_vector.values.tolist()
    
    # Map indices to terms and filter by threshold
    terms = [vocab[i] for i, val in zip(indices, values) if val >= threshold if len(vocab[i]) > 1]
    return list(set(terms))

def extract_rare_terms(tfidf_vector, vocab, threshold=0.1):
    # Convert the sparse vector to indices and values
    indices = tfidf_vector.indices.tolist()
    values = tfidf_vector.values.tolist()
    
    # Map indices to terms and filter by threshold
    terms = [vocab[i] for i, val in zip(indices, values) if val <= threshold if len(vocab[i]) > 1]
    return list(set(terms))


# Broadcast the vocabulary to avoid repeated serialization
broadcast_vocab = spark.sparkContext.broadcast(vocab)

# Define UDF with a helper function
def frequent_terms_udf(tfidf_vector):
    return extract_frequent_terms(tfidf_vector, broadcast_vocab.value)

def rare_terms_udf(tfidf_vector):
    return extract_rare_terms(tfidf_vector, broadcast_vocab.value)

# Register the UDF
frequent_terms_udf = F.udf(frequent_terms_udf, ArrayType(StringType()))
rare_terms_udf = F.udf(rare_terms_udf, ArrayType(StringType()))

# Apply UDF to the DataFrame
processed_journalism_data = processed_journalism_data.withColumn(
    "frequent_terms",
    frequent_terms_udf(F.col("normalized_tfidf"))
).withColumn("rare_terms", rare_terms_udf(F.col("normalized_tfidf")))

processed_journalism_data.show()

24/11/24 13:29:54 WARN DAGScheduler: Broadcasting large task binary with size 94.5 MiB
[Stage 37:>                                                         (0 + 1) / 1]

+--------------------+------------------+----+--------+--------------------+--------------+---------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|   conversation_hash|             model|turn|redacted|               state|       country|turn_identifier|    full_interaction|   clean_interaction|     tokenized_clean|    swr_clean_tokens|        raw_features|      tfidf_features|    normalized_tfidf|      frequent_terms|          rare_terms|
+--------------------+------------------+----+--------+--------------------+--------------+---------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|0002cbb456c406ec3...|gpt-3.5-turbo-0613|   5|   false|          California| United States|        1866564|Fi

                                                                                

In [26]:
processed_journalism_data.filter(F.array_contains(F.col('frequent_terms'), 'software') | F.array_contains(F.col('rare_terms'), 'software')).first()


24/11/24 13:30:35 WARN DAGScheduler: Broadcasting large task binary with size 94.5 MiB
                                                                                

Row(conversation_hash='00393107193f78ebc59bca3804e8b34e', model='gpt-3.5-turbo-0301', turn=8, redacted=False, state='Bashkortostan Republic', country='Russia', turn_identifier=698151, full_interaction='Which software is used in Brazil for pca? --botresp-- I\'m not aware of any specific software used for post-clearance audit (PCA) in Brazil. The Brazilian government provides a digital platform called "Siscomex" to manage customs processes, including PCA. Siscomex provides various modules to streamline customs procedures, including a module for risk analysis that identifies high-risk shipments to be selected for PCA. \n\nCustoms officials in Brazil may also use other software solutions to help with PCA, such as data analytics tools to analyze large amounts of data and identify anomalies in the import or export process. However, the specific software used may vary depending on the customs authority or department within the Brazilian government.', clean_interaction='which software is used 

In [27]:
processed_journalism_data.printSchema()

root
 |-- conversation_hash: string (nullable = true)
 |-- model: string (nullable = true)
 |-- turn: long (nullable = true)
 |-- redacted: boolean (nullable = true)
 |-- state: string (nullable = false)
 |-- country: string (nullable = true)
 |-- turn_identifier: long (nullable = true)
 |-- full_interaction: string (nullable = false)
 |-- clean_interaction: string (nullable = false)
 |-- tokenized_clean: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- swr_clean_tokens: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- raw_features: vector (nullable = true)
 |-- tfidf_features: vector (nullable = true)
 |-- normalized_tfidf: vector (nullable = true)
 |-- frequent_terms: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- rare_terms: array (nullable = true)
 |    |-- element: string (containsNull = true)



In [28]:
processed_journalism_data = processed_journalism_data.drop('tokenized_clean','swr_clean_tokens','raw_features', 'tfidf_features', 'normalized_tfidf')

In [29]:
processed_journalism_data.filter(F.col('conversation_hash') == '00393107193f78ebc59bca3804e8b34e').show()

24/11/24 13:30:41 WARN DAGScheduler: Broadcasting large task binary with size 94.5 MiB
[Stage 43:>                                                         (0 + 1) / 1]

+--------------------+------------------+----+--------+--------------------+-------+---------------+--------------------+--------------------+--------------------+--------------------+
|   conversation_hash|             model|turn|redacted|               state|country|turn_identifier|    full_interaction|   clean_interaction|      frequent_terms|          rare_terms|
+--------------------+------------------+----+--------+--------------------+-------+---------------+--------------------+--------------------+--------------------+--------------------+
|00393107193f78ebc...|gpt-3.5-turbo-0301|   8|   false|Bashkortostan Rep...| Russia|         698151|Which software is...|which software is...|[pca, brazilian, ...|[digital, specifi...|
|00393107193f78ebc...|gpt-3.5-turbo-0301|   8|   false|Bashkortostan Rep...| Russia|         698059|Provide me with a...|provide me with a...|[038, monteiro,, ...|[machine, hope, f...|
|00393107193f78ebc...|gpt-3.5-turbo-0301|   8|   false|Bashkortostan Rep...

                                                                                

In [56]:
output_directory = "/Users/sharan/Desktop/JournalismChats"
processed_journalism_data.coalesce(15).write.mode('append').parquet(output_directory)

24/11/24 02:53:29 WARN DAGScheduler: Broadcasting large task binary with size 94.7 MiB
                                                                                