In [31]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, split, udf, regexp_replace, lit, from_unixtime
from pyspark.sql.types import ArrayType, StringType, StructType, StructField, IntegerType, StringType, MapType
from pyspark.sql.functions import split
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, MapType, StringType
from pyspark.sql.functions import explode
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lower, regexp_replace, length, size, udf
from pyspark.sql.types import StringType, ArrayType, IntegerType
from pyspark.ml.feature import Tokenizer, StopWordsRemover

import json

from pyspark.sql.functions import split, explode, regexp_extract, col, collect_list, udf, broadcast
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, DoubleType, FloatType
from pyspark.ml.linalg import VectorUDT, Vectors
import numpy as np
import os

# NEWS PRE PROCESSING

In [32]:
from pyspark.sql import SparkSession

# Initialize a SparkSession
spark = SparkSession.builder \
    .appName("MIND Dataset Processing") \
    .getOrCreate()

24/12/12 18:07:23 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [33]:
# Define the path to the news.tsv file
news_path = "data/mind/MINDsmall_train/news.tsv"

# Define column names for the news.tsv file
news_columns = ["NewsID", "Category", "Subcategory", "Title", "Abstract", "URL", "TitleEntities", "AbstractEntities"]

# Load the news.tsv file into a Spark DataFrame
news_df = spark.read.csv(
    news_path,
    sep="\t",
    schema="NewsID STRING, Category STRING, Subcategory STRING, Title STRING, Abstract STRING, URL STRING, TitleEntities STRING, AbstractEntities STRING",
    header=False
)

# Assign column names
news_df.show(n=2, truncate=True) 

+------+---------+---------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|NewsID| Category|    Subcategory|               Title|            Abstract|                 URL|       TitleEntities|    AbstractEntities|
+------+---------+---------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|N55528|lifestyle|lifestyleroyals|The Brands Queen ...|Shop the notebook...|https://assets.ms...|[{"Label": "Princ...|                  []|
|N19639|   health|     weightloss|50 Worst Habits F...|These seemingly h...|https://assets.ms...|[{"Label": "Adipo...|[{"Label": "Adipo...|
+------+---------+---------------+--------------------+--------------------+--------------------+--------------------+--------------------+
only showing top 2 rows



In [34]:
news_df.columns

['NewsID',
 'Category',
 'Subcategory',
 'Title',
 'Abstract',
 'URL',
 'TitleEntities',
 'AbstractEntities']

In [35]:
# Initialize Spark session
spark = SparkSession.builder.appName("PreprocessingPipeline").getOrCreate()

# Load your data (modify the path as necessary)
news_df = spark.read.csv("data/mind/MINDsmall_train/news.tsv", sep="\t", header=False, inferSchema=True)

# Assign column names
news_df = news_df.toDF("NewsID", "Category", "Subcategory", "Title", "Abstract", "URL", "TitleEntities", "AbstractEntities")

# Display initial rows
news_df.show(5, truncate=True)

24/12/12 18:07:24 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


+------+---------+---------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|NewsID| Category|    Subcategory|               Title|            Abstract|                 URL|       TitleEntities|    AbstractEntities|
+------+---------+---------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|N55528|lifestyle|lifestyleroyals|The Brands Queen ...|Shop the notebook...|https://assets.ms...|[{"Label": "Princ...|                  []|
|N19639|   health|     weightloss|50 Worst Habits F...|These seemingly h...|https://assets.ms...|[{"Label": "Adipo...|[{"Label": "Adipo...|
|N61837|     news|      newsworld|The Cost of Trump...|Lt. Ivan Molchane...|https://assets.ms...|                  []|[{"Label": "Ukrai...|
|N53526|   health|         voices|I Was An NBA Wife...|I felt like I was...|https://assets.ms...|                  []|[{"Label": "Natio...|
|N38324|   health|  

In [36]:
### MISSING VALUES ###
print(f"Rows before dropping missing values: {news_df.count()}")

# Drop rows where Title or Abstract are missing
news_df = news_df.na.drop(subset=["Title", "Abstract"])

