In [None]:
!pip install pyspark



In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# PySpark imports
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql import functions as F
from pyspark.sql.types import StringType
from pyspark.sql import Window


In [None]:
spark = SparkSession.builder \
    .appName("Agriculture Data Processing") \
    .getOrCreate()

In [None]:
# Suppress warnings (equivalent to warnings.filterwarnings('ignore'))
spark.sparkContext.setLogLevel("ERROR")


-------------------------------------
### 1. DATA LOADING AND MERGING
-------------------------------------
##### STEP 1: DATA LOADING AND MERGING

In [None]:
# Load all datasets
fertilizer_df = spark.read.csv('Fertilizer.csv', header=True, inferSchema=True)
crop_yield_df = spark.read.csv('crop_yield.csv', header=True, inferSchema=True)
rainfall_df = spark.read.csv('rainfall_validation.csv', header=True, inferSchema=True)
temperature_df = spark.read.csv('final_temperature.csv', header=True, inferSchema=True)

print("    - All datasets loaded.")

print(f"    - Fertilizer shape: {fertilizer_df.count()}, {len(fertilizer_df.columns)}")
print(f"    - Crop yield shape: {crop_yield_df.count()}, {len(crop_yield_df.columns)}")
print(f"    - Rainfall shape: {rainfall_df.count()}, {len(rainfall_df.columns)}")
print(f"    - Temperature shape: {temperature_df.count()}, {len(temperature_df.columns)}")


    - All datasets loaded.
    - Fertilizer shape: 1843, 6
    - Crop yield shape: 246091, 7
    - Rainfall shape: 180, 14
    - Temperature shape: 33, 5


In [None]:
# Standardize text data function
def standardize_text(df, cols):
    for col_name in cols:
        if col_name in df.columns:
            df = df.withColumn(col_name, F.lower(F.trim(F.col(col_name))))
    return df
# Apply text standardization
crop_yield_df = standardize_text(crop_yield_df, ['State_Name', 'District_Name', 'Crop', 'Season'])
rainfall_df = standardize_text(rainfall_df, ['SUBDIVISION'])
temperature_df = standardize_text(temperature_df, ['States'])
fertilizer_df = standardize_text(fertilizer_df, ['Crop'])

print("    - Text data standardized (lowercase, stripped whitespace).")


    - Text data standardized (lowercase, stripped whitespace).


In [None]:
# State name mapping for consistency
state_name_mapping = {
    'andaman & nicobar islands': 'andaman and nicobar islands',
    'dadra & nagar haveli': 'dadra and nagar haveli',
    'jammu & kashmir': 'jammu and kashmir',
    'n.i. karnataka': 'north interior karnataka',
    's.i. karnataka': 'south interior karnataka',
    'rayalaseema': 'andhra pradesh',
    'coastal andhra pradesh': 'andhra pradesh',
    'telangana': 'andhra pradesh',
    'puducherry': 'pondicherry',
    'daman & diu': 'daman and diu',
    'uttaranchal': 'uttarakhand'
}
# Replace state names
for old_name, new_name in state_name_mapping.items():
    rainfall_df = rainfall_df.withColumn('SUBDIVISION', F.when(F.col('SUBDIVISION') == old_name, new_name).otherwise(F.col('SUBDIVISION')))
    temperature_df = temperature_df.withColumn('States', F.when(F.col('States') == old_name, new_name).otherwise(F.col('States')))

print("    - State names mapped for consistency.")

    - State names mapped for consistency.


In [None]:
# Process rainfall data
rainfall_df = rainfall_df.withColumnRenamed('SUBDIVISION', 'State_Name') \
    .withColumnRenamed('YEAR', 'Crop_Year') \
    .withColumn('kharif_rainfall', (F.col('JUN') + F.col('JUL') + F.col('AUG') + F.col('SEP'))) \
    .withColumn('rabi_rainfall', (F.col('OCT') + F.col('NOV') + F.col('DEC') + F.col('JAN'))) \
    .withColumn('summer_rainfall', (F.col('FEB') + F.col('MAR') + F.col('APR') + F.col('MAY'))) \
    .withColumn('yearly_rainfall', F.expr('JAN + FEB + MAR + APR + MAY + JUN + JUL + AUG + SEP + OCT + NOV + DEC'))

