In [0]:
from pyspark.sql.functions import avg, round as _round
from pyspark.sql.types import StructType, StructField, IntegerType, TimestampType
from datetime import datetime
from pyspark.sql.window import Window

# Define schema
schema = StructType([
    StructField("user_id", IntegerType(), True),
    StructField("tweet_date", TimestampType(), True),
    StructField("tweet_count", IntegerType(), True)
])

# Sample data with datetime conversion
data = [
    (111, datetime(2022, 6, 1, 0, 0, 0), 2),
    (111, datetime(2022, 6, 2, 0, 0, 0), 1),
    (111, datetime(2022, 6, 3, 0, 0, 0), 3),
    (111, datetime(2022, 6, 4, 0, 0, 0), 4),
    (111, datetime(2022, 6, 5, 0, 0, 0), 5),
    (111, datetime(2022, 6, 6, 0, 0, 0), 4),
    (111, datetime(2022, 6, 7, 0, 0, 0), 6),
    (199, datetime(2022, 6, 1, 0, 0, 0), 7),
    (199, datetime(2022, 6, 2, 0, 0, 0), 5),
    (199, datetime(2022, 6, 3, 0, 0, 0), 9),
    (199, datetime(2022, 6, 4, 0, 0, 0), 1),
    (199, datetime(2022, 6, 5, 0, 0, 0), 8),
    (199, datetime(2022, 6, 6, 0, 0, 0), 2),
    (199, datetime(2022, 6, 7, 0, 0, 0), 2),
    (254, datetime(2022, 6, 1, 0, 0, 0), 1),
    (254, datetime(2022, 6, 2, 0, 0, 0), 1),
    (254, datetime(2022, 6, 3, 0, 0, 0), 2),
    (254, datetime(2022, 6, 4, 0, 0, 0), 1),
    (254, datetime(2022, 6, 5, 0, 0, 0), 3),
    (254, datetime(2022, 6, 6, 0, 0, 0), 1),
    (254, datetime(2022, 6, 7, 0, 0, 0), 3)
]

# Create DataFrame
df = spark.createDataFrame(data, schema=schema)

# display the DataFrame
df.display()




user_id,tweet_date,tweet_count
111,2022-06-01T00:00:00.000+0000,2
111,2022-06-02T00:00:00.000+0000,1
111,2022-06-03T00:00:00.000+0000,3
111,2022-06-04T00:00:00.000+0000,4
111,2022-06-05T00:00:00.000+0000,5
111,2022-06-06T00:00:00.000+0000,4
111,2022-06-07T00:00:00.000+0000,6
199,2022-06-01T00:00:00.000+0000,7
199,2022-06-02T00:00:00.000+0000,5
199,2022-06-03T00:00:00.000+0000,9


In [0]:
# Define window specification for 3-day rolling average
window_spec = Window.partitionBy("user_id").orderBy("tweet_date").rowsBetween(-2, 0)

# Calculate rolling average and round the result
result_df = df.withColumn(
    "average_3_day_rolling", 
    _round(avg("tweet_count").over(window_spec), 2)
)

# display the result
result_df.display()


user_id,tweet_date,tweet_count,average_3_day_rolling
111,2022-06-01T00:00:00.000+0000,2,2.0
111,2022-06-02T00:00:00.000+0000,1,1.5
111,2022-06-03T00:00:00.000+0000,3,2.0
111,2022-06-04T00:00:00.000+0000,4,2.67
111,2022-06-05T00:00:00.000+0000,5,4.0
111,2022-06-06T00:00:00.000+0000,4,4.33
111,2022-06-07T00:00:00.000+0000,6,5.0
199,2022-06-01T00:00:00.000+0000,7,7.0
199,2022-06-02T00:00:00.000+0000,5,6.0
199,2022-06-03T00:00:00.000+0000,9,7.0


In [0]:
df.createOrReplaceTempView("tweet_data")

In [0]:
%sql
WITH TweetData AS (
  SELECT    
    user_id,    
    tweet_date,    
    tweet_count
  FROM tweet_data
)

SELECT    
  user_id,    
  tweet_date,   
  ROUND(AVG(tweet_count) OVER (
    PARTITION BY user_id     
    ORDER BY tweet_date     
    ROWS BETWEEN 2 PRECEDING AND CURRENT ROW), 2) AS rolling_avg_3d
FROM TweetData;

user_id,tweet_date,rolling_avg_3d
111,2022-06-01T00:00:00.000+0000,2.0
111,2022-06-02T00:00:00.000+0000,1.5
111,2022-06-03T00:00:00.000+0000,2.0
111,2022-06-04T00:00:00.000+0000,2.67
111,2022-06-05T00:00:00.000+0000,4.0
111,2022-06-06T00:00:00.000+0000,4.33
111,2022-06-07T00:00:00.000+0000,5.0
199,2022-06-01T00:00:00.000+0000,7.0
199,2022-06-02T00:00:00.000+0000,6.0
199,2022-06-03T00:00:00.000+0000,7.0
