# Spark Dataset API 教程

本教程将详细介绍Apache Spark Dataset API，它结合了RDD的类型安全和DataFrame的性能优化。

## 学习目标

通过本教程，你将学会：
1. Dataset的基本概念和优势
2. Dataset与DataFrame和RDD的区别
3. 创建和操作Dataset
4. 类型安全的数据处理
5. Dataset的性能优化
6. 实际应用案例

## 1. Dataset基础概念

Dataset是Spark 1.6引入的新抽象，它结合了RDD和DataFrame的优点：
- **类型安全**：编译时类型检查
- **性能优化**：Catalyst优化器和Tungsten执行引擎
- **面向对象编程**：支持lambda函数和复杂数据类型

### Dataset vs DataFrame vs RDD

| 特性 | RDD | DataFrame | Dataset |
|------|-----|-----------|----------|
| 类型安全 | 编译时 | 运行时 | 编译时 |
| 性能优化 | 无 | Catalyst | Catalyst |
| API风格 | 函数式 | SQL风格 | 混合 |
| 序列化 | Java/Kryo | Tungsten | Tungsten |
| GC开销 | 高 | 低 | 低 |

## 2. 环境准备

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
import pandas as pd

# 创建SparkSession
spark = SparkSession.builder \
    .appName("Dataset API Tutorial") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .getOrCreate()

# 设置日志级别
spark.sparkContext.setLogLevel("WARN")

print(f"Spark版本: {spark.version}")
print("Dataset API教程环境准备完成")

## 3. 创建Dataset

在PySpark中，Dataset实际上就是DataFrame，因为Python是动态类型语言。但我们可以通过类型提示和验证来模拟类型安全。

In [None]:
# 定义数据类型
from typing import NamedTuple
from dataclasses import dataclass

# 使用NamedTuple定义Person类型
class Person(NamedTuple):
    name: str
    age: int
    city: str
    salary: float

# 使用dataclass定义Product类型
@dataclass
class Product:
    id: str
    name: str
    category: str
    price: float
    
    def __post_init__(self):
        # 类型验证
        if not isinstance(self.price, (int, float)) or self.price < 0:
            raise ValueError("Price must be a non-negative number")

print("数据类型定义完成")

In [None]:
# 创建Person Dataset
person_data = [
    Person("Alice", 25, "New York", 75000.0),
    Person("Bob", 30, "San Francisco", 85000.0),
    Person("Charlie", 35, "Chicago", 95000.0),
    Person("Diana", 28, "Boston", 65000.0),
    Person("Eve", 32, "Seattle", 78000.0)
]

# 定义Schema
person_schema = StructType([
    StructField("name", StringType(), False),
    StructField("age", IntegerType(), False),
    StructField("city", StringType(), False),
    StructField("salary", DoubleType(), False)
])

# 创建Dataset（在PySpark中实际是DataFrame）
person_ds = spark.createDataFrame(person_data, person_schema)

print("Person Dataset:")
person_ds.show()
person_ds.printSchema()

In [None]:
# 创建Product Dataset
product_data = [
    Product("P001", "Laptop", "Electronics", 1200.0),
    Product("P002", "Mouse", "Electronics", 25.0),
    Product("P003", "Keyboard", "Electronics", 75.0),
    Product("P004", "Monitor", "Electronics", 300.0),
    Product("P005", "Chair", "Furniture", 150.0)
]

# 转换为字典列表（因为dataclass不能直接用于createDataFrame）
product_dict_data = [
    {"id": p.id, "name": p.name, "category": p.category, "price": p.price}
    for p in product_data
]

product_ds = spark.createDataFrame(product_dict_data)

print("Product Dataset:")
product_ds.show()
product_ds.printSchema()

## 4. Dataset基本操作

Dataset支持两种类型的操作：
- **Transformation**：返回新的Dataset
- **Action**：触发计算并返回结果

In [None]:
# 类型安全的转换操作
print("=== 基本转换操作 ===")

# 过滤操作
high_salary_persons = person_ds.filter(col("salary") > 70000)
print("高薪人员:")
high_salary_persons.show()

# 映射操作
person_with_bonus = person_ds.withColumn("bonus", col("salary") * 0.1)
print("\n添加奖金列:")
person_with_bonus.show()

# 选择操作
name_salary = person_ds.select("name", "salary")
print("\n姓名和薪资:")
name_salary.show()