processed_rainfall_df = rainfall_df.select('State_Name', 'Crop_Year', 'kharif_rainfall', 'rabi_rainfall', 'summer_rainfall', 'yearly_rainfall')

print("    - Rainfall data processed (seasonal/yearly sums calculated).")


    - Rainfall data processed (seasonal/yearly sums calculated).


In [None]:
# Merge datasets
df_merged = crop_yield_df.join(processed_rainfall_df, on=['State_Name', 'Crop_Year'], how='left')
print(f"    - Shape after merging rainfall: {df_merged.count()}, {len(df_merged.columns)}")

temperature_df = temperature_df.withColumnRenamed('States', 'State_Name')
df_merged = df_merged.join(temperature_df, on='State_Name', how='left')
print(f"    - Shape after merging temperature: {df_merged.count()}, {len(df_merged.columns)}")

df_merged = df_merged.join(fertilizer_df, on='Crop', how='left')
print(f"    - Shape after merging fertilizer: {df_merged.count()}, {len(df_merged.columns)}")


    - Shape after merging rainfall: 248760, 11
    - Shape after merging temperature: 261057, 15
    - Shape after merging fertilizer: 1726005, 20


In [None]:
df_merged.show()

+-------------------+--------------------+---------+-------------+------+------+----------+---------------+-------------+---------------+---------------+-----------+---------+-----------+------------------+----+----+----+----+----+
|               Crop|          State_Name|Crop_Year|District_Name|Season|  Area|Production|kharif_rainfall|rabi_rainfall|summer_rainfall|yearly_rainfall|kharif_temp|rabi_temp|summer_temp|       yearly_temp| _c0|   N|   P|   K|  pH|
+-------------------+--------------------+---------+-------------+------+------+----------+---------------+-------------+---------------+---------------+-----------+---------+-----------+------------------+----+----+----+----+----+
|           arecanut|andaman and nicob...|     2000|     nicobars|kharif|1254.0|    2000.0|           NULL|         NULL|           NULL|           NULL|       27.0|     26.6|       27.8|27.036363636363635|1836| 100|  40| 140|5.82|
|           arecanut|andaman and nicob...|     2000|     nicobars|kharif

-------------------------------------
### 2. DATA QUALITY CHECKS AND CLEANING
-------------------------------------

In [None]:
# Check for missing values
print("\nMissing values before cleaning:")
missing_values = df_merged.select([F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in df_merged.columns])
missing_values.show()


Missing values before cleaning:
+----+----------+---------+-------------+------+----+----------+---------------+-------------+---------------+---------------+-----------+---------+-----------+-----------+------+------+------+------+------+
|Crop|State_Name|Crop_Year|District_Name|Season|Area|Production|kharif_rainfall|rabi_rainfall|summer_rainfall|yearly_rainfall|kharif_temp|rabi_temp|summer_temp|yearly_temp|   _c0|     N|     P|     K|    pH|
+----+----------+---------+-------------+------+----+----------+---------------+-------------+---------------+---------------+-----------+---------+-----------+-----------+------+------+------+------+------+
|   0|         0|        0|            0|     0|   0|     18677|        1552997|      1552997|        1552997|        1552997|      41031|    41031|      41031|      41031|179671|179671|179671|179671|179671|
+----+----------+---------+-------------+------+----+----------+---------------+-------------+---------------+---------------+---------

In [None]:
# Handle missing values
# For numerical columns, fill with grouped means
numerical_cols = ['Area', 'Production', 'kharif_rainfall', 'rabi_rainfall',
                  'summer_rainfall', 'yearly_rainfall', 'kharif_temp',
                  'rabi_temp', 'summer_temp', 'yearly_temp', 'N', 'P', 'K', 'pH']

