In [1]:
# 第一步：数据加载与基本统计分析

# 本 notebook 的目标：
# 1. 初始化 Spark Session
# 2. 加载原始数据
# 3. 查看数据结构和基本信息
# 4. 进行描述性统计分析


In [2]:
# 导入必要的库
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import StringType
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

# 设置图表样式
plt.style.use('default')
sns.set_palette("husl")


In [3]:
# 初始化 Spark Session
spark = SparkSession.builder \
    .appName("TweetAnalysis_DataIngestion") \
    .master("local[*]") \
    .config("spark.driver.memory", "16g") \
    .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
    .getOrCreate()

sc = spark.sparkContext
print(f"Spark Version: {spark.version}")
print(f"Available cores: {sc.defaultParallelism}")


Spark Version: 3.5.0
Available cores: 20


In [4]:
# 加载数据
raw_data_path = "/home/jovyan/work/data/raw/the-reddit-climate-change-dataset-comments.csv"

# 加载数据，让Spark自动推断Schema
df_raw = spark.read.csv(raw_data_path, header=True, inferSchema=True, multiLine=True, escape='"')

# 缓存DataFrame，后续操作会更快
df_raw.cache()

print("数据加载完成！")


数据加载完成！


In [5]:
# 查看数据结构
print("=== 数据结构 (Schema) ===")
df_raw.printSchema()


=== 数据结构 (Schema) ===
root
 |-- type: string (nullable = true)
 |-- id: string (nullable = true)
 |-- subreddit.id: string (nullable = true)
 |-- subreddit.name: string (nullable = true)
 |-- subreddit.nsfw: boolean (nullable = true)
 |-- created_utc: integer (nullable = true)
 |-- permalink: string (nullable = true)
 |-- body: string (nullable = true)
 |-- sentiment: double (nullable = true)
 |-- score: integer (nullable = true)



In [6]:
# 查看前几行数据
print("=== 前5行数据 ===")
df_raw.show(5, truncate=False)


=== 前5行数据 ===
+-------+-------+------------+--------------+--------------+-----------+----------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [7]:
# 基本统计信息
print("=== 基本统计信息 ===")

# 1. 数据总量
total_count = df_raw.count()
print(f"总评论数量: {total_count:,}")

# 2. 列数
num_columns = len(df_raw.columns)
print(f"列数: {num_columns}")

# 3. 列名
print(f"列名: {df_raw.columns}")


=== 基本统计信息 ===
总评论数量: 4,600,698
列数: 10
列名: ['type', 'id', 'subreddit.id', 'subreddit.name', 'subreddit.nsfw', 'created_utc', 'permalink', 'body', 'sentiment', 'score']


In [8]:
# 检查各列的详细信息和样本数据
print("=== 各列详细信息 ===")
for col_name in df_raw.columns:
    print(f"\n列名: {col_name}")
    print(f"数据类型: {dict(df_raw.dtypes)[col_name]}")
    
    # 显示非空值数量 - 使用反引号处理包含点号的列名
    if '.' in col_name:
        non_null_count = df_raw.filter(F.col(f"`{col_name}`").isNotNull()).count()
    else:
        non_null_count = df_raw.filter(F.col(col_name).isNotNull()).count()
    
    null_count = total_count - non_null_count
    print(f"非空值: {non_null_count:,} | 空值: {null_count:,} ({null_count/total_count*100:.2f}%)")
    
    # 显示样本值（对于字符串类型，限制长度）
    if col_name in ['body', 'permalink']:
        sample_values = df_raw.select(col_name).limit(2).collect()
        for i, row in enumerate(sample_values):
            value = str(row[0])[:100] + "..." if row[0] and len(str(row[0])) > 100 else row[0]
            print(f"  样本{i+1}: {value}")
    else:
        sample_values = df_raw.select(col_name).limit(3).collect()
        for i, row in enumerate(sample_values):
            print(f"  样本{i+1}: {row[0]}")
    print("-" * 50)


=== 各列详细信息 ===

列名: type
数据类型: string
非空值: 4,600,698 | 空值: 0 (0.00%)
  样本1: comment
  样本2: comment
  样本3: comment
--------------------------------------------------

列名: id
数据类型: string
非空值: 4,600,698 | 空值: 0 (0.00%)
  样本1: imlddn9
  样本2: imldbeh
  样本3: imldado
--------------------------------------------------