In [None]:
# 复杂转换操作
print("=== 复杂转换操作 ===")

# 添加计算列
person_enhanced = person_ds.withColumn(
    "salary_level",
    when(col("salary") > 80000, "High")
    .when(col("salary") > 60000, "Medium")
    .otherwise("Low")
).withColumn(
    "age_group",
    when(col("age") < 30, "Young")
    .when(col("age") < 35, "Middle")
    .otherwise("Senior")
)

print("增强的Person Dataset:")
person_enhanced.show()

# 排序操作
sorted_by_salary = person_ds.orderBy(col("salary").desc())
print("\n按薪资排序:")
sorted_by_salary.show()

## 5. 聚合和分组操作

In [None]:
# 基本聚合操作
print("=== 基本聚合操作 ===")

# 统计信息
person_stats = person_ds.agg(
    count("*").alias("total_count"),
    avg("salary").alias("avg_salary"),
    max("salary").alias("max_salary"),
    min("salary").alias("min_salary"),
    stddev("salary").alias("salary_stddev")
)

print("人员统计信息:")
person_stats.show()

# 按城市分组
city_stats = person_ds.groupBy("city").agg(
    count("*").alias("person_count"),
    avg("salary").alias("avg_salary"),
    avg("age").alias("avg_age")
).orderBy(col("avg_salary").desc())

print("\n按城市统计:")
city_stats.show()

In [None]:
# 产品数据聚合
print("=== 产品数据聚合 ===")

# 按类别分组
category_stats = product_ds.groupBy("category").agg(
    count("*").alias("product_count"),
    avg("price").alias("avg_price"),
    max("price").alias("max_price"),
    min("price").alias("min_price"),
    sum("price").alias("total_value")
)

print("按类别统计:")
category_stats.show()

# 价格区间分析
price_ranges = product_ds.withColumn(
    "price_range",
    when(col("price") < 50, "Low")
    .when(col("price") < 200, "Medium")
    .otherwise("High")
).groupBy("price_range").agg(
    count("*").alias("count"),
    avg("price").alias("avg_price")
)

print("\n价格区间分析:")
price_ranges.show()

## 6. 连接操作

In [None]:
# 创建订单数据
from typing import NamedTuple

class Order(NamedTuple):
    order_id: str
    customer_name: str
    product_id: str
    quantity: int
    order_date: str

order_data = [
    Order("O001", "Alice", "P001", 1, "2023-01-01"),
    Order("O002", "Bob", "P002", 2, "2023-01-02"),
    Order("O003", "Charlie", "P003", 1, "2023-01-03"),
    Order("O004", "Alice", "P004", 1, "2023-01-04"),
    Order("O005", "Diana", "P005", 2, "2023-01-05")
]

order_ds = spark.createDataFrame(order_data)

print("Order Dataset:")
order_ds.show()
order_ds.printSchema()

In [None]:
# 连接操作
print("=== 连接操作 ===")

# 订单与产品信息连接
order_with_product = order_ds.join(
    product_ds,
    order_ds.product_id == product_ds.id,
    "inner"
).select(
    order_ds.order_id,
    order_ds.customer_name,
    product_ds.name.alias("product_name"),
    product_ds.category,
    product_ds.price,
    order_ds.quantity,
    (product_ds.price * order_ds.quantity).alias("total_amount")
)

print("订单产品信息:")
order_with_product.show()

# 订单与客户信息连接
order_with_customer = order_ds.join(
    person_ds,
    order_ds.customer_name == person_ds.name,
    "inner"
).select(
    order_ds.order_id,
    person_ds.name,
    person_ds.city,
    person_ds.salary,
    order_ds.product_id,
    order_ds.quantity
)

print("\n订单客户信息:")
order_with_customer.show()

## 7. 窗口函数

In [None]:
from pyspark.sql.window import Window

print("=== 窗口函数 ===")

# 薪资排名
salary_window = Window.orderBy(col("salary").desc())
city_salary_window = Window.partitionBy("city").orderBy(col("salary").desc())

person_with_rank = person_ds.withColumn(
    "global_salary_rank", row_number().over(salary_window)
).withColumn(
    "city_salary_rank", row_number().over(city_salary_window)
).withColumn(
    "salary_percentile", percent_rank().over(salary_window)
)