for col_name in numerical_cols:
    if col_name in df_merged.columns:
        if 'rainfall' in col_name:
            group_cols = ['State_Name', 'Crop_Year']
        elif 'temp' in col_name:
            group_cols = ['State_Name']
        elif col_name in ['N', 'P', 'K', 'pH']:
            group_cols = ['Crop']
        else:
            group_cols = ['State_Name', 'Crop', 'Season']

        df_merged = df_merged.withColumn(col_name, F.when(F.col(col_name).isNull(),
            F.mean(col_name).over(Window.partitionBy(*group_cols))).otherwise(F.col(col_name)))

        # Fill remaining with global mean
        global_mean = df_merged.select(F.mean(col_name)).first()[0]
        df_merged = df_merged.withColumn(col_name, F.when(F.col(col_name).isNull(), global_mean).otherwise(F.col(col_name)))

In [None]:
from pyspark.sql import functions as F

# For categorical columns, fill with mode or 'unknown'
categorical_cols = ['State_Name', 'District_Name', 'Crop', 'Season']
for col_name in categorical_cols:
    if col_name in df_merged.columns:
        # Calculate mode
        mode_value = df_merged.select(col_name).na.drop().agg(F.mode(col_name)).first()[0]

        # Fill missing values with mode or 'unknown' if mode is None
        df_merged = df_merged.withColumn(col_name,
                                          F.when(F.col(col_name).isNull(), mode_value if mode_value is not None else 'unknown')
                                          .otherwise(F.col(col_name)))

print("    - Missing values handled for categorical columns.")


    - Missing values handled for categorical columns.


In [None]:
print("\nMissing values after cleaning:")
missing_values_after = df_merged.select([F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in df_merged.columns])
missing_values_after.show()


Missing values after cleaning:
+----+----------+---------+-------------+------+----+----------+---------------+-------------+---------------+---------------+-----------+---------+-----------+-----------+------+---+---+---+---+
|Crop|State_Name|Crop_Year|District_Name|Season|Area|Production|kharif_rainfall|rabi_rainfall|summer_rainfall|yearly_rainfall|kharif_temp|rabi_temp|summer_temp|yearly_temp|   _c0|  N|  P|  K| pH|
+----+----------+---------+-------------+------+----+----------+---------------+-------------+---------------+---------------+-----------+---------+-----------+-----------+------+---+---+---+---+
|   0|         0|        0|            0|     0|   0|         0|              0|            0|              0|              0|          0|        0|          0|          0|179671|  0|  0|  0|  0|
+----+----------+---------+-------------+------+----+----------+---------------+-------------+---------------+---------------+-----------+---------+-----------+-----------+------+---+-

In [None]:
# Drop the _c0 column from the DataFrame
df_merged = df_merged.drop('_c0')
print("    - Column '_c0' dropped from the DataFrame.")

    - Column '_c0' dropped from the DataFrame.


In [None]:
df_merged.printSchema()

root
 |-- Crop: string (nullable = true)
 |-- State_Name: string (nullable = true)
 |-- Crop_Year: integer (nullable = true)
 |-- District_Name: string (nullable = true)
 |-- Season: string (nullable = true)
 |-- Area: double (nullable = true)
 |-- Production: double (nullable = true)
 |-- kharif_rainfall: double (nullable = true)
 |-- rabi_rainfall: double (nullable = true)
 |-- summer_rainfall: double (nullable = true)
 |-- yearly_rainfall: double (nullable = true)
 |-- kharif_temp: double (nullable = true)
 |-- rabi_temp: double (nullable = true)
 |-- summer_temp: double (nullable = true)
 |-- yearly_temp: double (nullable = true)
 |-- N: double (nullable = true)
 |-- P: double (nullable = true)
 |-- K: double (nullable = true)
 |-- pH: double (nullable = true)



In [None]:
# Remove duplicates if any
df_merged = df_merged.dropDuplicates()
print(f"\nNumber of duplicates after cleaning: {df_merged.count() - df_merged.distinct().count()}")


Number of duplicates after cleaning: 0


In [None]:
df_merged.show()