# Verify the results
print(f"Rows after dropping missing values: {news_df.count()}")


Rows before dropping missing values: 51282
Rows after dropping missing values: 48616


In [37]:
### TEXT CLEANING ###

# Define a function to clean text (remove special characters and convert to lowercase)
def clean_text(text):
    if text:
        return text.lower().replace("\n", " ").replace("\t", " ")
    return None

# Register the UDF
clean_text_udf = udf(lambda x: clean_text(x), StringType())

# Apply text cleaning to Title and Abstract
news_df = news_df.withColumn("CleanTitle", clean_text_udf(col("Title")))
news_df = news_df.withColumn("CleanAbstract", clean_text_udf(col("Abstract")))

# Display cleaned text
news_df.select("CleanTitle", "CleanAbstract").show(5, truncate=False)

+----------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|CleanTitle                                                            |CleanAbstract                                                                                                                                                                                       |
+----------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|the brands queen elizabeth, prince charles, and prince philip swear by|shop the notebooks, jackets, and more that the royals can't live without.                                             

In [38]:
### TOKENIZATION ###

# Tokenize CleanTitle and CleanAbstract
tokenizer_title = Tokenizer(inputCol="CleanTitle", outputCol="TitleTokens")
tokenizer_abstract = Tokenizer(inputCol="CleanAbstract", outputCol="AbstractTokens")

news_df = tokenizer_title.transform(news_df)
news_df = tokenizer_abstract.transform(news_df)

# Display tokenized data
news_df.select("TitleTokens", "AbstractTokens").show(5, truncate=False)