print("薪资排名:")
person_with_rank.show()

# 移动平均
salary_moving_avg = person_ds.withColumn(
    "salary_moving_avg",
    avg("salary").over(
        salary_window.rowsBetween(-1, 1)
    )
).withColumn(
    "salary_cumsum",
    sum("salary").over(
        salary_window.rowsBetween(Window.unboundedPreceding, Window.currentRow)
    )
)

print("\n移动平均和累计和:")
salary_moving_avg.orderBy(col("salary").desc()).show()

## 8. 用户定义函数 (UDF)

In [None]:
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType, IntegerType

print("=== 用户定义函数 ===")

# 定义UDF函数
def categorize_salary(salary):
    if salary >= 80000:
        return "High"
    elif salary >= 60000:
        return "Medium"
    else:
        return "Low"

def calculate_tax(salary):
    if salary <= 50000:
        return int(salary * 0.1)
    elif salary <= 80000:
        return int(salary * 0.15)
    else:
        return int(salary * 0.2)

# 注册UDF
categorize_salary_udf = udf(categorize_salary, StringType())
calculate_tax_udf = udf(calculate_tax, IntegerType())

# 应用UDF
person_with_udf = person_ds.withColumn(
    "salary_category", categorize_salary_udf(col("salary"))
).withColumn(
    "estimated_tax", calculate_tax_udf(col("salary"))
).withColumn(
    "net_salary", col("salary") - col("estimated_tax")
)

print("应用UDF后的结果:")
person_with_udf.show()

In [None]:
# 向量化UDF (pandas UDF) - 更高性能
from pyspark.sql.functions import pandas_udf
import pandas as pd

@pandas_udf(returnType=StringType())
def vectorized_salary_category(salary_series: pd.Series) -> pd.Series:
    return salary_series.apply(lambda x: "High" if x >= 80000 else ("Medium" if x >= 60000 else "Low"))

@pandas_udf(returnType=IntegerType())
def vectorized_tax_calculation(salary_series: pd.Series) -> pd.Series:
    def calc_tax(salary):
        if salary <= 50000:
            return int(salary * 0.1)
        elif salary <= 80000:
            return int(salary * 0.15)
        else:
            return int(salary * 0.2)
    
    return salary_series.apply(calc_tax)

# 应用向量化UDF
person_with_pandas_udf = person_ds.withColumn(
    "salary_category_v2", vectorized_salary_category(col("salary"))
).withColumn(
    "estimated_tax_v2", vectorized_tax_calculation(col("salary"))
)

print("\n应用向量化UDF后的结果:")
person_with_pandas_udf.show()

## 9. 数据验证和质量检查

In [None]:
print("=== 数据验证和质量检查 ===")

# 创建包含问题数据的Dataset
problematic_data = [
    ("Alice", 25, "New York", 75000.0),
    ("Bob", None, "San Francisco", 85000.0),  # 缺失年龄
    ("Charlie", 35, None, 95000.0),  # 缺失城市
    ("Diana", 28, "Boston", None),  # 缺失薪资
    (None, 32, "Seattle", 78000.0),  # 缺失姓名
    ("Frank", -5, "Miami", 60000.0),  # 无效年龄
    ("Grace", 30, "Denver", -1000.0)  # 无效薪资
]

problematic_ds = spark.createDataFrame(problematic_data, person_schema)

print("原始数据（包含问题）:")
problematic_ds.show()

# 数据质量检查
print("\n数据质量报告:")
total_rows = problematic_ds.count()
print(f"总行数: {total_rows}")

# 检查空值
null_counts = problematic_ds.select([
    sum(when(col(c).isNull(), 1).otherwise(0)).alias(f"{c}_nulls")
    for c in problematic_ds.columns
])
null_counts.show()

# 检查无效值
invalid_age_count = problematic_ds.filter(col("age") < 0).count()
invalid_salary_count = problematic_ds.filter(col("salary") < 0).count()

print(f"无效年龄记录数: {invalid_age_count}")
print(f"无效薪资记录数: {invalid_salary_count}")

In [None]:
# 数据清洗
print("=== 数据清洗 ===")

# 方法1: 删除包含空值的行
clean_ds_method1 = problematic_ds.dropna()
print(f"删除空值后剩余行数: {clean_ds_method1.count()}")
clean_ds_method1.show()