+---------+--------------+---------+---------------+------+-------+----------+-----------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+-----------------+-----------------+------------------+------------------+
|     Crop|    State_Name|Crop_Year|  District_Name|Season|   Area|Production|  kharif_rainfall|     rabi_rainfall|   summer_rainfall|   yearly_rainfall|       kharif_temp|         rabi_temp|       summer_temp|       yearly_temp|                N|                P|                 K|                pH|
+---------+--------------+---------+---------------+------+-------+----------+-----------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+-----------------+-----------------+------------------+------------------+
|arhar/tur|andhra pradesh|     2001|     srikakulam|kharif| 1395.0|     813.0|893.375234

In [None]:
# Save intermediate merged data
df_initial_merged = df_merged

df_initial_merged.toPandas().to_csv('agripredict_initial_merged_data.csv', index=False)
print("    - Intermediate merged data saved.")

    - Intermediate merged data saved.


-------------------------------------
### 3. FEATURE ENGINEERING
-------------------------------------

In [None]:
# Create relevant rainfall and temperature based on season
conditions = [
    df_merged['Season'].contains('kharif'),
    df_merged['Season'].contains('rabi'),
    df_merged['Season'].contains('summer') | df_merged['Season'].contains('zaid'),
    df_merged['Season'].contains('whole year')
]

choices = [
    df_merged['kharif_rainfall'],
    df_merged['rabi_rainfall'],
    df_merged['summer_rainfall'],
    df_merged['yearly_rainfall']
]

df_merged = df_merged.withColumn('relevant_rainfall', F.when(conditions[0], choices[0])
                                  .when(conditions[1], choices[1])
                                  .when(conditions[2], choices[2])
                                  .when(conditions[3], choices[3])
                                  .otherwise(None))

conditions_temp = [
    df_merged['Season'].contains('kharif'),
    df_merged['Season'].contains('rabi'),
    df_merged['Season'].contains('summer') | df_merged['Season'].contains('zaid'),
    df_merged['Season'].contains('whole year')
]

choices_temp = [
    df_merged['kharif_temp'],
    df_merged['rabi_temp'],
    df_merged['summer_temp'],
    df_merged['yearly_temp']
]

df_merged = df_merged.withColumn('relevant_temperature', F.when(conditions_temp[0], choices_temp[0])
                                  .when(conditions_temp[1], choices_temp[1])
                                  .when(conditions_temp[2], choices_temp[2])
                                  .when(conditions_temp[3], choices_temp[3])
                                  .otherwise(None))

print("    - Created 'relevant_rainfall' and 'relevant_temperature' features.")


    - Created 'relevant_rainfall' and 'relevant_temperature' features.


In [None]:
# Calculate yield (production/area)
df_merged = df_merged.withColumn('Area', F.when(F.col('Area') == 0, None).otherwise(F.col('Area')))
df_merged = df_merged.withColumn('Yield_ton_per_hec', F.col('Production') / F.col('Area'))
df_merged = df_merged.withColumn('Yield_ton_per_hec', F.when(F.col('Yield_ton_per_hec').isNull(), None).otherwise(F.col('Yield_ton_per_hec')))


In [None]:
# Fill yield NaNs with crop-season averages
df_merged = df_merged.withColumn('Yield_ton_per_hec', F.when(F.col('Yield_ton_per_hec').isNull(),
    F.mean('Yield_ton_per_hec').over(Window.partitionBy('Crop', 'Season'))).otherwise(F.col('Yield_ton_per_hec')))

# Fill remaining with global mean
global_yield_mean = df_merged.agg(F.mean('Yield_ton_per_hec')).first()[0]
df_merged = df_merged.withColumn('Yield_ton_per_hec', F.when(F.col('Yield_ton_per_hec').isNull(), global_yield_mean).otherwise(F.col('Yield_ton_per_hec')))

print("    - Calculated 'Yield_ton_per_hec'.")

    - Calculated 'Yield_ton_per_hec'.


In [None]:
# Create interaction features
df_merged = df_merged.withColumn('rainfall_temp_interaction', F.col('relevant_rainfall') * F.col('relevant_temperature'))
df_merged = df_merged.withColumn('NPK_ratio', F.col('N') / (F.col('P') + F.col('K') + 1e-6))

