In [1]:


from pyspark.sql import SparkSession
from pyspark.sql.functions import col, mean
import pandas as pd
from sklearn.datasets import load_wine
import numpy as np

spark = SparkSession.builder \
    .appName("WineDataCleaning") \
    .getOrCreate()

wine = load_wine()
df_pd = pd.DataFrame(wine.data, columns=wine.feature_names)
df_pd["target"] = wine.target


df_pd.loc[5:10, "alcohol"] = np.nan
df_pd.loc[20:22, "malic_acid"] = np.nan
df_pd = pd.concat([df_pd, df_pd.iloc[0:5]])
df_pd.to_csv("wine_raw.csv", index=False)
df = spark.read.csv("wine_raw.csv", header=True, inferSchema=True)
print("Initial Dataset:")
df.show(5)
print(f"Initial count: {df.count()}")
df = df.dropDuplicates()
print(f"After removing duplicates: {df.count()}")
num_cols = [col_name for col_name, dtype in df.dtypes if dtype in ["double", "int"]]
for c in num_cols:
    mean_value = df.select(mean(col(c))).first()[0]
    df = df.na.fill({c: mean_value})

df = df.filter((col("alcohol") >= 0) & (col("alcohol") <= 20))


df.write.csv("wine_cleaned_pyspark.csv", header=True, mode="overwrite")

print("✅ Data cleaning completed. Saved as 'wine_cleaned_pyspark.csv'")

spark.stop()


Initial Dataset:
+-------+----------+----+-----------------+---------+-------------+----------+--------------------+---------------+---------------+----+----------------------------+-------+------+
|alcohol|malic_acid| ash|alcalinity_of_ash|magnesium|total_phenols|flavanoids|nonflavanoid_phenols|proanthocyanins|color_intensity| hue|od280/od315_of_diluted_wines|proline|target|
+-------+----------+----+-----------------+---------+-------------+----------+--------------------+---------------+---------------+----+----------------------------+-------+------+
|  14.23|      1.71|2.43|             15.6|    127.0|          2.8|      3.06|                0.28|           2.29|           5.64|1.04|                        3.92| 1065.0|     0|
|   13.2|      1.78|2.14|             11.2|    100.0|         2.65|      2.76|                0.26|           1.28|           4.38|1.05|                         3.4| 1050.0|     0|
|  13.16|      2.36|2.67|             18.6|    101.0|          2.8|      3.24|