# 方法2: 填充默认值并过滤无效数据
clean_ds_method2 = problematic_ds.fillna({
    "name": "Unknown",
    "age": 30,
    "city": "Unknown",
    "salary": 50000.0
}).filter(
    (col("age") >= 0) & (col("salary") >= 0)
)

print(f"\n填充默认值并过滤后剩余行数: {clean_ds_method2.count()}")
clean_ds_method2.show()

# 方法3: 智能填充
# 计算平均值用于填充
avg_age = problematic_ds.agg(avg("age")).collect()[0][0]
avg_salary = problematic_ds.agg(avg("salary")).collect()[0][0]

clean_ds_method3 = problematic_ds.fillna({
    "name": "Unknown",
    "age": int(avg_age) if avg_age else 30,
    "city": "Unknown",
    "salary": avg_salary if avg_salary else 50000.0
}).filter(
    (col("age") >= 0) & (col("salary") >= 0)
)

print(f"\n智能填充后剩余行数: {clean_ds_method3.count()}")
clean_ds_method3.show()

## 10. 性能优化

In [None]:
print("=== 性能优化 ===")

# 缓存策略
print("1. 缓存策略")
person_ds.cache()
print(f"缓存状态: {person_ds.is_cached}")

# 分区优化
print(f"\n2. 分区信息")
print(f"当前分区数: {person_ds.rdd.getNumPartitions()}")

# 重新分区
repartitioned_ds = person_ds.repartition(2, col("city"))
print(f"按城市重新分区后: {repartitioned_ds.rdd.getNumPartitions()}")

# 查看执行计划
print("\n3. 执行计划分析")
complex_query = person_ds.filter(col("salary") > 70000) \
                         .groupBy("city") \
                         .agg(avg("salary").alias("avg_salary")) \
                         .orderBy("avg_salary")

print("查询执行计划:")
complex_query.explain(True)

In [None]:
# 广播连接优化
print("=== 广播连接优化 ===")

from pyspark.sql.functions import broadcast

# 小表广播连接
# 假设product_ds是小表
optimized_join = order_ds.join(
    broadcast(product_ds),
    order_ds.product_id == product_ds.id,
    "inner"
)

print("广播连接执行计划:")
optimized_join.explain()

print("\n广播连接结果:")
optimized_join.select(
    order_ds.order_id,
    order_ds.customer_name,
    product_ds.name.alias("product_name"),
    product_ds.price
).show()

## 11. 实际应用案例

In [None]:
print("=== 实际应用案例：员工薪资分析系统 ===")

# 加载真实数据
sales_df = spark.read.option("header", "true").option("inferSchema", "true").csv("/home/jovyan/data/sample/sales_data.csv")

# 案例1: 销售业绩分析
print("1. 销售业绩分析")

# 计算每个客户的购买统计
customer_analysis = sales_df.groupBy("customer_id").agg(
    count("*").alias("purchase_count"),
    sum(col("price") * col("quantity")).alias("total_spent"),
    avg(col("price") * col("quantity")).alias("avg_order_value"),
    countDistinct("product_id").alias("unique_products"),
    countDistinct("category").alias("unique_categories")
).withColumn(
    "customer_tier",
    when(col("total_spent") > 2000, "Premium")
    .when(col("total_spent") > 1000, "Gold")
    .when(col("total_spent") > 500, "Silver")
    .otherwise("Bronze")
)

print("客户分析结果:")
customer_analysis.orderBy(col("total_spent").desc()).show()

# 客户分层统计
tier_stats = customer_analysis.groupBy("customer_tier").agg(
    count("*").alias("customer_count"),
    avg("total_spent").alias("avg_spent"),
    avg("purchase_count").alias("avg_purchases")
)

print("\n客户分层统计:")
tier_stats.show()

In [None]:
# 案例2: 产品推荐系统基础分析
print("=== 案例2: 产品推荐系统基础分析 ===")

# 产品流行度分析
product_popularity = sales_df.groupBy("product_id", "category").agg(
    count("*").alias("purchase_frequency"),
    sum("quantity").alias("total_quantity"),
    countDistinct("customer_id").alias("unique_customers"),
    avg("price").alias("avg_price")
).withColumn(
    "popularity_score",
    col("purchase_frequency") * 0.4 + col("unique_customers") * 0.6
)

print("产品流行度分析:")
product_popularity.orderBy(col("popularity_score").desc()).show()

