In [2]:
# 第一步：数据加载与基本统计分析

# 本 notebook 的目标：
# 1. 初始化 Spark Session
# 2. 加载原始数据
# 3. 查看数据结构和基本信息
# 4. 进行描述性统计分析


In [3]:
# 导入必要的库
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import StringType
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

# 设置图表样式
plt.style.use('default')
sns.set_palette("husl")


In [4]:
# 初始化 Spark Session
spark = SparkSession.builder \
    .appName("TweetAnalysis_DataIngestion") \
    .master("local[*]") \
    .config("spark.driver.memory", "16g") \
    .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
    .getOrCreate()

sc = spark.sparkContext
print(f"Spark Version: {spark.version}")
print(f"Available cores: {sc.defaultParallelism}")


Spark Version: 3.5.0
Available cores: 20


In [5]:
# 加载数据
raw_data_path = "/home/jovyan/work/data/raw/the-reddit-climate-change-dataset-comments.csv"

# 加载数据，让Spark自动推断Schema
df_raw = spark.read.csv(raw_data_path, header=True, inferSchema=True, multiLine=True, escape='"')

# 缓存DataFrame，后续操作会更快
df_raw.cache()

print("数据加载完成！")


数据加载完成！


In [6]:
# 查看数据结构
print("=== 数据结构 (Schema) ===")
df_raw.printSchema()


=== 数据结构 (Schema) ===
root
 |-- type: string (nullable = true)
 |-- id: string (nullable = true)
 |-- subreddit.id: string (nullable = true)
 |-- subreddit.name: string (nullable = true)
 |-- subreddit.nsfw: boolean (nullable = true)
 |-- created_utc: integer (nullable = true)
 |-- permalink: string (nullable = true)
 |-- body: string (nullable = true)
 |-- sentiment: double (nullable = true)
 |-- score: integer (nullable = true)



In [7]:
# 查看前几行数据
print("=== 前5行数据 ===")
df_raw.show(5, truncate=False)


=== 前5行数据 ===
+-------+-------+------------+--------------+--------------+-----------+----------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [8]:
# 基本统计信息
print("=== 基本统计信息 ===")

# 1. 数据总量
total_count = df_raw.count()
print(f"总评论数量: {total_count:,}")

# 2. 列数
num_columns = len(df_raw.columns)
print(f"列数: {num_columns}")

# 3. 列名
print(f"列名: {df_raw.columns}")


=== 基本统计信息 ===
总评论数量: 4,600,698
列数: 10
列名: ['type', 'id', 'subreddit.id', 'subreddit.name', 'subreddit.nsfw', 'created_utc', 'permalink', 'body', 'sentiment', 'score']