print("    - Created interaction features.")


    - Created interaction features.


In [None]:
# Create lag features
df_merged = df_merged.sort(['State_Name', 'District_Name', 'Crop', 'Season', 'Crop_Year'])
df_merged = df_merged.withColumn('lagged_production_1yr',
                                  F.lag('Production', 1).over(Window.partitionBy('State_Name', 'District_Name', 'Crop', 'Season').orderBy('Crop_Year')))
df_merged = df_merged.withColumn('lagged_yield_1yr',
                                  F.lag('Yield_ton_per_hec', 1).over(Window.partitionBy('State_Name', 'District_Name', 'Crop', 'Season').orderBy('Crop_Year')))


In [None]:
# Fill lag features with mean values
df_merged = df_merged.withColumn('lagged_production_1yr', F.when(F.col('lagged_production_1yr').isNull(),
    F.mean('lagged_production_1yr').over(Window.partitionBy('Crop', 'Season'))).otherwise(F.col('lagged_production_1yr')))
df_merged = df_merged.withColumn('lagged_yield_1yr', F.when(F.col('lagged_yield_1yr').isNull(),
    F.mean('lagged_yield_1yr').over(Window.partitionBy('Crop', 'Season'))).otherwise(F.col('lagged_yield_1yr')))


In [None]:
print("    - Created lag features.")

    - Created lag features.


-----------------------------------------
### 4. OUTLIER DETECTION AND TREATMENT
-----------------------------------------

In [None]:
numerical_cols = [
    'Area', 'Production', 'kharif_rainfall', 'rabi_rainfall',
    'summer_rainfall', 'yearly_rainfall', 'kharif_temp',
    'rabi_temp', 'summer_temp', 'yearly_temp', 'N', 'P', 'K',
    'pH', 'Yield_ton_per_hec', 'rainfall_temp_interaction',
    'NPK_ratio', 'lagged_production_1yr', 'lagged_yield_1yr'
]

In [None]:
# --- Step 1: Calculate Outlier Bounds Efficiently ---
# This step still requires one approxQuantile action per column but is necessary.
outlier_bounds = {}
print("    - Calculating IQR bounds for all columns...")
for col_name in numerical_cols:
    if col_name in df_merged.columns:
        quantiles = df_merged.approxQuantile(col_name, [0.25, 0.75], 0.01)
        if quantiles and quantiles[0] is not None and quantiles[1] is not None:
            Q1, Q3 = quantiles[0], quantiles[1]
            IQR = Q3 - Q1
            lower_bound = Q1 - 1.5 * IQR
            upper_bound = Q3 + 1.5 * IQR
            outlier_bounds[col_name] = {'lower_bound': lower_bound, 'upper_bound': upper_bound}
        else:
            print(f"        - Skipping {col_name} due to insufficient data or nulls.")


    - Calculating IQR bounds for all columns...


In [None]:
# --- Step 2: Create a capped DataFrame (a series of lazy transformations) ---
print("    - Performing Outlier Treatment (Capping) on all columns...")
df_capped = df_merged
for col_name in outlier_bounds:
    lower_bound = outlier_bounds[col_name]['lower_bound']
    upper_bound = outlier_bounds[col_name]['upper_bound']
    df_capped = df_capped.withColumn(
        col_name,
        F.when(F.col(col_name) < lower_bound, lower_bound)
         .when(F.col(col_name) > upper_bound, upper_bound)
         .otherwise(F.col(col_name))
    )
print("    - Capping logic defined for all columns")


    - Performing Outlier Treatment (Capping) on all columns...
    - Capping logic defined for all columns. No Spark job executed yet.


In [1]:
# --- Step 3: Visualize and Verify in a single loop ---
print("\n### Generating side-by-side plots and verifying capping... ###")
print("---------------------------------------------------------------")
total_rows = df_merged.count()