列名: subreddit.id
数据类型: string


AnalysisException: [UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name `subreddit`.`id` cannot be resolved. Did you mean one of the following? [`subreddit.id`, `subreddit.name`, `subreddit.nsfw`, `body`, `permalink`].;
'Filter isnotnull('subreddit.id)
+- Relation [type#0,id#1,subreddit.id#2,subreddit.name#3,subreddit.nsfw#4,created_utc#5,permalink#6,body#7,sentiment#8,score#9] csv


In [None]:
# 时间范围分析
print("=== 时间范围分析 ===")

# 转换时间戳为可读格式
df_with_time = df_raw.withColumn("timestamp", F.from_unixtime(F.col("created_utc")))

# 获取时间范围
time_stats = df_with_time.select(
    F.min("timestamp").alias("earliest_time"),
    F.max("timestamp").alias("latest_time")
).collect()[0]

print(f"最早评论时间: {time_stats['earliest_time']}")
print(f"最晚评论时间: {time_stats['latest_time']}")

# 按年份统计评论数量
yearly_stats = df_with_time.withColumn("year", F.year("timestamp")) \
                          .groupBy("year") \
                          .count() \
                          .orderBy("year")

print("\n按年份统计评论数量:")
yearly_stats.show()


In [None]:
# 子版块分析
print("=== 子版块分析 ===")

# 统计各子版块的评论数量 - 使用反引号处理列名
subreddit_stats = df_raw.groupBy(F.col("`subreddit.name`")) \
                        .count() \
                        .orderBy(F.desc("count"))

print("评论数量最多的前20个子版块:")
subreddit_stats.show(20)

# 转换为 Pandas 进行可视化
top_subreddits = subreddit_stats.limit(15).toPandas()

plt.figure(figsize=(12, 8))
sns.barplot(data=top_subreddits, x='count', y='subreddit.name')
plt.title('评论数量最多的前15个子版块')
plt.xlabel('评论数量')
plt.ylabel('子版块名称')
plt.tight_layout()
plt.show()


In [None]:
# 情感分数分析（数据中已有sentiment列）
print("=== 情感分数分析 ===")

# 基本统计
sentiment_stats = df_raw.select("sentiment").describe()
sentiment_stats.show()

# 情感分数分布
print("情感分数分布:")
sentiment_ranges = df_raw.withColumn("sentiment_range",
    F.when(F.col("sentiment") >= 0.1, "positive")
     .when(F.col("sentiment") <= -0.1, "negative")
     .otherwise("neutral")
).groupBy("sentiment_range").count().orderBy(F.desc("count"))

sentiment_ranges.show()

# 可视化情感分数分布
sentiment_sample = df_raw.select("sentiment").sample(False, 0.01).toPandas()  # 抽样1%

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.hist(sentiment_sample['sentiment'], bins=50, alpha=0.7, edgecolor='black')
plt.title('情感分数分布直方图')
plt.xlabel('情感分数')
plt.ylabel('频次')

plt.subplot(1, 2, 2)
sns.boxplot(y=sentiment_sample['sentiment'])
plt.title('情感分数箱线图')
plt.ylabel('情感分数')

plt.tight_layout()
plt.show()


In [None]:
# 评论长度分析
print("=== 评论长度分析 ===")

# 计算评论长度
df_with_length = df_raw.withColumn("body_length", F.length(F.col("body")))

# 长度统计
length_stats = df_with_length.select("body_length").describe()
length_stats.show()

# 长度分布
print("评论长度分布:")
length_ranges = df_with_length.withColumn("length_range",
    F.when(F.col("body_length") <= 50, "very_short")
     .when(F.col("body_length") <= 200, "short")
     .when(F.col("body_length") <= 500, "medium")
     .when(F.col("body_length") <= 1000, "long")
     .otherwise("very_long")
).groupBy("length_range").count().orderBy(F.desc("count"))

length_ranges.show()

# 可视化评论长度分布
length_sample = df_with_length.select("body_length").sample(False, 0.01).toPandas()

plt.figure(figsize=(10, 6))
plt.hist(length_sample['body_length'], bins=50, alpha=0.7, edgecolor='black')
plt.title('评论长度分布直方图')
plt.xlabel('评论长度（字符数）')
plt.ylabel('频次')
plt.xlim(0, 2000)  # 限制x轴范围以便更好地观察
plt.show()
