In [None]:
import os
import shap
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

os.makedirs("plots/rmplots", exist_ok=True)
os.makedirs("plots/rmplots/SHAP", exist_ok=True)

ds = pd.read_csv("crop_yield_dataset.csv")
ds["Date"] = pd.to_datetime(ds["Date"])
ds["Year"] = ds["Date"].dt.year
ds["Month"] = ds["Date"].dt.month

crops = ds["Crop_Type"].unique()
summary_list = []  # to collect results

for crop in crops:
    print(f"Processing {crop}...")

    # Subset dataset
    crop_ds = ds[ds["Crop_Type"] == crop]
    X = crop_ds[["Soil_Quality", "Temperature", "Humidity", "N", "P", "K", "Soil_pH"]]
    y = crop_ds["Crop_Yield"]

    # Train/test split
    from sklearn.model_selection import train_test_split
    from sklearn.ensemble import RandomForestRegressor

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=36)

    # Train model
    rf = RandomForestRegressor(n_estimators=100, random_state=36)
    rf.fit(X_train, y_train)

    # SHAP (subset for speed)
    X_sample = X_test.sample(100, random_state=36) if len(X_test) > 100 else X_test
    explainer = shap.TreeExplainer(rf)
    shap_values = explainer.shap_values(X_sample)

    y_sample = y.loc[X_sample.index]  # align yields with sampled rows

    # Get indices for min, median, max Crop Yield
    low_idx = y_sample.idxmin()
    high_idx = y_sample.idxmax()
    median_idx = y_sample.sort_values().index[len(y_sample)//2]

    # Collect mean absolute SHAP values (feature importance)
    shap_importance = pd.DataFrame({
        "Feature": X_sample.columns,
        "MeanAbsSHAP": np.abs(shap_values).mean(axis=0)
    })
    shap_importance["Crop_Type"] = crop
    summary_list.append(shap_importance)

    for label, idx in [("low", low_idx), ("median", median_idx), ("high", high_idx)]:
        row_i = X_sample.loc[idx]
        shap_i = shap_values[list(X_sample.index).index(idx)]
    
        shap_html = shap.force_plot(
            explainer.expected_value,
            shap_i,
            row_i,
            matplotlib=False
        )
        shap.save_html(f"plots/rmplots/SHAP/{crop}_shap_force_{label}.html", shap_html)


# Combine all crops into one CSV
summary_df = pd.concat(summary_list, ignore_index=True)
summary_df.to_csv("plots/rmplots/SHAP/shap_summary_table.csv", index=False)

print("Done! Plots and CSV saved in plots/rmplots/html/")