for col_name in numerical_cols:
    if col_name in outlier_bounds:
        lower_bound = outlier_bounds[col_name]['lower_bound']
        upper_bound = outlier_bounds[col_name]['upper_bound']

        print(f"\n--- Analyzing and Visualizing '{col_name}' ---")

        # --- Sub-step 3.1: Efficiently count outliers before and after ---
        outliers_low_before, outliers_high_before = df_merged.agg(
            F.sum(F.when(F.col(col_name) < lower_bound, 1).otherwise(0)).alias('low'),
            F.sum(F.when(F.col(col_name) > upper_bound, 1).otherwise(0)).alias('high')
        ).collect()[0]

        # We can't use the same bounds to check after capping, as new outliers may form.
        # So we calculate new bounds for the capped data for a more accurate report.
        quantiles_after = df_capped.approxQuantile(col_name, [0.25, 0.75], 0.01)
        if quantiles_after:
            Q1_after, Q3_after = quantiles_after[0], quantiles_after[1]
            IQR_after = Q3_after - Q1_after
            lower_bound_after = Q1_after - 1.5 * IQR_after
            upper_bound_after = Q3_after + 1.5 * IQR_after
            outliers_low_after, outliers_high_after = df_capped.agg(
                F.sum(F.when(F.col(col_name) < lower_bound_after, 1).otherwise(0)).alias('low'),
                F.sum(F.when(F.col(col_name) > upper_bound_after, 1).otherwise(0)).alias('high')
            ).collect()[0]

            print(f"  - Before Capping: {outliers_low_before + outliers_high_before} outliers ({(outliers_low_before + outliers_high_before) / total_rows * 100:.2f}%)")
            print(f"  - After Capping: {outliers_low_after + outliers_high_after} outliers ({(outliers_low_after + outliers_high_after) / total_rows * 100:.2f}%)")

        # --- Sub-step 3.2: Sample and Plot ---
        # Sample data from both DataFrames to ensure consistent plotting
        sample_size = 0.05  # Adjust sampling size as needed
        sample_df_before = df_merged.select(col_name).sample(False, sample_size, seed=42).toPandas()
        sample_df_after = df_capped.select(col_name).sample(False, sample_size, seed=42).toPandas()

        if not sample_df_before.empty and not sample_df_after.empty:
            plt.figure(figsize=(10, 6))

            # Before Capping Plot
            ax1 = plt.subplot(1, 2, 1)
            sns.boxplot(y=sample_df_before[col_name], ax=ax1)
            ax1.set_title(f'Box Plot of {col_name} (Before Capping)')
            ax1.set_ylabel(col_name)

            # After Capping Plot
            ax2 = plt.subplot(1, 2, 2)
            sns.boxplot(y=sample_df_after[col_name], ax=ax2)
            ax2.set_title(f'Box Plot of {col_name} (After Capping)')
            ax2.set_ylabel(col_name)

            plt.tight_layout()
            plt.show()
            plt.close()
        else:
            print(f"    - Skipping box plot for '{col_name}' (not enough sample data).")
    else:
        print(f"\n--- Skipping '{col_name}': Not found or all values are null. ---")

# The final df_merged_spark now holds the capped values.
df_merged_spark = df_capped
print("\n=== Outlier Detection & Treatment with Visualizations Complete. ===\n")


### Generating side-by-side plots and verifying capping... ###
---------------------------------------------------------------


NameError: name 'df_merged' is not defined

-----------------------------------------
### 5. EXPLORATORY DATA ANALYSIS (EDA)
-----------------------------------------


##### STEP 5: EXPLORATORY DATA ANALYSIS (EDA)

In [None]:
# 1. Distribution of numerical variables
numerical_cols_for_eda = ['Area', 'Production', 'kharif_rainfall', 'rabi_rainfall',
                         'summer_rainfall', 'yearly_rainfall', 'kharif_temp',
                         'rabi_temp', 'summer_temp', 'yearly_temp', 'N', 'P', 'K',
                         'pH', 'Yield_ton_per_hec', 'rainfall_temp_interaction',
                         'NPK_ratio']