# 类别关联分析
category_cooccurrence = sales_df.alias("s1").join(
    sales_df.alias("s2"),
    (col("s1.customer_id") == col("s2.customer_id")) & 
    (col("s1.category") != col("s2.category")),
    "inner"
).groupBy(
    col("s1.category").alias("category1"),
    col("s2.category").alias("category2")
).agg(
    countDistinct(col("s1.customer_id")).alias("common_customers")
).filter(
    col("common_customers") > 1
).orderBy(col("common_customers").desc())

print("\n类别关联分析:")
category_cooccurrence.show()

## 12. 实践练习

In [None]:
print("=== 实践练习 ===")

# 练习1: 创建一个员工绩效分析系统
print("练习1: 员工绩效分析系统")

# 创建员工绩效数据
performance_data = [
    ("Alice", "Engineering", 85, 92, 88, 4.2),
    ("Bob", "Sales", 78, 85, 82, 3.8),
    ("Charlie", "Marketing", 92, 88, 90, 4.5),
    ("Diana", "Engineering", 88, 90, 89, 4.3),
    ("Eve", "Sales", 82, 87, 85, 4.0),
    ("Frank", "Marketing", 75, 80, 78, 3.5),
    ("Grace", "Engineering", 90, 93, 92, 4.6)
]

performance_schema = StructType([
    StructField("name", StringType(), False),
    StructField("department", StringType(), False),
    StructField("technical_score", IntegerType(), False),
    StructField("communication_score", IntegerType(), False),
    StructField("overall_score", IntegerType(), False),
    StructField("manager_rating", DoubleType(), False)
])

performance_ds = spark.createDataFrame(performance_data, performance_schema)

# 任务1: 计算综合绩效分数
performance_analysis = performance_ds.withColumn(
    "weighted_score",
    col("technical_score") * 0.4 + 
    col("communication_score") * 0.3 + 
    col("overall_score") * 0.2 +
    col("manager_rating") * 10 * 0.1
).withColumn(
    "performance_grade",
    when(col("weighted_score") >= 90, "A")
    .when(col("weighted_score") >= 80, "B")
    .when(col("weighted_score") >= 70, "C")
    .otherwise("D")
)

print("员工绩效分析:")
performance_analysis.orderBy(col("weighted_score").desc()).show()

# 任务2: 部门绩效统计
dept_performance = performance_analysis.groupBy("department").agg(
    count("*").alias("employee_count"),
    avg("weighted_score").alias("avg_performance"),
    max("weighted_score").alias("best_performance"),
    min("weighted_score").alias("worst_performance")
).withColumn(
    "dept_grade",
    when(col("avg_performance") >= 85, "Excellent")
    .when(col("avg_performance") >= 80, "Good")
    .otherwise("Needs Improvement")
)

print("\n部门绩效统计:")
dept_performance.show()

## 13. 总结

In [None]:
print("=== Dataset API教程总结 ===")
print("\n1. 核心概念:")
print("   - Dataset = DataFrame + 类型安全")
print("   - 编译时类型检查")
print("   - Catalyst优化器支持")

print("\n2. 主要优势:")
print("   - 类型安全的数据处理")
print("   - 高性能执行")
print("   - 丰富的API支持")
print("   - 易于调试和维护")

print("\n3. 核心操作:")
print("   - 转换: filter, map, flatMap, groupBy")
print("   - 行动: collect, count, show, write")
print("   - 聚合: agg, groupBy, window functions")
print("   - 连接: join, broadcast join")

print("\n4. 性能优化:")
print("   - 缓存策略")
print("   - 分区优化")
print("   - 广播连接")
print("   - 向量化UDF")

print("\n5. 最佳实践:")
print("   - 明确定义数据类型")
print("   - 使用向量化操作")
print("   - 合理使用缓存")
print("   - 监控执行计划")
print("   - 数据质量验证")

print("\n恭喜！你已经掌握了Dataset API的核心概念和操作。")
print("Dataset API是Spark中最强大的数据处理抽象，")
print("它结合了类型安全和高性能，是现代大数据应用的理想选择。")

In [None]:
# 清理资源
print("清理Spark资源...")

# 取消缓存
person_ds.unpersist()
spark.catalog.clearCache()

print("资源清理完成")
# spark.stop()  # 保持会话活跃