+----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|TitleTokens                                                                       |AbstractTokens                                                                                                                                                                                                                           |
+----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|[the, brands, queen, elizabeth,, prince, c

In [39]:
### STOPWORDS REMOVAL ###

# Remove stopwords from TitleTokens and AbstractTokens
stopword_remover_title = StopWordsRemover(inputCol="TitleTokens", outputCol="FilteredTitleTokens")
stopword_remover_abstract = StopWordsRemover(inputCol="AbstractTokens", outputCol="FilteredAbstractTokens")

news_df = stopword_remover_title.transform(news_df)
news_df = stopword_remover_abstract.transform(news_df)

# Display filtered tokens
news_df.select("FilteredTitleTokens", "FilteredAbstractTokens").show(5, truncate=False)


+--------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|FilteredTitleTokens                                                 |FilteredAbstractTokens                                                                                                                                                |
+--------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|[brands, queen, elizabeth,, prince, charles,, prince, philip, swear]|[shop, notebooks,, jackets,, royals, live, without.]                                                                                                                  |
|[50, worst, habits, belly, fat]                

In [40]:
# Define a UDF to clean each token in the array
def clean_tokens(tokens):
    if tokens:
        return [token.replace(",", "") for token in tokens]  # Remove commas
    return tokens

# Register the UDF
clean_tokens_udf = udf(clean_tokens, ArrayType(StringType()))

# Apply the UDF to FilteredTitleTokens
news_df = news_df.withColumn("FilteredTitleTokens", clean_tokens_udf(col("FilteredTitleTokens")))
news_df = news_df.withColumn("FilteredAbstractTokens", clean_tokens_udf(col("FilteredAbstractTokens")))

In [41]:
news_df.select("Title", "FilteredTitleTokens").show(5, truncate=False)

+----------------------------------------------------------------------+------------------------------------------------------------------+
|Title                                                                 |FilteredTitleTokens                                               |
+----------------------------------------------------------------------+------------------------------------------------------------------+
|The Brands Queen Elizabeth, Prince Charles, and Prince Philip Swear By|[brands, queen, elizabeth, prince, charles, prince, philip, swear]|
|50 Worst Habits For Belly Fat                                         |[50, worst, habits, belly, fat]                                   |
|The Cost of Trump's Aid Freeze in the Trenches of Ukraine's War       |[cost, trump's, aid, freeze, trenches, ukraine's, war]            |
|I Was An NBA Wife. Here's How It Affected My Mental Health.           |[nba, wife., affected, mental, health.]                           |
|How to Get Rid of S

In [42]:
news_df.select("Abstract", "FilteredAbstractTokens").show(2, truncate=False)

+--------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------+
|Abstract                                                                                                            |FilteredAbstractTokens                                                                      |
+--------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------+
|Shop the notebooks, jackets, and more that the royals can't live without.                                           |[shop, notebooks, jackets, royals, live, without.]                                          |
|These seemingly harmless habits are holding you back and keeping you from shedding that unwanted belly fat for good.|[seemingly, harmless, habits, hold

In [43]:
news_df.columns

['NewsID',
 'Category',
 'Subcategory',
 'Title',
 'Abstract',
 'URL',
 'TitleEntities',
 'AbstractEntities',
 'CleanTitle',
 'CleanAbstract',
 'TitleTokens',
 'AbstractTokens',
 'FilteredTitleTokens',
 'FilteredAbstractTokens']

In [44]:
news_df.select('NewsID', 'Category', 'FilteredTitleTokens', 'FilteredAbstractTokens').show(1)

+------+---------+--------------------+----------------------+
|NewsID| Category| FilteredTitleTokens|FilteredAbstractTokens|
+------+---------+--------------------+----------------------+
|N55528|lifestyle|[brands, queen, e...|  [shop, notebooks,...|
+------+---------+--------------------+----------------------+
only showing top 1 row



In [45]:
pip install spark-nlp==5.5.1

Note: you may need to restart the kernel to use updated packages.


# EMBEDDINGS

In [46]:
import sparknlp
from pyspark.sql import SparkSession
from sparknlp.base import DocumentAssembler, TokenAssembler
from sparknlp.annotator import BertEmbeddings
from pyspark.ml import Pipeline
from pyspark.ml.feature import Tokenizer, StopWordsRemover, CountVectorizer, IDF
from pyspark.sql.functions import col, concat_ws, array_union, explode
from pyspark.sql.functions import concat, col

In [47]:
news_df.printSchema()

root
 |-- NewsID: string (nullable = true)
 |-- Category: string (nullable = true)
 |-- Subcategory: string (nullable = true)
 |-- Title: string (nullable = true)
 |-- Abstract: string (nullable = true)
 |-- URL: string (nullable = true)
 |-- TitleEntities: string (nullable = true)
 |-- AbstractEntities: string (nullable = true)
 |-- CleanTitle: string (nullable = true)
 |-- CleanAbstract: string (nullable = true)
 |-- TitleTokens: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- AbstractTokens: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- FilteredTitleTokens: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- FilteredAbstractTokens: array (nullable = true)
 |    |-- element: string (containsNull = true)



In [48]:
news_df = news_df.withColumn(
    "combined_tokens", concat(col("FilteredTitleTokens"), col("FilteredAbstractTokens"))
)


In [49]:
news_df.select("FilteredTitleTokens", "FilteredAbstractTokens").show(1, truncate=False)

+------------------------------------------------------------------+--------------------------------------------------+
|FilteredTitleTokens                                               |FilteredAbstractTokens                            |
+------------------------------------------------------------------+--------------------------------------------------+
|[brands, queen, elizabeth, prince, charles, prince, philip, swear]|[shop, notebooks, jackets, royals, live, without.]|
+------------------------------------------------------------------+--------------------------------------------------+
only showing top 1 row



In [50]:
news_df.select('combined_tokens').show(1, truncate=False)

+--------------------------------------------------------------------------------------------------------------------+
|combined_tokens                                                                                                     |
+--------------------------------------------------------------------------------------------------------------------+
|[brands, queen, elizabeth, prince, charles, prince, philip, swear, shop, notebooks, jackets, royals, live, without.]|
+--------------------------------------------------------------------------------------------------------------------+
only showing top 1 row



### TF-IDF

In [51]:
# Term frequency 

from pyspark.ml.feature import CountVectorizer

cv = CountVectorizer(inputCol="combined_tokens", outputCol="raw_features")
cv_model = cv.fit(news_df)
news_df_tf = cv_model.transform(news_df)


                                                                                

In [52]:
from pyspark.ml.feature import IDF

idf = IDF(inputCol="raw_features", outputCol="tf_idf")
idf_model = idf.fit(news_df_tf)
news_df_tfidf = idf_model.transform(news_df_tf)

24/12/12 18:07:33 WARN DAGScheduler: Broadcasting large task binary with size 1236.3 KiB
24/12/12 18:07:37 WARN DAGScheduler: Broadcasting large task binary with size 1237.3 KiB
                                                                                

In [53]:
news_df_tfidf.select("NewsID", "combined_tokens", "tf_idf").show(truncate=True)

24/12/12 18:07:38 WARN DAGScheduler: Broadcasting large task binary with size 2.9 MiB


+------+--------------------+--------------------+
|NewsID|     combined_tokens|              tf_idf|
+------+--------------------+--------------------+
|N55528|[brands, queen, e...|(109675,[310,977,...|
|N19639|[50, worst, habit...|(109675,[27,437,8...|
|N61837|[cost, trump's, a...|(109675,[63,176,3...|
|N53526|[nba, wife., affe...|(109675,[38,89,23...|
|N38324|[get, rid, skin, ...|(109675,[6,17,20,...|
| N2073|[nfl, able, fine,...|(109675,[87,94,17...|
|N49186|[orlando's, hotte...|(109675,[98,224,2...|
|N59295|[chile:, three, d...|(109675,[2,19,23,...|
|N24510|[best, ps5, games...|(109675,[13,28,33...|
|N39237|[report, weather-...|(109675,[52,65,93...|
| N9721|[50, foods, never...|(109675,[17,187,2...|
|N60905|[trying, make, ra...|(109675,[1,44,58,...|
|N39758|[25, biggest, gro...|(109675,[5,207,22...|
|N28361|[instagram, filte...|(109675,[142,181,...|
|N18680|[michigan, apple,...|(109675,[156,181,...|
|N55610|[kate, middleton'...|(109675,[33,35,41...|
|N35621|[stars, got, fire...|(1

### BERT

# USERS

In [54]:
# Define the schema
behaviors_schema = StructType([
    StructField("ImpressionID", StringType(), True),
    StructField("UserID", StringType(), True),
    StructField("Time", StringType(), True),
    StructField("History", StringType(), True),
    StructField("Impressions", StringType(), True)
])

# Load the behaviors.tsv file
behaviors_df = spark.read.csv(
    "data/mind/MINDsmall_train/behaviors.tsv",
    sep="\t",
    schema=behaviors_schema,
    header=False
)

# Display the schema and a sample row
# behaviors_df.printSchema()
behaviors_df.show(3, truncate=False)

+------------+------+---------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|ImpressionID|UserID|Time                 |History                           

In [55]:
# Split History into an array
behaviors_df = behaviors_df.withColumn("HistoryList", split(col("History"), " "))
behaviors_df = behaviors_df.drop("History")  # Drop original History column if not needed

# Verify the transformation
behaviors_df.select("ImpressionID", "UserID", "HistoryList").show(3, truncate=False)

+------------+------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|ImpressionID|UserID|HistoryList                                                                                                                                                                                                                                                                                                                 

In [56]:
# Split Impressions into an array
behaviors_df = behaviors_df.withColumn("ImpressionsList", split(col("Impressions"), " "))
behaviors_df = behaviors_df.drop("Impressions")  # Drop original Impressions column if not needed

# Verify the transformation
behaviors_df.select("ImpressionID", "ImpressionsList").show(2, truncate=False)

+------------+------------------------------------------------------------------------------------------------------------+
|ImpressionID|ImpressionsList                                                                                             |
+------------+------------------------------------------------------------------------------------------------------------+
|1           |[N55689-1, N35729-0]                                                                                        |
|2           |[N20678-0, N39317-0, N58114-0, N20495-0, N42977-0, N22407-0, N14592-0, N17059-1, N33677-0, N7821-0, N6890-0]|
+------------+------------------------------------------------------------------------------------------------------------+
only showing top 2 rows



In [57]:
# Explode ImpressionsList
impressions_exploded = behaviors_df.select(
    "ImpressionID",
    "UserID",
    "Time",
    "HistoryList",
    explode("ImpressionsList").alias("ImpressionItem")
)

# Extract CandidateNewsID and ClickLabel using regex
impressions_exploded = impressions_exploded \
    .withColumn("CandidateNewsID", regexp_extract(col("ImpressionItem"), r"^(N\d+)-\d+$", 1)) \
    .withColumn("ClickLabel", regexp_extract(col("ImpressionItem"), r"^N\d+-(\d+)$", 1).cast("integer")) \
    .drop("ImpressionItem")

# Verify the transformation
impressions_exploded.select("ImpressionID", "UserID", "CandidateNewsID", "ClickLabel").show(5, truncate=False)

+------------+------+---------------+----------+
|ImpressionID|UserID|CandidateNewsID|ClickLabel|
+------------+------+---------------+----------+
|1           |U13740|N55689         |1         |
|1           |U13740|N35729         |0         |
|2           |U91836|N20678         |0         |
|2           |U91836|N39317         |0         |
|2           |U91836|N58114         |0         |
+------------+------+---------------+----------+
only showing top 5 rows



In [58]:
# Join impressions with news_features_df on CandidateNewsID
impressions_with_features = impressions_exploded.join(
    news_df_tfidf,
    impressions_exploded.CandidateNewsID == news_df_tfidf.NewsID,
    how="left"
).drop(news_df_tfidf.NewsID)  # Drop duplicate NewsID column if present

In [59]:
impressions_with_features.show(3, truncate=True)

24/12/12 18:09:47 WARN DAGScheduler: Broadcasting large task binary with size 2.9 MiB
24/12/12 18:09:55 WARN DAGScheduler: Broadcasting large task binary with size 2.9 MiB


+------------+------+--------------------+--------------------+---------------+----------+------------+------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+----------------------+--------------------+--------------------+--------------------+
|ImpressionID|UserID|                Time|         HistoryList|CandidateNewsID|ClickLabel|    Category|       Subcategory|               Title|            Abstract|                 URL|       TitleEntities|    AbstractEntities|          CleanTitle|       CleanAbstract|         TitleTokens|      AbstractTokens| FilteredTitleTokens|FilteredAbstractTokens|     combined_tokens|        raw_features|              tf_idf|
+------------+------+--------------------+--------------------+---------------+----------+------------+------------------+--------------------+-------------------

                                                                                

In [96]:
    # Filter records where ClickLabel == 1
    clicked_news_df = impressions_with_features.filter(col("ClickLabel") == 1)
    clicked_news_df = impressions_with_features.drop('CleanTitle', 'CleanAbstract', 'TitleTokens', 'AbstractTokens', 'AbstractTokens', 'FilteredAbstractTokens','FilteredAbstractTokens', 'FilteredAbstractTokens', 'raw_features','combined_tokens','FilteredTitleTokens' )
    # Verify the filtered DataFrame
    clicked_news_df.show(5, truncate=True)

24/12/12 18:32:03 WARN DAGScheduler: Broadcasting large task binary with size 2.9 MiB
24/12/12 18:32:08 WARN DAGScheduler: Broadcasting large task binary with size 2.9 MiB


+------------+------+--------------------+--------------------+---------------+----------+------------+------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|ImpressionID|UserID|                Time|         HistoryList|CandidateNewsID|ClickLabel|    Category|       Subcategory|               Title|            Abstract|                 URL|       TitleEntities|    AbstractEntities|              tf_idf|
+------------+------+--------------------+--------------------+---------------+----------+------------+------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|      143656|U74335|11/11/2019 3:07:0...|[N13008, N19593, ...|         N11483|         0|       autos|         autossuvs|2020 Hyundai Venu...|Hyundai's new tin...|https://assets.ms...|                  []|[{"Label": "Nissa...|(109675,[1,85,301...|
|   

                                                                                

### USER EMBEDDINGS

In [69]:
spark = SparkSession.builder \
    .appName("User Embeddings") \
    .config("spark.pyspark.python", "/path/to/python") \
    .config("spark.pyspark.driver.python", "/path/to/python") \
    .getOrCreate()

24/12/12 18:22:05 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [74]:
# from pyspark.sql.functions import collect_list, col, udf
# from pyspark.ml.linalg import Vectors, VectorUDT
# import numpy as np

# # UDF to compute the average of a list of vectors
# def average_vectors(vectors):
#     if not vectors:
#         return Vectors.dense([0.0] * 109675)  # Replace 109675 with the actual embedding size
#     np_vectors = np.array([v.toArray() for v in vectors])
#     avg_vector = np.mean(np_vectors, axis=0)
#     return Vectors.dense(avg_vector)

# average_vectors_udf = udf(average_vectors, VectorUDT())

# def test_average_vectors():
#     vectors = [Vectors.dense([1.0, 2.0, 3.0]), Vectors.dense([4.0, 5.0, 6.0])]
#     np_vectors = np.array([v.toArray() for v in vectors])
#     avg_vector = np.mean(np_vectors, axis=0)
#     print(avg_vector)

# test_average_vectors()

In [70]:
!which python3

/usr/local/bin/python3


In [71]:
spark = SparkSession.builder \
    .appName("User Embeddings") \
    .config("spark.pyspark.python", "/usr/local/bin/python3") \
    .config("spark.pyspark.driver.python", "/usr/local/bin/python3") \
    .getOrCreate()

24/12/12 18:23:46 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [95]:
clicked_news_df.printSchema()

root
 |-- ImpressionID: string (nullable = true)
 |-- UserID: string (nullable = true)
 |-- Time: string (nullable = true)
 |-- HistoryList: array (nullable = true)
 |    |-- element: string (containsNull = false)
 |-- CandidateNewsID: string (nullable = false)
 |-- ClickLabel: integer (nullable = true)
 |-- Category: string (nullable = true)
 |-- Subcategory: string (nullable = true)
 |-- Title: string (nullable = true)
 |-- Abstract: string (nullable = true)
 |-- URL: string (nullable = true)
 |-- TitleEntities: string (nullable = true)
 |-- AbstractEntities: string (nullable = true)
 |-- tf_idf: vector (nullable = true)
 |-- tf_idf_array: array (nullable = true)
 |    |-- element: float (containsNull = true)



In [89]:
clicked_news_df = clicked_news_df.filter(col("tf_idf").isNotNull())

In [93]:
clicked_news_df.printSchema()
clicked_news_df.select('tf_idf').show(1)

root
 |-- ImpressionID: string (nullable = true)
 |-- UserID: string (nullable = true)
 |-- Time: string (nullable = true)
 |-- HistoryList: array (nullable = true)
 |    |-- element: string (containsNull = false)
 |-- CandidateNewsID: string (nullable = false)
 |-- ClickLabel: integer (nullable = true)
 |-- Category: string (nullable = true)
 |-- Subcategory: string (nullable = true)
 |-- Title: string (nullable = true)
 |-- Abstract: string (nullable = true)
 |-- URL: string (nullable = true)
 |-- TitleEntities: string (nullable = true)
 |-- AbstractEntities: string (nullable = true)
 |-- tf_idf: vector (nullable = true)
 |-- tf_idf_array: array (nullable = true)
 |    |-- element: float (containsNull = true)



24/12/12 18:29:38 WARN DAGScheduler: Broadcasting large task binary with size 2.9 MiB
                                                                                

+--------------------+
|              tf_idf|
+--------------------+
|(109675,[8,30,51,...|
+--------------------+
only showing top 1 row



In [85]:

from pyspark.sql.functions import pandas_udf
import pandas as pd
import numpy as np

@pandas_udf("array<float>")
def average_vectors_udf(vectors: pd.Series) -> pd.Series:
    # Convert list of vectors to a NumPy array and calculate the mean
    np_vectors = np.array(vectors.tolist())
    avg_vector = np.mean(np_vectors, axis=0)
    return pd.Series(avg_vector.tolist())