print("\n1. Distribution of Numerical Variables:")
for col_name in numerical_cols_for_eda:
    if col_name in df_merged.columns:
        plt.figure(figsize=(8, 4))
        sns.histplot(df_merged.select(col_name).toPandas()[col_name], kde=True, bins=50)
        plt.title(f'Distribution of {col_name}')
        plt.show()


In [None]:
# 2. Categorical variable distributions
print("\n2. Distribution of Categorical Variables:")
categorical_cols_for_eda = ['State_Name', 'District_Name', 'Crop', 'Season']
for col_name in categorical_cols_for_eda:
    if col_name in df_merged.columns:
        plt.figure(figsize=(8, 4))
        if df_merged.select(col_name).distinct().count() > 10:  # For columns with many categories
            top_categories = df_merged.groupBy(col_name).count().orderBy('count', ascending=False).limit(10).toPandas()[col_name]
            sns.countplot(y=col_name, data=df_merged.toPandas(), order=top_categories)
        else:
            sns.countplot(x=col_name, data=df_merged.toPandas())
        plt.title(f'Distribution of {col_name}')
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.show()

In [None]:
# 3. Correlation analysis
print("\n3. Correlation Analysis:")
plt.figure(figsize=(10, 6))
corr_matrix = df_merged.select(numerical_cols_for_eda).toPandas().corr()
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt=".2f", linewidths=.5)
plt.title('Correlation Matrix of Numerical Variables')
plt.tight_layout()
plt.show()



In [None]:
# 4. Time series analysis of yield and production
print("\n4. Time Series Analysis:")
if 'Crop_Year' in df_merged.columns:
    yearly_summary = df_merged.groupBy('Crop_Year').agg(
        spark_sum('Production').alias('Total_Production'),
        spark_mean('Yield_ton_per_hec').alias('Average_Yield')
    ).toPandas()

    plt.figure(figsize=(10, 6))
    plt.subplot(1, 2, 1)
    sns.lineplot(x='Crop_Year', y='Total_Production', data=yearly_summary, marker='o')
    plt.title('Total Production Over Years')
    plt.xlabel('Year')
    plt.ylabel('Total Production')

    plt.subplot(1, 2, 2)
    sns.lineplot(x='Crop_Year', y='Average_Yield', data=yearly_summary, marker='o', color='green')
    plt.title('Average Yield Over Years')
    plt.xlabel('Year')
    plt.ylabel('Average Yield (ton/hec)')
    plt.tight_layout()
    plt.show()

In [None]:
# 5. Relationship between key features and yield
print("\n5. Feature-Yield Relationships:")
key_features = ['relevant_rainfall', 'relevant_temperature', 'N', 'P', 'K', 'pH']
for feature in key_features:
    if feature in df_merged.columns:
        plt.figure(figsize=(8, 4))
        sns.scatterplot(x=feature, y='Yield_ton_per_hec', data=df_merged.toPandas(), alpha=0.5)
        plt.title(f'Relationship between {feature} and Yield')
        plt.tight_layout()
        plt.show()

--------------------------------------------------------------------------------
### 6. FINAL PROCESSED DATA SAVE
--------------------------------------------------------------------------------

##### FINAL PROCESSED DATA SAVE

In [None]:

# Select final columns to keep
final_columns = ['State_Name', 'District_Name', 'Crop', 'Season', 'Crop_Year',
                'Area', 'Production', 'kharif_rainfall', 'rabi_rainfall',
                'summer_rainfall', 'yearly_rainfall', 'kharif_temp', 'rabi_temp',
                'summer_temp', 'yearly_temp', 'N', 'P', 'K', 'pH',
                'relevant_rainfall', 'relevant_temperature', 'rainfall_temp_interaction',
                'NPK_ratio', 'lagged_production_1yr', 'lagged_yield_1yr',
                'Yield_ton_per_hec']

df_final = df_merged.select(final_columns)



In [None]:
df_final.show()

In [None]:
# Save final processed data
df_final.toPandas().to_csv('agripredict_final_processed_data.csv', index=False)
print("    - Final processed data saved to 'agripredict_final_processed_data.csv'")

print("\n=== PROCESSING COMPLETE ===")
df_final.show()